diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
similarity index 83%
rename from core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
rename to core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
index b8f5d3a5b02aa98a00ec2e384fbeb9ede2b3f0f8..76e3932a9bb91d5736de1dfefba20ef0b17e9e25 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
@@ -15,22 +15,22 @@
  * limitations under the License.
  */
 
-package org.apache.spark.storage
+package org.apache.spark.shuffle
 
 import java.io.File
+import java.nio.ByteBuffer
 import java.util.concurrent.ConcurrentLinkedQueue
 import java.util.concurrent.atomic.AtomicInteger
 
 import scala.collection.JavaConversions._
 
-import org.apache.spark.Logging
+import org.apache.spark.{SparkEnv, SparkConf, Logging}
+import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.ShuffleManager
-import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup
+import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup
+import org.apache.spark.storage._
 import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
 import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
-import org.apache.spark.shuffle.sort.SortShuffleManager
-import org.apache.spark.executor.ShuffleWriteMetrics
 
 /** A group of writers for a ShuffleMapTask, one writer per reducer. */
 private[spark] trait ShuffleWriterGroup {
@@ -61,20 +61,18 @@ private[spark] trait ShuffleWriterGroup {
  * each block stored in each file. In order to find the location of a shuffle block, we search the
  * files within a ShuffleFileGroups associated with the block's reducer.
  */
-// TODO: Factor this into a separate class for each ShuffleManager implementation
+
 private[spark]
-class ShuffleBlockManager(blockManager: BlockManager,
-                          shuffleManager: ShuffleManager) extends Logging {
-  def conf = blockManager.conf
+class FileShuffleBlockManager(conf: SparkConf)
+  extends ShuffleBlockManager with Logging {
+
+  private lazy val blockManager = SparkEnv.get.blockManager
 
   // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
   // TODO: Remove this once the shuffle file consolidation feature is stable.
-  val consolidateShuffleFiles =
+  private val consolidateShuffleFiles =
     conf.getBoolean("spark.shuffle.consolidateFiles", false)
 
-  // Are we using sort-based shuffle?
-  val sortBasedShuffle = shuffleManager.isInstanceOf[SortShuffleManager]
-
   private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024
 
   /**
@@ -93,22 +91,11 @@ class ShuffleBlockManager(blockManager: BlockManager,
     val completedMapTasks = new ConcurrentLinkedQueue[Int]()
   }
 
-  type ShuffleId = Int
   private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState]
 
   private val metadataCleaner =
     new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf)
 
-  /**
-   * Register a completed map without getting a ShuffleWriterGroup. Used by sort-based shuffle
-   * because it just writes a single file by itself.
-   */
-  def addCompletedMap(shuffleId: Int, mapId: Int, numBuckets: Int): Unit = {
-    shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
-    val shuffleState = shuffleStates(shuffleId)
-    shuffleState.completedMapTasks.add(mapId)
-  }
-
   /**
    * Get a ShuffleWriterGroup for the given map task, which will register it as complete
    * when the writers are closed successfully
@@ -181,17 +168,30 @@ class ShuffleBlockManager(blockManager: BlockManager,
 
   /**
    * Returns the physical file segment in which the given BlockId is located.
-   * This function should only be called if shuffle file consolidation is enabled, as it is
-   * an error condition if we don't find the expected block.
    */
-  def getBlockLocation(id: ShuffleBlockId): FileSegment = {
-    // Search all file groups associated with this shuffle.
-    val shuffleState = shuffleStates(id.shuffleId)
-    for (fileGroup <- shuffleState.allFileGroups) {
-      val segment = fileGroup.getFileSegmentFor(id.mapId, id.reduceId)
-      if (segment.isDefined) { return segment.get }
+  private def getBlockLocation(id: ShuffleBlockId): FileSegment = {
+    if (consolidateShuffleFiles) {
+      // Search all file groups associated with this shuffle.
+      val shuffleState = shuffleStates(id.shuffleId)
+      val iter = shuffleState.allFileGroups.iterator
+      while (iter.hasNext) {
+        val segment = iter.next.getFileSegmentFor(id.mapId, id.reduceId)
+        if (segment.isDefined) { return segment.get }
+      }
+      throw new IllegalStateException("Failed to find shuffle block: " + id)
+    } else {
+      val file = blockManager.diskBlockManager.getFile(id)
+      new FileSegment(file, 0, file.length())
     }
-    throw new IllegalStateException("Failed to find shuffle block: " + id)
+  }
+
+  override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = {
+    val segment = getBlockLocation(blockId)
+    blockManager.diskStore.getBytes(segment)
+  }
+
+  override def getBlockData(blockId: ShuffleBlockId): Either[FileSegment, ByteBuffer] = {
+    Left(getBlockLocation(blockId.asInstanceOf[ShuffleBlockId]))
   }
 
   /** Remove all the blocks / files and metadata related to a particular shuffle. */
@@ -207,14 +207,7 @@ class ShuffleBlockManager(blockManager: BlockManager,
   private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = {
     shuffleStates.get(shuffleId) match {
       case Some(state) =>
-        if (sortBasedShuffle) {
-          // There's a single block ID for each map, plus an index file for it
-          for (mapId <- state.completedMapTasks) {
-            val blockId = new ShuffleBlockId(shuffleId, mapId, 0)
-            blockManager.diskBlockManager.getFile(blockId).delete()
-            blockManager.diskBlockManager.getFile(blockId.name + ".index").delete()
-          }
-        } else if (consolidateShuffleFiles) {
+        if (consolidateShuffleFiles) {
           for (fileGroup <- state.allFileGroups; file <- fileGroup.files) {
             file.delete()
           }
@@ -240,13 +233,13 @@ class ShuffleBlockManager(blockManager: BlockManager,
     shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId))
   }
 
-  def stop() {
+  override def stop() {
     metadataCleaner.cancel()
   }
 }
 
 private[spark]
-object ShuffleBlockManager {
+object FileShuffleBlockManager {
   /**
    * A group of shuffle files, one per reducer.
    * A particular mapper will be assigned a single ShuffleFileGroup to write its output to.
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
new file mode 100644
index 0000000000000000000000000000000000000000..8bb9efc46cc583fa0a3ecc1ad66cd72dfca44243
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io._
+import java.nio.ByteBuffer
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.storage._
+
+/**
+ * Create and maintain the shuffle blocks' mapping between logic block and physical file location.
+ * Data of shuffle blocks from the same map task are stored in a single consolidated data file.
+ * The offsets of the data blocks in the data file are stored in a separate index file.
+ *
+ * We use the name of the shuffle data's shuffleBlockId with reduce ID set to 0 and add ".data"
+ * as the filename postfix for data file, and ".index" as the filename postfix for index file.
+ *
+ */
+private[spark]
+class IndexShuffleBlockManager extends ShuffleBlockManager {
+
+  private lazy val blockManager = SparkEnv.get.blockManager
+
+  /**
+   * Mapping to a single shuffleBlockId with reduce ID 0.
+   * */
+  def consolidateId(shuffleId: Int, mapId: Int): ShuffleBlockId = {
+    ShuffleBlockId(shuffleId, mapId, 0)
+  }
+
+  def getDataFile(shuffleId: Int, mapId: Int): File = {
+    blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, 0))
+  }
+
+  private def getIndexFile(shuffleId: Int, mapId: Int): File = {
+    blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, 0))
+  }
+
+  /**
+   * Remove data file and index file that contain the output data from one map.
+   * */
+  def removeDataByMap(shuffleId: Int, mapId: Int): Unit = {
+    var file = getDataFile(shuffleId, mapId)
+    if (file.exists()) {
+      file.delete()
+    }
+
+    file = getIndexFile(shuffleId, mapId)
+    if (file.exists()) {
+      file.delete()
+    }
+  }
+
+  /**
+   * Write an index file with the offsets of each block, plus a final offset at the end for the
+   * end of the output file. This will be used by getBlockLocation to figure out where each block
+   * begins and ends.
+   * */
+  def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]) = {
+    val indexFile = getIndexFile(shuffleId, mapId)
+    val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
+    try {
+      // We take in lengths of each block, need to convert it to offsets.
+      var offset = 0L
+      out.writeLong(offset)
+
+      for (length <- lengths) {
+        offset += length
+        out.writeLong(offset)
+      }
+    } finally {
+      out.close()
+    }
+  }
+
+  /**
+   * Get the location of a block in a map output file. Uses the index file we create for it.
+   * */
+  private def getBlockLocation(blockId: ShuffleBlockId): FileSegment = {
+    // The block is actually going to be a range of a single map output file for this map, so
+    // find out the consolidated file, then the offset within that from our index
+    val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)
+
+    val in = new DataInputStream(new FileInputStream(indexFile))
+    try {
+      in.skip(blockId.reduceId * 8)
+      val offset = in.readLong()
+      val nextOffset = in.readLong()
+      new FileSegment(getDataFile(blockId.shuffleId, blockId.mapId), offset, nextOffset - offset)
+    } finally {
+      in.close()
+    }
+  }
+
+  override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = {
+    val segment = getBlockLocation(blockId)
+    blockManager.diskStore.getBytes(segment)
+  }
+
+  override def getBlockData(blockId: ShuffleBlockId): Either[FileSegment, ByteBuffer] = {
+    Left(getBlockLocation(blockId.asInstanceOf[ShuffleBlockId]))
+  }
+
+  override def stop() = {}
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
new file mode 100644
index 0000000000000000000000000000000000000000..42405802500467fd287920e0cce3b5490fa4aa19
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.nio.ByteBuffer
+
+import org.apache.spark.storage.{FileSegment, ShuffleBlockId}
+
+private[spark]
+trait ShuffleBlockManager {
+  type ShuffleId = Int
+
+  /**
+   * Get shuffle block data managed by the local ShuffleBlockManager.
+   * @return Some(ByteBuffer) if block found, otherwise None.
+   */
+  def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer]
+
+  def getBlockData(blockId: ShuffleBlockId): Either[FileSegment, ByteBuffer]
+
+  def stop(): Unit
+}
+
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
index 9c859b8b4a118904612ce6e69eb2789200f399d8..801ae54086053ced0091f7d84a561dabc6aeb3ba 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
@@ -49,8 +49,13 @@ private[spark] trait ShuffleManager {
       endPartition: Int,
       context: TaskContext): ShuffleReader[K, C]
 
-  /** Remove a shuffle's metadata from the ShuffleManager. */
-  def unregisterShuffle(shuffleId: Int)
+  /**
+    * Remove a shuffle's metadata from the ShuffleManager.
+    * @return true if the metadata removed successfully, otherwise false.
+    */
+  def unregisterShuffle(shuffleId: Int): Boolean
+
+  def shuffleBlockManager: ShuffleBlockManager
 
   /** Shut down this ShuffleManager. */
   def stop(): Unit
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
index df98d18fa81936895f9b6370dfa2893c9f790db1..62e0629b3440058df710c67785327d8490608ada 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
@@ -25,6 +25,9 @@ import org.apache.spark.shuffle._
  * mapper (possibly reusing these across waves of tasks).
  */
 private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
+
+  private val fileShuffleBlockManager = new FileShuffleBlockManager(conf)
+
   /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */
   override def registerShuffle[K, V, C](
       shuffleId: Int,
@@ -49,12 +52,21 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager
   /** Get a writer for a given partition. Called on executors by map tasks. */
   override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
       : ShuffleWriter[K, V] = {
-    new HashShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
+    new HashShuffleWriter(
+      shuffleBlockManager, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
   }
 
   /** Remove a shuffle's metadata from the ShuffleManager. */
-  override def unregisterShuffle(shuffleId: Int): Unit = {}
+  override def unregisterShuffle(shuffleId: Int): Boolean = {
+    shuffleBlockManager.removeShuffle(shuffleId)
+  }
+
+  override def shuffleBlockManager: FileShuffleBlockManager = {
+    fileShuffleBlockManager
+  }
 
   /** Shut down this ShuffleManager. */
-  override def stop(): Unit = {}
+  override def stop(): Unit = {
+    shuffleBlockManager.stop()
+  }
 }
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 51e454d9313c924fb5f0e58f64dccd334713c46d..4b9454d75abb7cc2c1367a7941007828232e6727 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -17,14 +17,15 @@
 
 package org.apache.spark.shuffle.hash
 
-import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter}
-import org.apache.spark.{Logging, MapOutputTracker, SparkEnv, TaskContext}
-import org.apache.spark.storage.{BlockObjectWriter}
-import org.apache.spark.serializer.Serializer
+import org.apache.spark._
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle._
+import org.apache.spark.storage.BlockObjectWriter
 
 private[spark] class HashShuffleWriter[K, V](
+    shuffleBlockManager: FileShuffleBlockManager,
     handle: BaseShuffleHandle[K, V, _],
     mapId: Int,
     context: TaskContext)
@@ -43,7 +44,6 @@ private[spark] class HashShuffleWriter[K, V](
   metrics.shuffleWriteMetrics = Some(writeMetrics)
 
   private val blockManager = SparkEnv.get.blockManager
-  private val shuffleBlockManager = blockManager.shuffleBlockManager
   private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
   private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser,
     writeMetrics)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 6dcca47ea7c0c9ebfee3c55b587f8ce95efaa973..b727438ae7e47c13daf8f451e46db1c011f2f96c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -17,14 +17,17 @@
 
 package org.apache.spark.shuffle.sort
 
-import java.io.{DataInputStream, FileInputStream}
+import java.util.concurrent.ConcurrentHashMap
 
+import org.apache.spark.{SparkConf, TaskContext, ShuffleDependency}
 import org.apache.spark.shuffle._
-import org.apache.spark.{TaskContext, ShuffleDependency}
 import org.apache.spark.shuffle.hash.HashShuffleReader
-import org.apache.spark.storage.{DiskBlockManager, FileSegment, ShuffleBlockId}
 
-private[spark] class SortShuffleManager extends ShuffleManager {
+private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager {
+
+  private val indexShuffleBlockManager = new IndexShuffleBlockManager()
+  private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]()
+
   /**
    * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
    */
@@ -52,29 +55,29 @@ private[spark] class SortShuffleManager extends ShuffleManager {
   /** Get a writer for a given partition. Called on executors by map tasks. */
   override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
       : ShuffleWriter[K, V] = {
-    new SortShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
+    val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]]
+    shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps)
+    new SortShuffleWriter(
+      shuffleBlockManager, baseShuffleHandle, mapId, context)
   }
 
   /** Remove a shuffle's metadata from the ShuffleManager. */
-  override def unregisterShuffle(shuffleId: Int): Unit = {}
+  override def unregisterShuffle(shuffleId: Int): Boolean = {
+    if (shuffleMapNumber.containsKey(shuffleId)) {
+      val numMaps = shuffleMapNumber.remove(shuffleId)
+      (0 until numMaps).map{ mapId =>
+        shuffleBlockManager.removeDataByMap(shuffleId, mapId)
+      }
+    }
+    true
+  }
 
-  /** Shut down this ShuffleManager. */
-  override def stop(): Unit = {}
+  override def shuffleBlockManager: IndexShuffleBlockManager = {
+    indexShuffleBlockManager
+  }
 
-  /** Get the location of a block in a map output file. Uses the index file we create for it. */
-  def getBlockLocation(blockId: ShuffleBlockId, diskManager: DiskBlockManager): FileSegment = {
-    // The block is actually going to be a range of a single map output file for this map, so
-    // figure out the ID of the consolidated file, then the offset within that from our index
-    val consolidatedId = blockId.copy(reduceId = 0)
-    val indexFile = diskManager.getFile(consolidatedId.name + ".index")
-    val in = new DataInputStream(new FileInputStream(indexFile))
-    try {
-      in.skip(blockId.reduceId * 8)
-      val offset = in.readLong()
-      val nextOffset = in.readLong()
-      new FileSegment(diskManager.getFile(consolidatedId), offset, nextOffset - offset)
-    } finally {
-      in.close()
-    }
+  /** Shut down this ShuffleManager. */
+  override def stop(): Unit = {
+    shuffleBlockManager.stop()
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index b8c9ad46ab035bf5befd7f5f2ee0800a384ea28c..89a78d6982ba0fd6c1286a67f8d375436aecc60d 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -17,29 +17,25 @@
 
 package org.apache.spark.shuffle.sort
 
-import java.io.File
-
 import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext}
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.scheduler.MapStatus
-import org.apache.spark.shuffle.{ShuffleWriter, BaseShuffleHandle}
+import org.apache.spark.shuffle.{IndexShuffleBlockManager, ShuffleWriter, BaseShuffleHandle}
 import org.apache.spark.storage.ShuffleBlockId
 import org.apache.spark.util.collection.ExternalSorter
 
 private[spark] class SortShuffleWriter[K, V, C](
+    shuffleBlockManager: IndexShuffleBlockManager,
     handle: BaseShuffleHandle[K, V, C],
     mapId: Int,
     context: TaskContext)
   extends ShuffleWriter[K, V] with Logging {
 
   private val dep = handle.dependency
-  private val numPartitions = dep.partitioner.numPartitions
 
   private val blockManager = SparkEnv.get.blockManager
 
   private var sorter: ExternalSorter[K, V, _] = null
-  private var outputFile: File = null
-  private var indexFile: File = null
 
   // Are we in the process of stopping? Because map tasks can call stop() with success = true
   // and then call stop() with success = false if they get an exception, we want to make sure
@@ -69,17 +65,10 @@ private[spark] class SortShuffleWriter[K, V, C](
       sorter.insertAll(records)
     }
 
-    // Create a single shuffle file with reduce ID 0 that we'll write all results to. We'll later
-    // serve different ranges of this file using an index file that we create at the end.
-    val blockId = ShuffleBlockId(dep.shuffleId, mapId, 0)
-
-    outputFile = blockManager.diskBlockManager.getFile(blockId)
-    indexFile = blockManager.diskBlockManager.getFile(blockId.name + ".index")
-
-    val partitionLengths = sorter.writePartitionedFile(blockId, context)
-
-    // Register our map output with the ShuffleBlockManager, which handles cleaning it over time
-    blockManager.shuffleBlockManager.addCompletedMap(dep.shuffleId, mapId, numPartitions)
+    val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId)
+    val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId)
+    val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
+    shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
 
     mapStatus = new MapStatus(blockManager.blockManagerId,
       partitionLengths.map(MapOutputTracker.compressSize))
@@ -95,13 +84,8 @@ private[spark] class SortShuffleWriter[K, V, C](
       if (success) {
         return Option(mapStatus)
       } else {
-        // The map task failed, so delete our output file if we created one
-        if (outputFile != null) {
-          outputFile.delete()
-        }
-        if (indexFile != null) {
-          indexFile.delete()
-        }
+        // The map task failed, so delete our output data.
+        shuffleBlockManager.removeDataByMap(dep.shuffleId, mapId)
         return None
       }
     } finally {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
index d07e6a1b1836cad3f0c1dc6d8ab92894170e3306..e35b7fe62c753741f6bf58af5ae92424ea0f6531 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
@@ -197,7 +197,7 @@ object BlockFetcherIterator {
       for (id <- localBlocksToFetch) {
         try {
           readMetrics.localBlocksFetched += 1
-          results.put(new FetchResult(id, 0, () => getLocalFromDisk(id, serializer).get))
+          results.put(new FetchResult(id, 0, () => getLocalShuffleFromDisk(id, serializer).get))
           logDebug("Got local block " + id)
         } catch {
           case e: Exception => {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index c1756ac9054172e7bf228eeba201feb8f2823d91..a83a3f468ae5f3c33355e0fcd50cfac13e3c6b66 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -58,6 +58,11 @@ case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends Blo
   def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
 }
 
+@DeveloperApi
+case class ShuffleDataBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
+  def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data"
+}
+
 @DeveloperApi
 case class ShuffleIndexBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
   def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index"
@@ -92,6 +97,7 @@ private[spark] case class TestBlockId(id: String) extends BlockId {
 object BlockId {
   val RDD = "rdd_([0-9]+)_([0-9]+)".r
   val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
+  val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r
   val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r
   val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
   val TASKRESULT = "taskresult_([0-9]+)".r
@@ -104,6 +110,8 @@ object BlockId {
       RDDBlockId(rddId.toInt, splitIndex.toInt)
     case SHUFFLE(shuffleId, mapId, reduceId) =>
       ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
+    case SHUFFLE_DATA(shuffleId, mapId, reduceId) =>
+      ShuffleDataBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
     case SHUFFLE_INDEX(shuffleId, mapId, reduceId) =>
       ShuffleIndexBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
     case BROADCAST(broadcastId, field) =>
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index cfe5b6c50aea27738be457a5973194a2434fe27e..a714142763243cffb8a4ae04763d74a496f8c6e4 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -64,8 +64,8 @@ private[spark] class BlockManager(
   extends BlockDataProvider with Logging {
 
   private val port = conf.getInt("spark.blockManager.port", 0)
-  val shuffleBlockManager = new ShuffleBlockManager(this, shuffleManager)
-  val diskBlockManager = new DiskBlockManager(shuffleBlockManager, conf)
+
+  val diskBlockManager = new DiskBlockManager(this, conf)
   val connectionManager =
     new ConnectionManager(port, conf, securityManager, "Connection manager for block manager")
 
@@ -83,7 +83,7 @@ private[spark] class BlockManager(
     val tachyonStorePath = s"$storeDir/$appFolderName/${this.executorId}"
     val tachyonMaster = conf.get("spark.tachyonStore.url",  "tachyon://localhost:19998")
     val tachyonBlockManager =
-      new TachyonBlockManager(shuffleBlockManager, tachyonStorePath, tachyonMaster)
+      new TachyonBlockManager(this, tachyonStorePath, tachyonMaster)
     tachyonInitialized = true
     new TachyonStore(this, tachyonBlockManager)
   }
@@ -215,7 +215,7 @@ private[spark] class BlockManager(
   override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = {
     val bid = BlockId(blockId)
     if (bid.isShuffle) {
-      Left(diskBlockManager.getBlockLocation(bid))
+      shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId])
     } else {
       val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]]
       if (blockBytesOpt.isDefined) {
@@ -333,8 +333,14 @@ private[spark] class BlockManager(
    * shuffle blocks. It is safe to do so without a lock on block info since disk store
    * never deletes (recent) items.
    */
-  def getLocalFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
-    diskStore.getValues(blockId, serializer).orElse {
+  def getLocalShuffleFromDisk(
+      blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
+
+    val shuffleBlockManager = shuffleManager.shuffleBlockManager
+    val values = shuffleBlockManager.getBytes(blockId.asInstanceOf[ShuffleBlockId]).map(
+      bytes => this.dataDeserialize(blockId, bytes, serializer))
+
+    values.orElse {
       throw new BlockException(blockId, s"Block $blockId not found on disk, though it should be")
     }
   }
@@ -355,7 +361,8 @@ private[spark] class BlockManager(
     // As an optimization for map output fetches, if the block is for a shuffle, return it
     // without acquiring a lock; the disk store never deletes (recent) items so this should work
     if (blockId.isShuffle) {
-      diskStore.getBytes(blockId) match {
+      val shuffleBlockManager = shuffleManager.shuffleBlockManager
+      shuffleBlockManager.getBytes(blockId.asInstanceOf[ShuffleBlockId]) match {
         case Some(bytes) =>
           Some(bytes)
         case None =>
@@ -1045,7 +1052,6 @@ private[spark] class BlockManager(
 
   def stop(): Unit = {
     connectionManager.stop()
-    shuffleBlockManager.stop()
     diskBlockManager.stop()
     actorSystem.stop(slaveActor)
     blockInfo.clear()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
index c194e0fed336760c07d303c67155f23c5b2b2a1e..14ae2f38c567046f72d3eb9564b1e66ac6277cd6 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala
@@ -21,7 +21,7 @@ import scala.concurrent.Future
 
 import akka.actor.{ActorRef, Actor}
 
-import org.apache.spark.{Logging, MapOutputTracker}
+import org.apache.spark.{Logging, MapOutputTracker, SparkEnv}
 import org.apache.spark.storage.BlockManagerMessages._
 import org.apache.spark.util.ActorLogReceive
 
@@ -55,7 +55,7 @@ class BlockManagerSlaveActor(
         if (mapOutputTracker != null) {
           mapOutputTracker.unregisterShuffle(shuffleId)
         }
-        blockManager.shuffleBlockManager.removeShuffle(shuffleId)
+        SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId)
       }
 
     case RemoveBroadcast(broadcastId, tellMaster) =>
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index ec022ce9c048a976dc9bdf14d7b8c3d8591bcbd7..a715594f198c26717eea9a5a735fcdfd17b84820 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -21,11 +21,9 @@ import java.io.File
 import java.text.SimpleDateFormat
 import java.util.{Date, Random, UUID}
 
-import org.apache.spark.{SparkConf, SparkEnv, Logging}
+import org.apache.spark.{SparkConf, Logging}
 import org.apache.spark.executor.ExecutorExitCode
-import org.apache.spark.network.netty.PathResolver
 import org.apache.spark.util.Utils
-import org.apache.spark.shuffle.sort.SortShuffleManager
 
 /**
  * Creates and maintains the logical mapping between logical blocks and physical on-disk
@@ -36,13 +34,11 @@ import org.apache.spark.shuffle.sort.SortShuffleManager
  * Block files are hashed among the directories listed in spark.local.dir (or in
  * SPARK_LOCAL_DIRS, if it's set).
  */
-private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, conf: SparkConf)
-  extends PathResolver with Logging {
+private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkConf)
+  extends Logging {
 
   private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
-
-  private val subDirsPerLocalDir =
-    shuffleBlockManager.conf.getInt("spark.diskStore.subDirectories", 64)
+  private val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64)
 
   /* Create one local directory for each path mentioned in spark.local.dir; then, inside this
    * directory, create multiple subdirectories that we will hash files into, in order to avoid
@@ -56,26 +52,6 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager,
 
   addShutdownHook()
 
-  /**
-   * Returns the physical file segment in which the given BlockId is located. If the BlockId has
-   * been mapped to a specific FileSegment by the shuffle layer, that will be returned.
-   * Otherwise, we assume the Block is mapped to the whole file identified by the BlockId.
-   */
-  def getBlockLocation(blockId: BlockId): FileSegment = {
-    val env = SparkEnv.get  // NOTE: can be null in unit tests
-    if (blockId.isShuffle && env != null && env.shuffleManager.isInstanceOf[SortShuffleManager]) {
-      // For sort-based shuffle, let it figure out its blocks
-      val sortShuffleManager = env.shuffleManager.asInstanceOf[SortShuffleManager]
-      sortShuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId], this)
-    } else if (blockId.isShuffle && shuffleBlockManager.consolidateShuffleFiles) {
-      // For hash-based shuffle with consolidated files, ShuffleBlockManager takes care of this
-      shuffleBlockManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId])
-    } else {
-      val file = getFile(blockId.name)
-      new FileSegment(file, 0, file.length())
-    }
-  }
-
   def getFile(filename: String): File = {
     // Figure out which local directory it hashes to, and which subdirectory in that
     val hash = Utils.nonNegativeHash(filename)
@@ -105,7 +81,7 @@ private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager,
 
   /** Check if disk block manager has a block. */
   def containsBlock(blockId: BlockId): Boolean = {
-    getBlockLocation(blockId).file.exists()
+    getFile(blockId.name).exists()
   }
 
   /** List all the files currently stored on disk by the disk manager. */
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index c83261dd91b36cb758ec183562b1f1ee4d6a2b59..e9304f6bb45d004cea44b41170737fa08a9afe05 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.storage
 
-import java.io.{FileOutputStream, RandomAccessFile}
+import java.io.{File, FileOutputStream, RandomAccessFile}
 import java.nio.ByteBuffer
 import java.nio.channels.FileChannel.MapMode
 
@@ -34,7 +34,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc
   val minMemoryMapBytes = blockManager.conf.getLong("spark.storage.memoryMapThreshold", 2 * 4096L)
 
   override def getSize(blockId: BlockId): Long = {
-    diskManager.getBlockLocation(blockId).length
+    diskManager.getFile(blockId.name).length
   }
 
   override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel): PutResult = {
@@ -89,25 +89,33 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc
     }
   }
 
-  override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
-    val segment = diskManager.getBlockLocation(blockId)
-    val channel = new RandomAccessFile(segment.file, "r").getChannel
+  private def getBytes(file: File, offset: Long, length: Long): Option[ByteBuffer] = {
+    val channel = new RandomAccessFile(file, "r").getChannel
 
     try {
       // For small files, directly read rather than memory map
-      if (segment.length < minMemoryMapBytes) {
-        val buf = ByteBuffer.allocate(segment.length.toInt)
-        channel.read(buf, segment.offset)
+      if (length < minMemoryMapBytes) {
+        val buf = ByteBuffer.allocate(length.toInt)
+        channel.read(buf, offset)
         buf.flip()
         Some(buf)
       } else {
-        Some(channel.map(MapMode.READ_ONLY, segment.offset, segment.length))
+        Some(channel.map(MapMode.READ_ONLY, offset, length))
       }
     } finally {
       channel.close()
     }
   }
 
+  override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
+    val file = diskManager.getFile(blockId.name)
+    getBytes(file, 0, file.length)
+  }
+
+  def getBytes(segment: FileSegment): Option[ByteBuffer] = {
+    getBytes(segment.file, segment.offset, segment.length)
+  }
+
   override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
     getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer))
   }
@@ -117,24 +125,25 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc
    * shuffle short-circuit code.
    */
   def getValues(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
+    // TODO: Should bypass getBytes and use a stream based implementation, so that
+    // we won't use a lot of memory during e.g. external sort merge.
     getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes, serializer))
   }
 
   override def remove(blockId: BlockId): Boolean = {
-    val fileSegment = diskManager.getBlockLocation(blockId)
-    val file = fileSegment.file
-    if (file.exists() && file.length() == fileSegment.length) {
+    val file = diskManager.getFile(blockId.name)
+    // If consolidation mode is used With HashShuffleMananger, the physical filename for the block
+    // is different from blockId.name. So the file returns here will not be exist, thus we avoid to
+    // delete the whole consolidated file by mistake.
+    if (file.exists()) {
       file.delete()
     } else {
-      if (fileSegment.length < file.length()) {
-        logWarning(s"Could not delete block associated with only a part of a file: $blockId")
-      }
       false
     }
   }
 
   override def contains(blockId: BlockId): Boolean = {
-    val file = diskManager.getBlockLocation(blockId).file
+    val file = diskManager.getFile(blockId.name)
     file.exists()
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
index a6cbe3aa440ffee661221fe4a3150cc356bd53a9..6908a59a79e60fe251d4f0d0a4b3d14c800ca98e 100644
--- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
@@ -35,7 +35,7 @@ import org.apache.spark.util.Utils
  * @param rootDirs The directories to use for storing block files. Data will be hashed among these.
  */
 private[spark] class TachyonBlockManager(
-    shuffleManager: ShuffleBlockManager,
+    blockManager: BlockManager,
     rootDirs: String,
     val master: String)
   extends Logging {
@@ -49,7 +49,7 @@ private[spark] class TachyonBlockManager(
 
   private val MAX_DIR_CREATION_ATTEMPTS = 10
   private val subDirsPerTachyonDir =
-    shuffleManager.conf.get("spark.tachyonStore.subDirectories", "64").toInt
+    blockManager.conf.get("spark.tachyonStore.subDirectories", "64").toInt
 
   // Create one Tachyon directory for each path mentioned in spark.tachyonStore.folderName;
   // then, inside this directory, create multiple subdirectories that we will hash files into,
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 5d8a648d9551ed79fa2760c44ffcf50bd506de7c..782b979e2e93d4a2b4e59f118442721710d7b30d 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -719,20 +719,20 @@ private[spark] class ExternalSorter[K, V, C](
   def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2)
 
   /**
-   * Write all the data added into this ExternalSorter into a file in the disk store, creating
-   * an .index file for it as well with the offsets of each partition. This is called by the
-   * SortShuffleWriter and can go through an efficient path of just concatenating binary files
-   * if we decided to avoid merge-sorting.
+   * Write all the data added into this ExternalSorter into a file in the disk store. This is
+   * called by the SortShuffleWriter and can go through an efficient path of just concatenating
+   * binary files if we decided to avoid merge-sorting.
    *
    * @param blockId block ID to write to. The index file will be blockId.name + ".index".
    * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
    * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
    */
-  def writePartitionedFile(blockId: BlockId, context: TaskContext): Array[Long] = {
-    val outputFile = blockManager.diskBlockManager.getFile(blockId)
+  def writePartitionedFile(
+      blockId: BlockId,
+      context: TaskContext,
+      outputFile: File): Array[Long] = {
 
     // Track location of each range in the output file
-    val offsets = new Array[Long](numPartitions + 1)
     val lengths = new Array[Long](numPartitions)
 
     if (bypassMergeSort && partitionWriters != null) {
@@ -750,7 +750,6 @@ private[spark] class ExternalSorter[K, V, C](
           in.close()
           in = null
           lengths(i) = size
-          offsets(i + 1) = offsets(i) + lengths(i)
         }
       } finally {
         if (out != null) {
@@ -772,11 +771,7 @@ private[spark] class ExternalSorter[K, V, C](
           }
           writer.commitAndClose()
           val segment = writer.fileSegment()
-          offsets(id + 1) = segment.offset + segment.length
           lengths(id) = segment.length
-        } else {
-          // The partition is empty; don't create a new writer to avoid writing headers, etc
-          offsets(id + 1) = offsets(id)
         }
       }
     }
@@ -784,23 +779,6 @@ private[spark] class ExternalSorter[K, V, C](
     context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
     context.taskMetrics.diskBytesSpilled += diskBytesSpilled
 
-    // Write an index file with the offsets of each block, plus a final offset at the end for the
-    // end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure
-    // out where each block begins and ends.
-
-    val diskBlockManager = blockManager.diskBlockManager
-    val indexFile = diskBlockManager.getFile(blockId.name + ".index")
-    val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
-    try {
-      var i = 0
-      while (i < numPartitions + 1) {
-        out.writeLong(offsets(i))
-        i += 1
-      }
-    } finally {
-      out.close()
-    }
-
     lengths
   }
 
@@ -811,7 +789,7 @@ private[spark] class ExternalSorter[K, V, C](
     if (writer.isOpen) {
       writer.commitAndClose()
     }
-    blockManager.getLocalFromDisk(writer.blockId, ser).get.asInstanceOf[Iterator[Product2[K, C]]]
+    blockManager.diskStore.getValues(writer.blockId, ser).get.asInstanceOf[Iterator[Product2[K, C]]]
   }
 
   def stop(): Unit = {
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..6061e544e79b4b832d1d06e9af3f8c9898fe29e9
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.hash
+
+import java.io.{File, FileWriter}
+
+import scala.language.reflectiveCalls
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.{SparkEnv, SparkContext, LocalSparkContext, SparkConf}
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.shuffle.FileShuffleBlockManager
+import org.apache.spark.storage.{ShuffleBlockId, FileSegment}
+
+class HashShuffleManagerSuite extends FunSuite with LocalSparkContext {
+  private val testConf = new SparkConf(false)
+
+  private def checkSegments(segment1: FileSegment, segment2: FileSegment) {
+    assert (segment1.file.getCanonicalPath === segment2.file.getCanonicalPath)
+    assert (segment1.offset === segment2.offset)
+    assert (segment1.length === segment2.length)
+  }
+
+  test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") {
+
+    val conf = new SparkConf(false)
+    // reset after EACH object write. This is to ensure that there are bytes appended after
+    // an object is written. So if the codepaths assume writeObject is end of data, this should
+    // flush those bugs out. This was common bug in ExternalAppendOnlyMap, etc.
+    conf.set("spark.serializer.objectStreamReset", "1")
+    conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager")
+
+    sc = new SparkContext("local", "test", conf)
+
+    val shuffleBlockManager =
+      SparkEnv.get.shuffleManager.shuffleBlockManager.asInstanceOf[FileShuffleBlockManager]
+
+    val shuffle1 = shuffleBlockManager.forMapTask(1, 1, 1, new JavaSerializer(conf),
+      new ShuffleWriteMetrics)
+    for (writer <- shuffle1.writers) {
+      writer.write("test1")
+      writer.write("test2")
+    }
+    for (writer <- shuffle1.writers) {
+      writer.commitAndClose()
+    }
+
+    val shuffle1Segment = shuffle1.writers(0).fileSegment()
+    shuffle1.releaseWriters(success = true)
+
+    val shuffle2 = shuffleBlockManager.forMapTask(1, 2, 1, new JavaSerializer(conf),
+      new ShuffleWriteMetrics)
+
+    for (writer <- shuffle2.writers) {
+      writer.write("test3")
+      writer.write("test4")
+    }
+    for (writer <- shuffle2.writers) {
+      writer.commitAndClose()
+    }
+    val shuffle2Segment = shuffle2.writers(0).fileSegment()
+    shuffle2.releaseWriters(success = true)
+
+    // Now comes the test :
+    // Write to shuffle 3; and close it, but before registering it, check if the file lengths for
+    // previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length
+    // of block based on remaining data in file : which could mess things up when there is concurrent read
+    // and writes happening to the same shuffle group.
+
+    val shuffle3 = shuffleBlockManager.forMapTask(1, 3, 1, new JavaSerializer(testConf),
+      new ShuffleWriteMetrics)
+    for (writer <- shuffle3.writers) {
+      writer.write("test3")
+      writer.write("test4")
+    }
+    for (writer <- shuffle3.writers) {
+      writer.commitAndClose()
+    }
+    // check before we register.
+    checkSegments(shuffle2Segment, shuffleBlockManager.getBlockData(ShuffleBlockId(1, 2, 0)).left.get)
+    shuffle3.releaseWriters(success = true)
+    checkSegments(shuffle2Segment, shuffleBlockManager.getBlockData(ShuffleBlockId(1, 2, 0)).left.get)
+    shuffleBlockManager.removeShuffle(1)
+
+  }
+
+
+  def writeToFile(file: File, numBytes: Int) {
+    val writer = new FileWriter(file, true)
+    for (i <- 0 until numBytes) writer.write(i)
+    writer.close()
+  }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
index fbfcb5156d496603524baafac77ae2750bfd7208..3c86f6bafcaa328b5febacf76f8ee49ec97e93f2 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
@@ -60,11 +60,11 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
     }
 
     // 3rd block is going to fail
-    doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(0)), any())
-    doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(1)), any())
-    doAnswer(answer).when(blockManager).getLocalFromDisk(meq(blIds(2)), any())
-    doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(3)), any())
-    doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(4)), any())
+    doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any())
+    doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any())
+    doAnswer(answer).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any())
+    doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any())
+    doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any())
 
     val bmId = BlockManagerId("test-client", "test-client", 1)
     val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
