diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 22d8d1cb1ddcfe48006d31b6f020210988c332d6..fc36e37c53f5e532caddb1d4dfcb77d9483484d7 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -210,12 +210,22 @@ object SparkEnv extends Logging {
       "MapOutputTracker",
       new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))
 
+    // Let the user specify short names for shuffle managers
+    val shortShuffleMgrNames = Map(
+      "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager",
+      "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager")
+    val shuffleMgrName = conf.get("spark.shuffle.manager", "hash")
+    val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
+    val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)
+
+    val shuffleMemoryManager = new ShuffleMemoryManager(conf)
+
     val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
       "BlockManagerMaster",
       new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf)
 
     val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
-      serializer, conf, securityManager, mapOutputTracker)
+      serializer, conf, securityManager, mapOutputTracker, shuffleManager)
 
     val connectionManager = blockManager.connectionManager
 
@@ -250,16 +260,6 @@ object SparkEnv extends Logging {
       "."
     }
 
-    // Let the user specify short names for shuffle managers
-    val shortShuffleMgrNames = Map(
-      "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager",
-      "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager")
-    val shuffleMgrName = conf.get("spark.shuffle.manager", "hash")
-    val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
-    val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)
-
-    val shuffleMemoryManager = new ShuffleMemoryManager(conf)
-
     // Warn about deprecated spark.cache.class property
     if (conf.contains("spark.cache.class")) {
       logWarning("The spark.cache.class property is no longer being used! Specify storage " +
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index e8bbd298c631a02a51bbecd8d8933c8f3e64cfb4..e4c3d58905e7fca22673e779c64ee3fea50ba7da 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -33,6 +33,7 @@ import org.apache.spark.executor._
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.network._
 import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.ShuffleManager
 import org.apache.spark.util._
 
 private[spark] sealed trait BlockValues
@@ -57,11 +58,12 @@ private[spark] class BlockManager(
     maxMemory: Long,
     val conf: SparkConf,
     securityManager: SecurityManager,
-    mapOutputTracker: MapOutputTracker)
+    mapOutputTracker: MapOutputTracker,
+    shuffleManager: ShuffleManager)
   extends Logging {
 
   private val port = conf.getInt("spark.blockManager.port", 0)
-  val shuffleBlockManager = new ShuffleBlockManager(this)
+  val shuffleBlockManager = new ShuffleBlockManager(this, shuffleManager)
   val diskBlockManager = new DiskBlockManager(shuffleBlockManager,
     conf.get("spark.local.dir", System.getProperty("java.io.tmpdir")))
   val connectionManager =
@@ -142,9 +144,10 @@ private[spark] class BlockManager(
       serializer: Serializer,
       conf: SparkConf,
       securityManager: SecurityManager,
-      mapOutputTracker: MapOutputTracker) = {
+      mapOutputTracker: MapOutputTracker,
+      shuffleManager: ShuffleManager) = {
     this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf),
-      conf, securityManager, mapOutputTracker)
+      conf, securityManager, mapOutputTracker, shuffleManager)
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index 3565719b545451813e4f59906a0702f373abb44b..b8f5d3a5b02aa98a00ec2e384fbeb9ede2b3f0f8 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -25,6 +25,7 @@ import scala.collection.JavaConversions._
 
 import org.apache.spark.Logging
 import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.ShuffleManager
 import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup
 import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
 import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
