From 6201e5e2493b0f9addba57f60d6ddb88e572b858 Mon Sep 17 00:00:00 2001
From: Aaron Davidson <aaron@databricks.com>
Date: Mon, 4 Nov 2013 09:41:04 -0800
Subject: [PATCH] Refactor ShuffleBlockManager to reduce public interface

- ShuffleBlocks has been removed and replaced by ShuffleWriterGroup.
- ShuffleWriterGroup no longer contains a reference to a ShuffleFileGroup.
- ShuffleFile has been removed and its contents are now within ShuffleFileGroup.
- ShuffleBlockManager.forShuffle has been replaced by a more stateful forMapTask.
---
 .../spark/scheduler/ShuffleMapTask.scala      |  21 +-
 .../spark/storage/ShuffleBlockManager.scala   | 270 +++++++-----------
 .../spark/storage/StoragePerfTester.scala     |  10 +-
 3 files changed, 123 insertions(+), 178 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index c502f8f91a..1dc71a0428 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -146,27 +146,26 @@ private[spark] class ShuffleMapTask(
     metrics = Some(context.taskMetrics)
 
     val blockManager = SparkEnv.get.blockManager
-    var shuffle: ShuffleBlocks = null
-    var buckets: ShuffleWriterGroup = null
+    val shuffleBlockManager = blockManager.shuffleBlockManager
+    var shuffle: ShuffleWriterGroup = null
     var success = false
 
     try {
       // Obtain all the block writers for shuffle blocks.
       val ser = SparkEnv.get.serializerManager.get(dep.serializerClass)
-      shuffle = blockManager.shuffleBlockManager.forShuffle(dep.shuffleId, numOutputSplits, ser)
-      buckets = shuffle.acquireWriters(partitionId)
+      shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)
 
       // Write the map output to its associated buckets.
       for (elem <- rdd.iterator(split, context)) {
         val pair = elem.asInstanceOf[Product2[Any, Any]]
         val bucketId = dep.partitioner.getPartition(pair._1)
-        buckets.writers(bucketId).write(pair)
+        shuffle.writers(bucketId).write(pair)
       }
 
       // Commit the writes. Get the size of each bucket block (total block size).
       var totalBytes = 0L
       var totalTime = 0L
-      val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
+      val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
         writer.commit()
         val size = writer.fileSegment().length
         totalBytes += size
@@ -185,15 +184,15 @@ private[spark] class ShuffleMapTask(
     } catch { case e: Exception =>
       // If there is an exception from running the task, revert the partial writes
       // and throw the exception upstream to Spark.
-      if (buckets != null) {
-        buckets.writers.foreach(_.revertPartialWrites())
+      if (shuffle != null) {
+        shuffle.writers.foreach(_.revertPartialWrites())
       }
       throw e
     } finally {
       // Release the writers back to the shuffle block manager.
-      if (shuffle != null && buckets != null) {
-        buckets.writers.foreach(_.close())
-        shuffle.releaseWriters(buckets, success)
+      if (shuffle != null && shuffle.writers != null) {
+        shuffle.writers.foreach(_.close())
+        shuffle.releaseWriters(success)
       }
       // Execute the callbacks on task completion.
       context.executeOnCompleteCallbacks()
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index a3bb425208..6346db3894 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -18,31 +18,23 @@
 package org.apache.spark.storage
 
 import java.io.File
-import java.util
 import java.util.concurrent.ConcurrentLinkedQueue
 import java.util.concurrent.atomic.AtomicInteger
 
 import scala.collection.JavaConversions._
-import scala.collection.mutable
 
 import org.apache.spark.Logging
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap}
 import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
+import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup
 
-private[spark]
-class ShuffleWriterGroup(
-   val mapId: Int,
-   val fileGroup: ShuffleFileGroup,
-   val writers: Array[BlockObjectWriter])
-
-private[spark]
-trait ShuffleBlocks {
-  /** Get a group of writers for this map task. */
-  def acquireWriters(mapId: Int): ShuffleWriterGroup
+/** A group of writers for a ShuffleMapTask, one writer per reducer. */
+private[spark] trait ShuffleWriterGroup {
+  val writers: Array[BlockObjectWriter]
 
   /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */
-  def releaseWriters(group: ShuffleWriterGroup, success: Boolean)
+  def releaseWriters(success: Boolean)
 }
 
 /**
@@ -50,9 +42,9 @@ trait ShuffleBlocks {
  * per reducer (this set of files is called a ShuffleFileGroup).
  *
  * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle
- * blocks are aggregated into the same file. There is one "combined shuffle file" (ShuffleFile) per
- * reducer per concurrently executing shuffle task. As soon as a task finishes writing to its
- * shuffle files, it releases them for another task.
+ * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer
+ * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle
+ * files, it releases them for another task.
  * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple:
  *   - shuffleId: The unique id given to the entire shuffle stage.
  *   - bucketId: The id of the output partition (i.e., reducer id)
@@ -62,10 +54,9 @@ trait ShuffleBlocks {
  * that specifies where in a given file the actual block data is located.
  *
  * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping
- * ShuffleBlockIds directly to FileSegments, each ShuffleFile maintains a list of offsets for each
- * block stored in that file. In order to find the location of a shuffle block, we search all
- * ShuffleFiles destined for the block's reducer.
- *
+ * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for
+ * each block stored in each file. In order to find the location of a shuffle block, we search the
+ * files within a ShuffleFileGroups associated with the block's reducer.
  */
 private[spark]
 class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
@@ -74,102 +65,74 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
   val consolidateShuffleFiles =
     System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean
 
+  private val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
+
   /**
-   * Contains a pool of unused ShuffleFileGroups.
-   * One group is needed per concurrent thread (mapper) operating on the same shuffle.
+   * Contains all the state related to a particular shuffle. This includes a pool of unused
+   * ShuffleFileGroups, as well as all ShuffleFileGroups that have been created for the shuffle.
    */
-  private class ShuffleFileGroupPool {
-    private val nextFileId = new AtomicInteger(0)
-    private val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
-
-    def getNextFileId() = nextFileId.getAndIncrement()
-    def getUnusedFileGroup() = unusedFileGroups.poll()
-    def returnFileGroup(group: ShuffleFileGroup) = unusedFileGroups.add(group)
+  private class ShuffleState() {
+    val nextFileId = new AtomicInteger(0)
+    val unusedFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
+    val allFileGroups = new ConcurrentLinkedQueue[ShuffleFileGroup]()
   }
 
   type ShuffleId = Int
-  private val shuffleToFileGroupPoolMap = new TimeStampedHashMap[ShuffleId, ShuffleFileGroupPool]
-
-  /**
-   * Maps reducers (of a particular shuffle) to the set of files that have blocks destined for them.
-   * Each reducer will have one ShuffleFile per concurrent thread that executed during mapping.
-   */
-  private val shuffleToReducerToFilesMap =
-    new TimeStampedHashMap[ShuffleId, Array[ConcurrentLinkedQueue[ShuffleFile]]]
+  private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState]
 
   private
   val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup)
 
-  def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = {
-    initializeShuffleMetadata(shuffleId, numBuckets)
-
-    new ShuffleBlocks {
-      override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
-        val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
-        var fileGroup: ShuffleFileGroup = null
-        val writers = if (consolidateShuffleFiles) {
-          fileGroup = getUnusedFileGroup(shuffleId, mapId, numBuckets)
-          Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
-            val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
-            blockManager.getDiskWriter(blockId, fileGroup(bucketId).file, serializer, bufferSize)
-          }
-        } else {
-          Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
-            val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
-            val blockFile = blockManager.diskBlockManager.getFile(blockId)
-            blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize)
-          }
+  def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) =
+    new ShuffleWriterGroup {
+      shuffleStates.putIfAbsent(shuffleId, new ShuffleState())
+      private val shuffleState = shuffleStates(shuffleId)
+      private var fileGroup: ShuffleFileGroup = null
+
+      val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
+        fileGroup = getUnusedFileGroup()
+        Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+          val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
+          blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize)
+        }
+      } else {
+        Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
+          val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
+          val blockFile = blockManager.diskBlockManager.getFile(blockId)
+          blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize)
         }
