diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 24d97da6eb6e0fdef25c0b78f5d6c427fc28cb76..1dc71a04282e52940e9297f284b480a9d046dad7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -146,26 +146,26 @@ private[spark] class ShuffleMapTask(
     metrics = Some(context.taskMetrics)
 
     val blockManager = SparkEnv.get.blockManager
-    var shuffle: ShuffleBlocks = null
-    var buckets: ShuffleWriterGroup = null
+    val shuffleBlockManager = blockManager.shuffleBlockManager
+    var shuffle: ShuffleWriterGroup = null
+    var success = false
 
     try {
       // Obtain all the block writers for shuffle blocks.
       val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
-      shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser)
-      buckets = shuffle.acquireWriters(partitionId)
+      shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)
 
       // Write the map output to its associated buckets.
       for (elem <- rdd.iterator(split, context)) {
         val pair = elem.asInstanceOf[Product2[Any, Any]]
         val bucketId = dep.partitioner.getPartition(pair._1)
-        buckets.writers(bucketId).write(pair)
+        shuffle.writers(bucketId).write(pair)
       }
 
       // Commit the writes. Get the size of each bucket block (total block size).
       var totalBytes = 0L
       var totalTime = 0L
-      val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
+      val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
         writer.commit()
         val size = writer.fileSegment().length
         totalBytes += size