@@ -62,7 +63,8 @@ private[spark] trait ShuffleWriterGroup {
  */
 // TODO: Factor this into a separate class for each ShuffleManager implementation
 private[spark]
-class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
+class ShuffleBlockManager(blockManager: BlockManager,
+                          shuffleManager: ShuffleManager) extends Logging {
   def conf = blockManager.conf
 
   // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
@@ -71,8 +73,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
     conf.getBoolean("spark.shuffle.consolidateFiles", false)
 
   // Are we using sort-based shuffle?
-  val sortBasedShuffle =
-    conf.get("spark.shuffle.manager", "") == classOf[SortShuffleManager].getName
+  val sortBasedShuffle = shuffleManager.isInstanceOf[SortShuffleManager]
 
   private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024
 
diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
index 75c2e09a6bbb84d0de3de7dfa97fb26e9b464181..aa83ea90ee9ee655e8cb724502063ec80e5e3be0 100644
--- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
@@ -20,6 +20,7 @@ package org.apache.spark.storage
 import java.util.concurrent.ArrayBlockingQueue
 
 import akka.actor._
+import org.apache.spark.shuffle.hash.HashShuffleManager
 import util.Random
 
 import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
@@ -101,7 +102,7 @@ private[spark] object ThreadingTest {
       conf)
     val blockManager = new BlockManager(
       "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf,
-      new SecurityManager(conf), new MapOutputTrackerMaster(conf))
+      new SecurityManager(conf), new MapOutputTrackerMaster(conf), new HashShuffleManager(conf))
     val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
     val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue))
     producers.foreach(_.start)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 94bb2c445d2e92538effbd4861f4de2e62ca82b6..20bac66105a69eff351c31f5c334413f7f23732c 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -24,6 +24,7 @@ import java.util.concurrent.TimeUnit
 import akka.actor._
 import akka.pattern.ask
 import akka.util.Timeout
+import org.apache.spark.shuffle.hash.HashShuffleManager
 
 import org.mockito.invocation.InvocationOnMock
 import org.mockito.Matchers.any
@@ -61,6 +62,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
   conf.set("spark.authenticate", "false")
   val securityMgr = new SecurityManager(conf)
   val mapOutputTracker = new MapOutputTrackerMaster(conf)
+  val shuffleManager = new HashShuffleManager(conf)
 
   // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test
   conf.set("spark.kryoserializer.buffer.mb", "1")
@@ -71,8 +73,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
   def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId)
 
   private def makeBlockManager(maxMem: Long, name: String = "<driver>"): BlockManager = {
-    new BlockManager(
-      name, actorSystem, master, serializer, maxMem, conf, securityMgr, mapOutputTracker)
+    new BlockManager(name, actorSystem, master, serializer, maxMem, conf, securityMgr,
+      mapOutputTracker, shuffleManager)
   }
 
   before {
@@ -791,7 +793,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
   test("block store put failure") {
     // Use Java serializer so we can create an unserializable error.
     store = new BlockManager("<driver>", actorSystem, master, new JavaSerializer(conf), 1200, conf,
-      securityMgr, mapOutputTracker)
+      securityMgr, mapOutputTracker, shuffleManager)
 
     // The put should fail since a1 is not serializable.
     class UnserializableClass
@@ -1007,7 +1009,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
 
   test("return error message when error occurred in BlockManagerWorker#onBlockMessageReceive") {
     store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
-      securityMgr, mapOutputTracker)
+      securityMgr, mapOutputTracker, shuffleManager)
 
     val worker = spy(new BlockManagerWorker(store))
     val connManagerId = mock(classOf[ConnectionManagerId])
@@ -1054,7 +1056,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
 
   test("return ack message when no error occurred in BlocManagerWorker#onBlockMessageReceive") {
     store = new BlockManager("<driver>", actorSystem, master, serializer, 1200, conf,
-      securityMgr, mapOutputTracker)
+      securityMgr, mapOutputTracker, shuffleManager)
 
     val worker = spy(new BlockManagerWorker(store))
     val connManagerId = mock(classOf[ConnectionManagerId])
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
index b8299e2ea187fed891eb9458480c49d9bf40c49f..777579bc570db017b4a658d6aaf64d86c005b2cb 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.storage
 
 import java.io.{File, FileWriter}
 
+import org.apache.spark.shuffle.hash.HashShuffleManager
+
 import scala.collection.mutable
 import scala.language.reflectiveCalls
 
@@ -42,7 +44,9 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
   // so we coerce consolidation if not already enabled.
   testConf.set("spark.shuffle.consolidateFiles", "true")
 
-  val shuffleBlockManager = new ShuffleBlockManager(null) {
+  private val shuffleManager = new HashShuffleManager(testConf.clone)
+
+  val shuffleBlockManager = new ShuffleBlockManager(null, shuffleManager) {
     override def conf = testConf.clone
     var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]()
     override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id)
@@ -148,7 +152,7 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
       actorSystem.actorOf(Props(new BlockManagerMasterActor(true, confCopy, new LiveListenerBus))),
       confCopy)
     val store = new BlockManager("<driver>", actorSystem, master , serializer, confCopy,
-      securityManager, null)
+      securityManager, null, shuffleManager)
 
     try {