-        new ShuffleWriterGroup(mapId, fileGroup, writers)
       }
 
-      override def releaseWriters(group: ShuffleWriterGroup, success: Boolean) {
+      override def releaseWriters(success: Boolean) {
         if (consolidateShuffleFiles) {
-          val fileGroup = group.fileGroup
           if (success) {
-            fileGroup.addMapper(group.mapId)
-            for ((writer, shuffleFile) <- group.writers.zip(fileGroup.files)) {
-              shuffleFile.recordMapOutput(writer.fileSegment().offset)
-            }
+            val offsets = writers.map(_.fileSegment().offset)
+            fileGroup.recordMapOutput(mapId, offsets)
           }
-          recycleFileGroup(shuffleId, fileGroup)
+          recycleFileGroup(fileGroup)
         }
       }
-    }
-  }
 
-  private def initializeShuffleMetadata(shuffleId: Int, numBuckets: Int) {
-    val prev = shuffleToFileGroupPoolMap.putIfAbsent(shuffleId, new ShuffleFileGroupPool())
-    if (!prev.isDefined) {
-      val reducerToFilesMap = new Array[ConcurrentLinkedQueue[ShuffleFile]](numBuckets)
-      for (reducerId <- 0 until numBuckets) {
-        reducerToFilesMap(reducerId) = new ConcurrentLinkedQueue[ShuffleFile]()
+      private def getUnusedFileGroup(): ShuffleFileGroup = {
+        val fileGroup = shuffleState.unusedFileGroups.poll()
+        if (fileGroup != null) fileGroup else newFileGroup()
+      }
+
+      private def newFileGroup(): ShuffleFileGroup = {
+        val fileId = shuffleState.nextFileId.getAndIncrement()
+        val files = Array.tabulate[File](numBuckets) { bucketId =>
+          val filename = physicalFileName(shuffleId, bucketId, fileId)
+          blockManager.diskBlockManager.getFile(filename)
+        }
+        val fileGroup = new ShuffleFileGroup(fileId, shuffleId, files)
+        shuffleState.allFileGroups.add(fileGroup)
+        fileGroup
       }
-      shuffleToReducerToFilesMap.put(shuffleId, reducerToFilesMap)
-    }
-  }
 
-  private def getUnusedFileGroup(shuffleId: Int, mapId: Int, numBuckets: Int): ShuffleFileGroup = {
-    val pool = shuffleToFileGroupPoolMap(shuffleId)
-    val fileGroup = pool.getUnusedFileGroup()
-    if (fileGroup == null) {
-      val fileId = pool.getNextFileId()
-      val files = Array.tabulate[ShuffleFile](numBuckets) { bucketId =>
-        val filename = physicalFileName(shuffleId, bucketId, fileId)
-        val file = blockManager.diskBlockManager.getFile(filename)
-        val shuffleFile = new ShuffleFile(file)
-        shuffleToReducerToFilesMap(shuffleId)(bucketId).add(shuffleFile)
-        shuffleFile
+      private def recycleFileGroup(group: ShuffleFileGroup) {
+        shuffleState.unusedFileGroups.add(group)
       }
-      new ShuffleFileGroup(shuffleId, fileId, files)
-    } else {
-      fileGroup
     }
-  }
-
-  private def recycleFileGroup(shuffleId: Int, fileGroup: ShuffleFileGroup) {
-    shuffleToFileGroupPoolMap(shuffleId).returnFileGroup(fileGroup)
-  }
 
   /**
    * Returns the physical file segment in which the given BlockId is located.
@@ -177,13 +140,12 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
    * an error condition if we don't find the expected block.
    */
   def getBlockLocation(id: ShuffleBlockId): FileSegment = {
-    // Search all files associated with the given reducer.
-    val filesForReducer = shuffleToReducerToFilesMap(id.shuffleId)(id.reduceId)
-    for (file <- filesForReducer) {
-      val segment = file.getFileSegmentFor(id.mapId)
-      if (segment != None) { return segment.get }
+    // Search all file groups associated with this shuffle.
+    val shuffleState = shuffleStates(id.shuffleId)
+    for (fileGroup <- shuffleState.allFileGroups) {
+      val segment = fileGroup.getFileSegmentFor(id.mapId, id.reduceId)
+      if (segment.isDefined) { return segment.get }
     }
-
     throw new IllegalStateException("Failed to find shuffle block: " + id)
   }
 
@@ -192,78 +154,62 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
   }
 
   private def cleanup(cleanupTime: Long) {
-    shuffleToFileGroupPoolMap.clearOldValues(cleanupTime)
-    shuffleToReducerToFilesMap.clearOldValues(cleanupTime)
+    shuffleStates.clearOldValues(cleanupTime)
   }
 }
 
-/**
- * A group of shuffle files, one per reducer.
- * A particular mapper will be assigned a single ShuffleFileGroup to write its output to.
- */
-private[spark]
-class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[ShuffleFile]) {
-  /**
-   * Stores the absolute index of each mapId in the files of this group. For instance,
-   * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0.
-   */
-  private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]()
-
-  files.foreach(_.setShuffleFileGroup(this))
-
-  def apply(bucketId: Int) = files(bucketId)
-
-  def addMapper(mapId: Int) {
-    mapIdToIndex(mapId) = mapIdToIndex.size
-  }
-
-  def indexOf(mapId: Int): Int = mapIdToIndex.getOrElse(mapId, -1)
-}
-
-/**
- * A single, consolidated shuffle file that may contain many actual blocks. All blocks are destined
- * to the same reducer.
- */
 private[spark]