@@ -76,24 +76,24 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
 
     iterator.initialize()
 
-    // Without exhausting the iterator, the iterator should be lazy and not call getLocalFromDisk.
-    verify(blockManager, times(0)).getLocalFromDisk(any(), any())
+    // Without exhausting the iterator, the iterator should be lazy and not call getLocalShuffleFromDisk.
+    verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any())
 
     assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
     // the 2nd element of the tuple returned by iterator.next should be defined when fetching successfully
     assert(iterator.next()._2.isDefined, "1st element should be defined but is not actually defined")
-    verify(blockManager, times(1)).getLocalFromDisk(any(), any())
+    verify(blockManager, times(1)).getLocalShuffleFromDisk(any(), any())
 
     assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element")
     assert(iterator.next()._2.isDefined, "2nd element should be defined but is not actually defined")
-    verify(blockManager, times(2)).getLocalFromDisk(any(), any())
+    verify(blockManager, times(2)).getLocalShuffleFromDisk(any(), any())
 
     assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements")
     // 3rd fetch should be failed
     intercept[Exception] {
       iterator.next()
     }
-    verify(blockManager, times(3)).getLocalFromDisk(any(), any())
+    verify(blockManager, times(3)).getLocalShuffleFromDisk(any(), any())
   }
 
 