@@ -179,19 +179,20 @@ private[spark] class ShuffleMapTask(
       shuffleMetrics.shuffleWriteTime = totalTime
       metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
 
+      success = true
       new MapStatus(blockManager.blockManagerId, compressedSizes)
     } catch { case e: Exception =>
       // If there is an exception from running the task, revert the partial writes
       // and throw the exception upstream to Spark.
-      if (buckets != null) {
-        buckets.writers.foreach(_.revertPartialWrites())
+      if (shuffle != null) {
+        shuffle.writers.foreach(_.revertPartialWrites())
       }
       throw e
     } finally {
       // Release the writers back to the shuffle block manager.
-      if (shuffle != null && buckets != null) {
-        buckets.writers.foreach(_.close())
-        shuffle.releaseWriters(buckets)
+      if (shuffle != null && shuffle.writers != null) {
+        shuffle.writers.foreach(_.close())
+        shuffle.releaseWriters(success)
       }
       // Execute the callbacks on task completion.
       context.executeOnCompleteCallbacks()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 8f4d69df274680b26966f803f1377d2ed9f76dfd..ccc05f5f6582f836037c14fa43af35155d475acc 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.storage
 
-import java.io.{InputStream, OutputStream}
+import java.io.{File, InputStream, OutputStream}
 import java.nio.{ByteBuffer, MappedByteBuffer}
 
 import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet}
@@ -47,7 +47,7 @@ private[spark] class BlockManager(
   extends Logging {
 
   val shuffleBlockManager = new ShuffleBlockManager(this)
-  val diskBlockManager = new DiskBlockManager(
+  val diskBlockManager = new DiskBlockManager(shuffleBlockManager,
     System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
 
   private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
@@ -517,15 +517,11 @@ private[spark] class BlockManager(
    * This is currently used for writing shuffle files out. Callers should handle error
    * cases.
    */
-  def getDiskWriter(blockId: BlockId, filename: String, serializer: Serializer, bufferSize: Int)
+  def getDiskWriter(blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int)
     : BlockObjectWriter = {
     val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
-    val file = diskBlockManager.createBlockFile(blockId, filename, allowAppending =  true)
     val writer = new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
     writer.registerCloseEventHandler(() => {
-      if (shuffleBlockManager.consolidateShuffleFiles) {
-        diskBlockManager.mapBlockToFileSegment(blockId, writer.fileSegment())
-      }
       val myInfo = new ShuffleBlockInfo()
       blockInfo.put(blockId, myInfo)
       myInfo.markReady(writer.fileSegment().length)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 32d2dd06943a0952f7a6763397cbb81000b17933..e49c191c70a1176e9ac0a25994343f4556bcb8ac 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -78,11 +78,11 @@ abstract class BlockObjectWriter(val blockId: BlockId) {
 
 /** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
 class DiskBlockObjectWriter(
-                             blockId: BlockId,
-                             file: File,
-                             serializer: Serializer,
-                             bufferSize: Int,
-                             compressStream: OutputStream => OutputStream)
+    blockId: BlockId,
+    file: File,
+    serializer: Serializer,
+    bufferSize: Int,
+    compressStream: OutputStream => OutputStream)
   extends BlockObjectWriter(blockId)
   with Logging
 {
@@ -111,8 +111,8 @@ class DiskBlockObjectWriter(
   private var fos: FileOutputStream = null
   private var ts: TimeTrackingOutputStream = null
   private var objOut: SerializationStream = null
-  private var initialPosition = 0L
-  private var lastValidPosition = 0L
+  private val initialPosition = file.length()
+  private var lastValidPosition = initialPosition
   private var initialized = false
   private var _timeWriting = 0L
 
@@ -120,7 +120,6 @@ class DiskBlockObjectWriter(
     fos = new FileOutputStream(file, true)
     ts = new TimeTrackingOutputStream(fos)
     channel = fos.getChannel()
-    initialPosition = channel.position
     lastValidPosition = initialPosition
     bs = compressStream(new FastBufferedOutputStream(ts, bufferSize))
     objOut = serializer.newInstance().serializeStream(bs)
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index bcb58ad9467e6c8ff6fcf611ec570edaebb5c735..fcd2e9798295596b48966c4b4b2529f4730a7ea3 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -20,12 +20,11 @@ package org.apache.spark.storage
 import java.io.File
 import java.text.SimpleDateFormat
 import java.util.{Date, Random}
-import java.util.concurrent.ConcurrentHashMap
 
 import org.apache.spark.Logging
 import org.apache.spark.executor.ExecutorExitCode
 import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
-import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
+import org.apache.spark.util.Utils
 
 /**
  * Creates and maintains the logical mapping between logical blocks and physical on-disk
@@ -35,7 +34,8 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
  *
  * @param rootDirs The directories to use for storing block files. Data will be hashed among these.
  */
-private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver with Logging {
+private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootDirs: String)
+  extends PathResolver with Logging {
 
   private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
   private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
@@ -47,54 +47,23 @@ private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver wit
   private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
   private var shuffleSender : ShuffleSender = null
 
-  // Stores only Blocks which have been specifically mapped to segments of files
-  // (rather than the default, which maps a Block to a whole file).
-  // This keeps our bookkeeping down, since the file system itself tracks the standalone Blocks. 
-  private val blockToFileSegmentMap = new TimeStampedHashMap[BlockId, FileSegment]
-
-  val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DISK_BLOCK_MANAGER, this.cleanup)
-
   addShutdownHook()
 
-  /**
-   * Creates a logical mapping from the given BlockId to a segment of a file.
-   * This will cause any accesses of the logical BlockId to be directed to the specified
-   * physical location.
-   */
-  def mapBlockToFileSegment(blockId: BlockId, fileSegment: FileSegment) {
-    blockToFileSegmentMap.put(blockId, fileSegment)
-  }
-
   /**
    * Returns the phyiscal file segment in which the given BlockId is located.
    * If the BlockId has been mapped to a specific FileSegment, that will be returned.
    * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
    */
   def getBlockLocation(blockId: BlockId): FileSegment = {
-    if (blockToFileSegmentMap.internalMap.containsKey(blockId)) {
-      blockToFileSegmentMap.get(blockId).get
+    if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) {
+      shuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId])
     } else {
       val file = getFile(blockId.name)
       new FileSegment(file, 0, file.length())
     }
   }
 
-  /**
-   * Simply returns a File to place the given Block into. This does not physically create the file.
-   * If filename is given, that file will be used. Otherwise, we will use the BlockId to get
-   * a unique filename.
-   */
-  def createBlockFile(blockId: BlockId, filename: String = "", allowAppending: Boolean): File = {
-    val actualFilename = if (filename == "") blockId.name else filename
-    val file = getFile(actualFilename)
-    if (!allowAppending && file.exists()) {
-      throw new IllegalStateException(
-        "Attempted to create file that already exists: " + actualFilename)
-    }
-    file
-  }
-
-  private def getFile(filename: String): File = {
+  def getFile(filename: String): File = {
     // Figure out which local directory it hashes to, and which subdirectory in that
     val hash = Utils.nonNegativeHash(filename)
     val dirId = hash % localDirs.length
@@ -119,6 +88,8 @@ private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver wit
     new File(subDir, filename)
   }
 
+  def getFile(blockId: BlockId): File = getFile(blockId.name)
+
   private def createLocalDirs(): Array[File] = {
     logDebug("Creating local directories at root dirs '" + rootDirs + "'")
     val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
@@ -151,10 +122,6 @@ private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver wit
     }
   }
 
-  private def cleanup(cleanupTime: Long) {
-    blockToFileSegmentMap.clearOldValues(cleanupTime)
-  }
-
   private def addShutdownHook() {
     localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
     Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index a3c496f9e05c517f198510095471ab6623b40d22..5a1e7b44440fdac533ae6256ba61c33d70552b7d 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -44,7 +44,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
     val bytes = _bytes.duplicate()
     logDebug("Attempting to put block " + blockId)
     val startTime = System.currentTimeMillis
-    val file = diskManager.createBlockFile(blockId, allowAppending = false)
+    val file = diskManager.getFile(blockId)
     val channel = new FileOutputStream(file).getChannel()
     while (bytes.remaining > 0) {
       channel.write(bytes)
@@ -64,7 +64,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
 
     logDebug("Attempting to write values for block " + blockId)
     val startTime = System.currentTimeMillis
-    val file = diskManager.createBlockFile(blockId, allowAppending = false)
+    val file = diskManager.getFile(blockId)
     val outputStream = new FileOutputStream(file)
     blockManager.dataSerializeStream(blockId, outputStream, values.iterator)
     val length = file.length
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index 066e45a12b8c7a8e9784a42eba63c373e1b44378..2f1b049ce4839f631c737d2d6cc0f9947ac2c93c 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -17,33 +17,45 @@
 
 package org.apache.spark.storage
 
+import java.io.File
 import java.util.concurrent.ConcurrentLinkedQueue
 import java.util.concurrent.atomic.AtomicInteger
 
+import scala.collection.JavaConversions._
+
 import org.apache.spark.serializer.Serializer
+import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap}
+import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
+import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup
 
-private[spark]
-class ShuffleWriterGroup(val id: Int, val fileId: Int, val writers: Array[BlockObjectWriter])
+/** A group of writers for a ShuffleMapTask, one writer per reducer. */
+private[spark] trait ShuffleWriterGroup {
+  val writers: Array[BlockObjectWriter]
 
-private[spark]
-trait ShuffleBlocks {
-  def acquireWriters(mapId: Int): ShuffleWriterGroup
-  def releaseWriters(group: ShuffleWriterGroup)
+  /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */
+  def releaseWriters(success: Boolean)
 }
 
 /**
- * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one writer
- * per reducer.
+ * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file
+ * per reducer (this set of files is called a ShuffleFileGroup).
  *
  * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle
  * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer
- * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle files,
- * it releases them for another task.
+ * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle
+ * files, it releases them for another task.
  * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple:
  *   - shuffleId: The unique id given to the entire shuffle stage.
  *   - bucketId: The id of the output partition (i.e., reducer id)
  *   - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a
  *       time owns a particular fileId, and this id is returned to a pool when the task finishes.
+ * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length)
+ * that specifies where in a given file the actual block data is located.
+ *
+ * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping
+ * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for
+ * each block stored in each file. In order to find the location of a shuffle block, we search the
+ * files within a ShuffleFileGroups associated with the block's reducer.
  */
 private[spark]
 class ShuffleBlockManager(blockManager: BlockManager) {
@@ -52,45 +64,152 @@ class ShuffleBlockManager(blockManager: BlockManager) {
   val consolidateShuffleFiles =
     System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean
 
-  var nextFileId = new AtomicInteger(0)
-  val unusedFileIds = new ConcurrentLinkedQueue[java.lang.Integer]()
+  private val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
+
+  /**
+   * Contains all the state related to a particular shuffle. This includes a pool of unused
+   * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle.
+   */
+  private class ShuffleState() {
+    val nextFileId = new AtomicInteger(0)
+    val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
+    val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
+  }
+
+  type ShuffleId = Int
+  private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState]
+
+  private
+  val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup)
 
-  def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer) = {
-    new ShuffleBlocks {
-      // Get a group of writers for a map task.
-      override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
-        val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
-        val fileId = getUnusedFileId()
-        val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+  def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = {
+    new ShuffleWriterGroup {
+      shuffleStates.putIfAbsent(shuffleId, new ShuffleState())
+      private val shuffleState = shuffleStates(shuffleId)
+      private var fileGroup: ShuffleFileGroup = null
+
+      val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
+        fileGroup = getUnusedFileGroup()
+        Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
           val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
-          if (consolidateShuffleFiles) {
-            val filename = physicalFileName(shuffleId, bucketId, fileId)
-            blockManager.getDiskWriter(blockId, filename, serializer, bufferSize)
-          } else {
-            blockManager.getDiskWriter(blockId, blockId.name, serializer, bufferSize)
+          blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize)
+        }
+      } else {
+        Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+          val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
+          val blockFile = blockManager.diskBlockManager.getFile(blockId)
+          blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize)
+        }
+      }
+
+      override def releaseWriters(success: Boolean) {
+        if (consolidateShuffleFiles) {
+          if (success) {
+            val offsets = writers.map(_.fileSegment().offset)
+            fileGroup.recordMapOutput(mapId, offsets)
           }
+          recycleFileGroup(fileGroup)
         }
-        new ShuffleWriterGroup(mapId, fileId, writers)
       }
 
-      override def releaseWriters(group: ShuffleWriterGroup) {
-        recycleFileId(group.fileId)
+      private def getUnusedFileGroup(): ShuffleFileGroup = {
+        val fileGroup = shuffleState.unusedFileGroups.poll()
+        if (fileGroup != null) fileGroup else newFileGroup()
+      }
+
+      private def newFileGroup(): ShuffleFileGroup = {
+        val fileId = shuffleState.nextFileId.getAndIncrement()
+        val files = Array.tabulate[File](numBuckets) { bucketId =>
+          val filename = physicalFileName(shuffleId, bucketId, fileId)
+          blockManager.diskBlockManager.getFile(filename)
+        }
+        val fileGroup = new ShuffleFileGroup(fileId, shuffleId, files)
+        shuffleState.allFileGroups.add(fileGroup)
+        fileGroup
       }
-    }
-  }
 
-  private def getUnusedFileId(): Int = {
-    val fileId = unusedFileIds.poll()
-    if (fileId == null) nextFileId.getAndIncrement() else fileId
+      private def recycleFileGroup(group: ShuffleFileGroup) {
+        shuffleState.unusedFileGroups.add(group)
+      }
+    }
   }
 
-  private def recycleFileId(fileId: Int) {
-    if (consolidateShuffleFiles) {
-      unusedFileIds.add(fileId)
+  /**
+   * Returns the physical file segment in which the given BlockId is located.
+   * This function should only be called if shuffle file consolidation is enabled, as it is
+   * an error condition if we don't find the expected block.
+   */
+  def getBlockLocation(id: ShuffleBlockId): FileSegment = {
+    // Search all file groups associated with this shuffle.
+    val shuffleState = shuffleStates(id.shuffleId)
+    for (fileGroup <- shuffleState.allFileGroups) {
+      val segment = fileGroup.getFileSegmentFor(id.mapId, id.reduceId)
+      if (segment.isDefined) { return segment.get }
     }
+    throw new IllegalStateException("Failed to find shuffle block: " + id)
   }
 
   private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = {
     "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId)
   }
+
+  private def cleanup(cleanupTime: Long) {
+    shuffleStates.clearOldValues(cleanupTime)
+  }
+}
+
+private[spark]
+object ShuffleBlockManager {
+  /**
+   * A group of shuffle files, one per reducer.
+   * A particular mapper will be assigned a single ShuffleFileGroup to write its output to.
+   */
+  private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) {
+    /**
+     * Stores the absolute index of each mapId in the files of this group. For instance,
+     * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0.
+     */
+    private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]()
+
+    /**
+     * Stores consecutive offsets of blocks into each reducer file, ordered by position in the file.
+     * This ordering allows us to compute block lengths by examining the following block offset.
+     * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every
+     * reducer.
+     */
+    private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) {
+      new PrimitiveVector[Long]()
+    }
+
+    def numBlocks = mapIdToIndex.size
+
+    def apply(bucketId: Int) = files(bucketId)
+
+    def recordMapOutput(mapId: Int, offsets: Array[Long]) {
+      mapIdToIndex(mapId) = numBlocks
+      for (i <- 0 until offsets.length) {
+        blockOffsetsByReducer(i) += offsets(i)
+      }
+    }
+
+    /** Returns the FileSegment associated with the given map task, or None if no entry exists. */
+    def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = {
+      val file = files(reducerId)
+      val blockOffsets = blockOffsetsByReducer(reducerId)
+      val index = mapIdToIndex.getOrElse(mapId, -1)
+      if (index >= 0) {
+        val offset = blockOffsets(index)
+        val length =
+          if (index + 1 < numBlocks) {
+            blockOffsets(index + 1) - offset
+          } else {
+            file.length() - offset
+          }
+        assert(length >= 0)
+        Some(new FileSegment(file, offset, length))
+      } else {
+        None
+      }
+    }
+  }
 }
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index 3f963727d98ddd3aca90c8bb327e410dceb6f546..67a7f87a5ca6e40bdb254ebee8c61b6e459c856e 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -59,7 +59,7 @@ object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext
   "ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") {
 
   val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
-    SHUFFLE_MAP_TASK, BLOCK_MANAGER, DISK_BLOCK_MANAGER, BROADCAST_VARS = Value
+    SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
 
   type MetadataCleanerType = Value
 
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
index 4adf9cfb7611204a5c0d63e2ebc1a3d57dbc931c..d76143e45aa58117f8d64fdef20b4c9de28c01ee 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
@@ -53,6 +53,12 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest,
     _values(pos)
   }
 
+  /** Get the value for a given key, or returns elseValue if it doesn't exist. */
+  def getOrElse(k: K, elseValue: V): V = {
+    val pos = _keySet.getPos(k)
+    if (pos >= 0) _values(pos) else elseValue
+  }
+
   /** Set the value for a key */
   def update(k: K, v: V) {
     val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
new file mode 100644
index 0000000000000000000000000000000000000000..369519c5595de8b12a646bea9b26d1704d42f32d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+/** Provides a simple, non-threadsafe, array-backed vector that can store primitives. */
+private[spark]
+class PrimitiveVector[@specialized(Long, Int, Double) V: ClassManifest](initialSize: Int = 64) {
+  private var numElements = 0
+  private var array: Array[V] = _
+
+  // NB: This must be separate from the declaration, otherwise the specialized parent class
+  // will get its own array with the same initial size. TODO: Figure out why...
+  array = new Array[V](initialSize)
+
+  def apply(index: Int): V = {
+    require(index < numElements)
+    array(index)
+  }
+
+  def +=(value: V) {
+    if (numElements == array.length) { resize(array.length * 2) }
+    array(numElements) = value
+    numElements += 1
+  }
+
+  def length = numElements
+
+  def getUnderlyingArray = array
+
+  /** Resizes the array, dropping elements if the total length decreases. */
+  def resize(newLength: Int) {
+    val newArray = new Array[V](newLength)
+    array.copyToArray(newArray)
+    array = newArray
+  }
+}
diff --git a/core/src/main/scala/spark/storage/StoragePerfTester.scala b/core/src/main/scala/spark/storage/StoragePerfTester.scala
index 1b074e5ec72951f48f24251b30dd5049111c439d..68893a2bf49b1648d220abe94fae5bf09ea0aa2e 100644
--- a/core/src/main/scala/spark/storage/StoragePerfTester.scala
+++ b/core/src/main/scala/spark/storage/StoragePerfTester.scala
@@ -36,19 +36,19 @@ object StoragePerfTester {
     val blockManager = sc.env.blockManager
 
     def writeOutputBytes(mapId: Int, total: AtomicLong) = {
-      val shuffle = blockManager.shuffleBlockManager.forShuffle(1, numOutputSplits,
+      val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits,
         new KryoSerializer())
-      val buckets = shuffle.acquireWriters(mapId)
+      val writers = shuffle.writers
       for (i <- 1 to recordsPerMap) {
-        buckets.writers(i % numOutputSplits).write(writeData)
+        writers(i % numOutputSplits).write(writeData)
       }
-      buckets.writers.map {w =>
+      writers.map {w =>
         w.commit()
         total.addAndGet(w.fileSegment().length)
         w.close()
       }
 
-      shuffle.releaseWriters(buckets)
+      shuffle.releaseWriters(true)
     }
 
     val start = System.currentTimeMillis()
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..0b9056344c1dd6d4b1b77bf7b2afb2b22a64e84c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -0,0 +1,84 @@
+package org.apache.spark.storage
+
+import java.io.{FileWriter, File}
+
+import scala.collection.mutable
+
+import com.google.common.io.Files
+import org.scalatest.{BeforeAndAfterEach, FunSuite}
+
+class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach {
+
+  val rootDir0 = Files.createTempDir()
+  rootDir0.deleteOnExit()
+  val rootDir1 = Files.createTempDir()
+  rootDir1.deleteOnExit()
+  val rootDirs = rootDir0.getName + "," + rootDir1.getName
+  println("Created root dirs: " + rootDirs)
+
+  val shuffleBlockManager = new ShuffleBlockManager(null) {
+    var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]()
+    override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id)
+  }
+
+  var diskBlockManager: DiskBlockManager = _
+
+  override def beforeEach() {
+    diskBlockManager = new DiskBlockManager(shuffleBlockManager, rootDirs)
+    shuffleBlockManager.idToSegmentMap.clear()
+  }
+
+  test("basic block creation") {
+    val blockId = new TestBlockId("test")
+    assertSegmentEquals(blockId, blockId.name, 0, 0)
+
+    val newFile = diskBlockManager.getFile(blockId)
+    writeToFile(newFile, 10)
+    assertSegmentEquals(blockId, blockId.name, 0, 10)
+
+    newFile.delete()
+  }
+
+  test("block appending") {
+    val blockId = new TestBlockId("test")
+    val newFile = diskBlockManager.getFile(blockId)
+    writeToFile(newFile, 15)
+    assertSegmentEquals(blockId, blockId.name, 0, 15)
+    val newFile2 = diskBlockManager.getFile(blockId)
+    assert(newFile === newFile2)
+    writeToFile(newFile2, 12)
+    assertSegmentEquals(blockId, blockId.name, 0, 27)
+    newFile.delete()
+  }
+
+  test("block remapping") {
+    val filename = "test"
+    val blockId0 = new ShuffleBlockId(1, 2, 3)
+    val newFile = diskBlockManager.getFile(filename)
+    writeToFile(newFile, 15)
+    shuffleBlockManager.idToSegmentMap(blockId0) = new FileSegment(newFile, 0, 15)
+    assertSegmentEquals(blockId0, filename, 0, 15)
+
+    val blockId1 = new ShuffleBlockId(1, 2, 4)
+    val newFile2 = diskBlockManager.getFile(filename)
+    writeToFile(newFile2, 12)
+    shuffleBlockManager.idToSegmentMap(blockId1) = new FileSegment(newFile, 15, 12)
+    assertSegmentEquals(blockId1, filename, 15, 12)
+
+    assert(newFile === newFile2)
+    newFile.delete()
+  }
+
+  def assertSegmentEquals(blockId: BlockId, filename: String, offset: Int, length: Int) {
+    val segment = diskBlockManager.getBlockLocation(blockId)
+    assert(segment.file.getName === filename)
+    assert(segment.offset === offset)
+    assert(segment.length === length)
+  }
+
+  def writeToFile(file: File, numBytes: Int) {
+    val writer = new FileWriter(file, true)
+    for (i <- 0 until numBytes) writer.write(i)
+    writer.close()
+  }
+}