-class ShuffleFile(val file: File) {
+object ShuffleBlockManager {
   /**
-   * Consecutive offsets of blocks into the file, ordered by position in the file.
-   * This ordering allows us to compute block lengths by examining the following block offset.
-   * Note: shuffleFileGroup.indexOf(mapId) returns the index of the mapper into this array.
+   * A group of shuffle files, one per reducer.
+   * A particular mapper will be assigned a single ShuffleFileGroup to write its output to.
    */
-  private val blockOffsets = new PrimitiveVector[Long]()
+  private class ShuffleFileGroup(val shuffleId: Int, val fileId: Int, val files: Array[File]) {
+    /**
+     * Stores the absolute index of each mapId in the files of this group. For instance,
+     * if mapId 5 is the first block in each file, mapIdToIndex(5) = 0.
+     */
+    private val mapIdToIndex = new PrimitiveKeyOpenHashMap[Int, Int]()
+
+    /**
+     * Stores consecutive offsets of blocks into each reducer file, ordered by position in the file.
+     * This ordering allows us to compute block lengths by examining the following block offset.
+     * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every
+     * reducer.
+     */
+    private val blockOffsetsByReducer = Array.tabulate[PrimitiveVector[Long]](files.length) { _ =>
+      new PrimitiveVector[Long]()
+    }
 
-  /** Back pointer to whichever ShuffleFileGroup this file is a part of. */
-  private var shuffleFileGroup : ShuffleFileGroup = _
+    def numBlocks = mapIdToIndex.size
 
-  // Required due to circular dependency between ShuffleFileGroup and ShuffleFile.
-  def setShuffleFileGroup(group: ShuffleFileGroup) {
-    assert(shuffleFileGroup == null)
-    shuffleFileGroup = group
-  }
+    def apply(bucketId: Int) = files(bucketId)
 
-  def recordMapOutput(offset: Long) {
-    blockOffsets += offset
-  }
+    def recordMapOutput(mapId: Int, offsets: Array[Long]) {
+      mapIdToIndex(mapId) = numBlocks
+      for (i <- 0 until offsets.length) {
+        blockOffsetsByReducer(i) += offsets(i)
+      }
+    }
 
-  /**
-   * Returns the FileSegment associated with the given map task, or
-   * None if this ShuffleFile does not have an entry for it.
-   */
-  def getFileSegmentFor(mapId: Int): Option[FileSegment] = {
-    val index = shuffleFileGroup.indexOf(mapId)
-    if (index >= 0) {
-      val offset = blockOffsets(index)
-      val length =
-        if (index + 1 < blockOffsets.length) {
-          blockOffsets(index + 1) - offset
-        } else {
-          file.length() - offset
-        }
-      assert(length >= 0)
-      Some(new FileSegment(file, offset, length))
-    } else {
-      None
+    /** Returns the FileSegment associated with the given map task, or None if no entry exists. */
+    def getFileSegmentFor(mapId: Int, reducerId: Int): Option[FileSegment] = {
+      val file = files(reducerId)
+      val blockOffsets = blockOffsetsByReducer(reducerId)
+      val index = mapIdToIndex.getOrElse(mapId, -1)
+      if (index >= 0) {
+        val offset = blockOffsets(index)
+        val length =
+          if (index + 1 < numBlocks) {
+            blockOffsets(index + 1) - offset
+          } else {
+            file.length() - offset
+          }
+        assert(length >= 0)
+        Some(new FileSegment(file, offset, length))
+      } else {
+        None
+      }
     }
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
index 021f6f6688..1e4db4f66b 100644
--- a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
@@ -38,19 +38,19 @@ object StoragePerfTester {
     val blockManager = sc.env.blockManager
 
     def writeOutputBytes(mapId: Int, total: AtomicLong) = {
-      val shuffle = blockManager.shuffleBlockManager.forShuffle(1, numOutputSplits,
+      val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits,
         new KryoSerializer())
-      val buckets = shuffle.acquireWriters(mapId)
+      val writers = shuffle.writers
       for (i <- 1 to recordsPerMap) {
-        buckets.writers(i % numOutputSplits).write(writeData)
+        writers(i % numOutputSplits).write(writeData)
       }
-      buckets.writers.map {w =>
+      writers.map {w =>
         w.commit()
         total.addAndGet(w.fileSegment().length)
         w.close()
       }
 
-      shuffle.releaseWriters(buckets, true)
+      shuffle.releaseWriters(true)
     }
 
     val start = System.currentTimeMillis()
-- 
GitLab