@@ -115,11 +115,11 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
     val optItr = mock(classOf[Option[Iterator[Any]]])
  
    // All blocks should be fetched successfully
-    doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(0)), any())
-    doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(1)), any())
-    doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(2)), any())
-    doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(3)), any())
-    doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(4)), any())
+    doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any())
+    doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any())
+    doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any())
+    doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any())
+    doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any())
 
     val bmId = BlockManagerId("test-client", "test-client", 1)
     val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
@@ -131,8 +131,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
 
     iterator.initialize()
 
-    // Without exhausting the iterator, the iterator should be lazy and not call getLocalFromDisk.
-    verify(blockManager, times(0)).getLocalFromDisk(any(), any())
+    // Without exhausting the iterator, the iterator should be lazy and not call getLocalShuffleFromDisk.
+    verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any())
 
     assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
     assert(iterator.next._2.isDefined, "All elements should be defined but 1st element is not actually defined") 
@@ -145,7 +145,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
     assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements")
     assert(iterator.next._2.isDefined, "All elements should be defined but 5th element is not actually defined")
 
-    verify(blockManager, times(5)).getLocalFromDisk(any(), any())
+    verify(blockManager, times(5)).getLocalShuffleFromDisk(any(), any())
   }
 
   test("block fetch from remote fails using BasicBlockFetcherIterator") {
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index bdcea07e5714f4bf48fd4e8cf58c4b8345bfa4cc..14ffadab99caeb99ab5e78a49b98224c55e3585e 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -49,6 +49,7 @@ import scala.concurrent.Await
 import scala.concurrent.duration._
 import scala.language.implicitConversions
 import scala.language.postfixOps
+import org.apache.spark.shuffle.ShuffleBlockManager
 
 class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
   with PrivateMethodTester {
@@ -823,11 +824,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
     // be nice to refactor classes involved in disk storage in a way that
     // allows for easier testing.
     val blockManager = mock(classOf[BlockManager])
-    val shuffleBlockManager = mock(classOf[ShuffleBlockManager])
-    when(shuffleBlockManager.conf).thenReturn(conf)
-    val diskBlockManager = new DiskBlockManager(shuffleBlockManager, conf)
-
     when(blockManager.conf).thenReturn(conf.clone.set(confKey, 0.toString))
+    val diskBlockManager = new DiskBlockManager(blockManager, conf)
+
     val diskStoreMapped = new DiskStore(blockManager, diskBlockManager)
     diskStoreMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY)
     val mapped = diskStoreMapped.getBytes(blockId).get
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
index aabaeadd7a071139cf063b535b4249286a479934..26082ded8ca7a73ee09bdb5438f252157575122e 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -26,6 +26,7 @@ import scala.language.reflectiveCalls
 
 import akka.actor.Props
 import com.google.common.io.Files
+import org.mockito.Mockito.{mock, when}
 import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
 
 import org.apache.spark.SparkConf
@@ -40,18 +41,8 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
   private var rootDir1: File = _
   private var rootDirs: String = _
 
-  // This suite focuses primarily on consolidation features,
-  // so we coerce consolidation if not already enabled.
-  testConf.set("spark.shuffle.consolidateFiles", "true")
-
-  private val shuffleManager = new HashShuffleManager(testConf.clone)
-
-  val shuffleBlockManager = new ShuffleBlockManager(null, shuffleManager) {
-    override def conf = testConf.clone
-    var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]()
-    override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id)
-  }
-
+  val blockManager = mock(classOf[BlockManager])
+  when(blockManager.conf).thenReturn(testConf)
   var diskBlockManager: DiskBlockManager = _
 
   override def beforeAll() {
@@ -73,22 +64,17 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
   override def beforeEach() {
     val conf = testConf.clone
     conf.set("spark.local.dir", rootDirs)
-    diskBlockManager = new DiskBlockManager(shuffleBlockManager, conf)
-    shuffleBlockManager.idToSegmentMap.clear()
+    diskBlockManager = new DiskBlockManager(blockManager, conf)
   }
 
   override def afterEach() {
     diskBlockManager.stop()
-    shuffleBlockManager.idToSegmentMap.clear()
   }
 
   test("basic block creation") {
     val blockId = new TestBlockId("test")
-    assertSegmentEquals(blockId, blockId.name, 0, 0)
-
     val newFile = diskBlockManager.getFile(blockId)
     writeToFile(newFile, 10)
-    assertSegmentEquals(blockId, blockId.name, 0, 10)
     assert(diskBlockManager.containsBlock(blockId))
     newFile.delete()
     assert(!diskBlockManager.containsBlock(blockId))
@@ -101,127 +87,6 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before
     assert(diskBlockManager.getAllBlocks.toSet === ids.toSet)
   }
 
-  test("block appending") {
-    val blockId = new TestBlockId("test")
-    val newFile = diskBlockManager.getFile(blockId)
-    writeToFile(newFile, 15)
-    assertSegmentEquals(blockId, blockId.name, 0, 15)
-    val newFile2 = diskBlockManager.getFile(blockId)
-    assert(newFile === newFile2)
-    writeToFile(newFile2, 12)
-    assertSegmentEquals(blockId, blockId.name, 0, 27)
-    newFile.delete()
-  }
-
-  test("block remapping") {
-    val filename = "test"
-    val blockId0 = new ShuffleBlockId(1, 2, 3)
-    val newFile = diskBlockManager.getFile(filename)
-    writeToFile(newFile, 15)
-    shuffleBlockManager.idToSegmentMap(blockId0) = new FileSegment(newFile, 0, 15)
-    assertSegmentEquals(blockId0, filename, 0, 15)
-
-    val blockId1 = new ShuffleBlockId(1, 2, 4)
-    val newFile2 = diskBlockManager.getFile(filename)
-    writeToFile(newFile2, 12)
-    shuffleBlockManager.idToSegmentMap(blockId1) = new FileSegment(newFile, 15, 12)
-    assertSegmentEquals(blockId1, filename, 15, 12)
-
-    assert(newFile === newFile2)
-    newFile.delete()
-  }
-
-  private def checkSegments(segment1: FileSegment, segment2: FileSegment) {
-    assert (segment1.file.getCanonicalPath === segment2.file.getCanonicalPath)
-    assert (segment1.offset === segment2.offset)
-    assert (segment1.length === segment2.length)
-  }
-
-  test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") {
-
-    val serializer = new JavaSerializer(testConf)
-    val confCopy = testConf.clone
-    // reset after EACH object write. This is to ensure that there are bytes appended after
-    // an object is written. So if the codepaths assume writeObject is end of data, this should
-    // flush those bugs out. This was common bug in ExternalAppendOnlyMap, etc.
-    confCopy.set("spark.serializer.objectStreamReset", "1")
-
-    val securityManager = new org.apache.spark.SecurityManager(confCopy)
-    // Do not use the shuffleBlockManager above !
-    val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, confCopy,
-      securityManager)
-    val master = new BlockManagerMaster(
-      actorSystem.actorOf(Props(new BlockManagerMasterActor(true, confCopy, new LiveListenerBus))),
-      confCopy)
-    val store = new BlockManager("<driver>", actorSystem, master , serializer, confCopy,
-      securityManager, null, shuffleManager)
-
-    try {
-
-      val shuffleManager = store.shuffleBlockManager
-
-      val shuffle1 = shuffleManager.forMapTask(1, 1, 1, serializer, new ShuffleWriteMetrics)
-      for (writer <- shuffle1.writers) {
-        writer.write("test1")
-        writer.write("test2")
-      }
-      for (writer <- shuffle1.writers) {
-        writer.commitAndClose()
-      }
-
-      val shuffle1Segment = shuffle1.writers(0).fileSegment()
-      shuffle1.releaseWriters(success = true)
-
-      val shuffle2 = shuffleManager.forMapTask(1, 2, 1, new JavaSerializer(testConf),
-        new ShuffleWriteMetrics)
-
-      for (writer <- shuffle2.writers) {
-        writer.write("test3")
-        writer.write("test4")
-      }
-      for (writer <- shuffle2.writers) {
-        writer.commitAndClose()
-      }
-      val shuffle2Segment = shuffle2.writers(0).fileSegment()
-      shuffle2.releaseWriters(success = true)
-
-      // Now comes the test :
-      // Write to shuffle 3; and close it, but before registering it, check if the file lengths for
-      // previous task (forof shuffle1) is the same as 'segments'. Earlier, we were inferring length
-      // of block based on remaining data in file : which could mess things up when there is concurrent read
-      // and writes happening to the same shuffle group.
-
-      val shuffle3 = shuffleManager.forMapTask(1, 3, 1, new JavaSerializer(testConf),
-        new ShuffleWriteMetrics)
-      for (writer <- shuffle3.writers) {
-        writer.write("test3")
-        writer.write("test4")
-      }
-      for (writer <- shuffle3.writers) {
-        writer.commitAndClose()
-      }
-      // check before we register.
-      checkSegments(shuffle2Segment, shuffleManager.getBlockLocation(ShuffleBlockId(1, 2, 0)))
-      shuffle3.releaseWriters(success = true)
-      checkSegments(shuffle2Segment, shuffleManager.getBlockLocation(ShuffleBlockId(1, 2, 0)))
-      shuffleManager.removeShuffle(1)
-    } finally {
-
-      if (store != null) {
-        store.stop()
-      }
-      actorSystem.shutdown()
-      actorSystem.awaitTermination()
-    }
-  }
-
-  def assertSegmentEquals(blockId: BlockId, filename: String, offset: Int, length: Int) {
-    val segment = diskBlockManager.getBlockLocation(blockId)
-    assert(segment.file.getName === filename)
-    assert(segment.offset === offset)
-    assert(segment.length === length)
-  }
-
   def writeToFile(file: File, numBytes: Int) {
     val writer = new FileWriter(file, true)
     for (i <- 0 until numBytes) writer.write(i)
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 300589394b96f4bcefe3b2d506f1a62f1baef254..fe8ffe6d97a05e2a4c840a02c79e8d939d44910a 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -58,6 +58,8 @@ object MimaExcludes {
               "org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"),
             ProblemFilters.exclude[MissingMethodProblem](
               "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"),
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.storage.DiskStore.getValues"),
             ProblemFilters.exclude[MissingMethodProblem](
               "org.apache.spark.storage.MemoryStore.Entry")
           ) ++
diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala
index 17bf7c2541d13565bda59a4b1d08b60d2b30cc1d..db58eb642b56d9bcc0b05b90faded93ada5ec576 100644
--- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala
+++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala
@@ -20,10 +20,11 @@ package org.apache.spark.tools
 import java.util.concurrent.{CountDownLatch, Executors}
 import java.util.concurrent.atomic.AtomicLong
 
+import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.SparkContext
 import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.shuffle.hash.HashShuffleManager
 import org.apache.spark.util.Utils
-import org.apache.spark.executor.ShuffleWriteMetrics
 
 /**
  * Internal utility for micro-benchmarking shuffle write performance.
@@ -50,13 +51,15 @@ object StoragePerfTester {
 
     System.setProperty("spark.shuffle.compress", "false")
     System.setProperty("spark.shuffle.sync", "true")
+    System.setProperty("spark.shuffle.manager",
+      "org.apache.spark.shuffle.hash.HashShuffleManager")
 
     // This is only used to instantiate a BlockManager. All thread scheduling is done manually.
     val sc = new SparkContext("local[4]", "Write Tester")
-    val blockManager = sc.env.blockManager
+    val hashShuffleManager = sc.env.shuffleManager.asInstanceOf[HashShuffleManager]
 
     def writeOutputBytes(mapId: Int, total: AtomicLong) = {
-      val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits,
+      val shuffle = hashShuffleManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits,
         new KryoSerializer(sc.conf), new ShuffleWriteMetrics())
       val writers = shuffle.writers
       for (i <- 1 to recordsPerMap) {