From 9d3f05a990beacadea00c68f9cf7ff82f93b0a44 Mon Sep 17 00:00:00 2001
From: Matei Zaharia <matei@eecs.berkeley.edu>
Date: Mon, 8 Nov 2010 00:45:02 -0800
Subject: [PATCH] Made shuffle algorithm pluggable and added LocalFileShuffle.

---
 src/scala/spark/DfsShuffle.scala       |  71 ++++--------
 src/scala/spark/LocalFileShuffle.scala | 152 +++++++++++++++++++++++++
 src/scala/spark/NumberedSplitRDD.scala |  42 +++++++
 src/scala/spark/RDD.scala              |   8 +-
 src/scala/spark/Shuffle.scala          |  15 +++
 src/scala/spark/Utils.scala            |   5 +-
 6 files changed, 237 insertions(+), 56 deletions(-)
 create mode 100644 src/scala/spark/LocalFileShuffle.scala
 create mode 100644 src/scala/spark/NumberedSplitRDD.scala
 create mode 100644 src/scala/spark/Shuffle.scala

diff --git a/src/scala/spark/DfsShuffle.scala b/src/scala/spark/DfsShuffle.scala
index 2ef0321a63..7a42bf2d06 100644
--- a/src/scala/spark/DfsShuffle.scala
+++ b/src/scala/spark/DfsShuffle.scala
@@ -9,8 +9,6 @@ import scala.collection.mutable.HashMap
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
 
-import mesos.SlaveOffer
-
 
 /**
  * A simple implementation of shuffle using a distributed file system.
@@ -18,20 +16,19 @@ import mesos.SlaveOffer
  * TODO: Add support for compression when spark.compress is set to true.
  */
 @serializable
-class DfsShuffle[K, V, C](
-  rdd: RDD[(K, V)],
-  numOutputSplits: Int,
-  createCombiner: V => C,
-  mergeValue: (C, V) => C,
-  mergeCombiners: (C, C) => C)
-extends Logging
-{
-  def compute(): RDD[(K, C)] = {
-    val sc = rdd.sparkContext
+class DfsShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
+  override def compute(input: RDD[(K, V)],
+                       numOutputSplits: Int,
+                       createCombiner: V => C,
+                       mergeValue: (C, V) => C,
+                       mergeCombiners: (C, C) => C)
+  : RDD[(K, C)] =
+  {
+    val sc = input.sparkContext
     val dir = DfsShuffle.newTempDirectory()
     logInfo("Intermediate data directory: " + dir)
 
-    val numberedSplitRdd = new NumberedSplitRDD(rdd)
+    val numberedSplitRdd = new NumberedSplitRDD(input)
     val numInputSplits = numberedSplitRdd.splits.size
 
     // Run a parallel foreach to write the intermediate data files
@@ -78,6 +75,7 @@ extends Logging
         } catch {
           case e: EOFException => {}
         }
+        inputStream.close()
       }
       combiners
     })
@@ -85,9 +83,14 @@ extends Logging
 }
 
 
+/**
+ * Companion object of DfsShuffle; responsible for initializing a Hadoop
+ * FileSystem object based on the spark.dfs property and generating names
+ * for temporary directories.
+ */
 object DfsShuffle {
-  var initialized = false
-  var fileSystem: FileSystem = null
+  private var initialized = false
+  private var fileSystem: FileSystem = null
 
   private def initializeIfNeeded() = synchronized {
     if (!initialized) {
@@ -97,8 +100,8 @@ object DfsShuffle {
       conf.setInt("io.file.buffer.size", bufferSize)
       conf.setInt("dfs.replication", 1)
       fileSystem = FileSystem.get(new URI(dfs), conf)
+      initialized = true
     }
-    initialized = true
   }
 
   def getFileSystem(): FileSystem = {
@@ -115,39 +118,3 @@ object DfsShuffle {
     return path
   }
 }
-
-
-/**
- * An RDD that captures the splits of a parent RDD and gives them unique indexes.
- * This is useful for a variety of shuffle implementations.
- */
-class NumberedSplitRDD[T: ClassManifest](prev: RDD[T])
-extends RDD[(Int, Iterator[T])](prev.sparkContext) {
-  @transient val splits_ = {
-    prev.splits.zipWithIndex.map {
-      case (s, i) => new NumberedSplitRDDSplit(s, i): Split
-    }.toArray
-  }
-
-  override def splits = splits_
-
-  override def preferredLocations(split: Split) = {
-    val nsplit = split.asInstanceOf[NumberedSplitRDDSplit]
-    prev.preferredLocations(nsplit.prev)
-  }
-
-  override def iterator(split: Split) = {
-    val nsplit = split.asInstanceOf[NumberedSplitRDDSplit]
-    Iterator((nsplit.index, prev.iterator(nsplit.prev)))
-  }
-
-  override def taskStarted(split: Split, slot: SlaveOffer) = {
-    val nsplit = split.asInstanceOf[NumberedSplitRDDSplit]
-    prev.taskStarted(nsplit.prev, slot)
-  }
-}
-
-
-class NumberedSplitRDDSplit(val prev: Split, val index: Int) extends Split {
-  override def getId() = "NumberedSplitRDDSplit(%d)".format(index)
-}
diff --git a/src/scala/spark/LocalFileShuffle.scala b/src/scala/spark/LocalFileShuffle.scala
new file mode 100644
index 0000000000..db6ae322f1
--- /dev/null
+++ b/src/scala/spark/LocalFileShuffle.scala
@@ -0,0 +1,152 @@
+package spark
+
+import java.io._
+import java.net.URL
+import java.util.UUID
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+
+import spark.SparkContext._
+
+
+/**
+ * A simple implementation of shuffle using local files served through HTTP.
+ *
+ * TODO: Add support for compression when spark.compress is set to true.
+ */
+@serializable
+class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
+  override def compute(input: RDD[(K, V)],
+                       numOutputSplits: Int,
+                       createCombiner: V => C,
+                       mergeValue: (C, V) => C,
+                       mergeCombiners: (C, C) => C)
+  : RDD[(K, C)] =
+  {
+    val sc = input.sparkContext
+    val shuffleId = LocalFileShuffle.newShuffleId()
+    logInfo("Shuffle ID: " + shuffleId)
+
+    val splitRdd = new NumberedSplitRDD(input)
+    val numInputSplits = splitRdd.splits.size
+
+    // Run a parallel map and collect to write the intermediate data files,
+    // returning a hash table of inputSplitId -> serverUri pairs
+    val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
+      val myIndex = pair._1
+      val myIterator = pair._2
+      val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
+      for ((k, v) <- myIterator) {
+        var bucketId = k.hashCode % numOutputSplits
+        if (bucketId < 0) { // Fix bucket ID if hash code was negative
+          bucketId += numOutputSplits
+        }
+        val bucket = buckets(bucketId)
+        bucket(k) = bucket.get(k) match {
+          case Some(c) => mergeValue(c, v)
+          case None => createCombiner(v)
+        }
+      }
+      for (i <- 0 until numOutputSplits) {
+        val file = LocalFileShuffle.getOutputFile(shuffleId, myIndex, i)
+        val out = new ObjectOutputStream(new FileOutputStream(file))
+        buckets(i).foreach(pair => out.writeObject(pair))
+        out.close()
+      }
+      (myIndex, LocalFileShuffle.serverUri)
+    }).collectAsMap()
+
+    // Build a hashmap from server URI to list of splits (to facillitate
+    // fetching all the URIs on a server within a single connection)
+    val splitsByUri = new HashMap[String, ArrayBuffer[Int]]
+    for ((inputId, serverUri) <- outputLocs) {
+      splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += inputId
+    }
+
+    // TODO: Could broadcast splitsByUri
+
+    // Return an RDD that does each of the merges for a given partition
+    val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
+    return indexes.flatMap((myId: Int) => {
+      val combiners = new HashMap[K, C]
+      for ((serverUri, inputIds) <- Utils.shuffle(splitsByUri)) {
+        for (i <- inputIds) {
+          val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, myId)
+          val inputStream = new ObjectInputStream(new URL(url).openStream())
+          try {
+            while (true) {
+              val (k, c) = inputStream.readObject().asInstanceOf[(K, C)]
+              combiners(k) = combiners.get(k) match {
+                case Some(oldC) => mergeCombiners(oldC, c)
+                case None => c
+              }
+            }
+          } catch {
+            case e: EOFException => {}
+          }
+          inputStream.close()
+        }
+      }
+      combiners
+    })
+  }
+}
+
+
+object LocalFileShuffle extends Logging {
+  private var initialized = false
+  private var nextShuffleId = new AtomicLong(0)
+
+  // Variables initialized by initializeIfNeeded()
+  private var localDir: File = null
+  private var server: HttpServer = null
+  private var serverUri: String = null
+
+  private def initializeIfNeeded() = synchronized {
+    if (!initialized) {
+      val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
+      var tries = 0
+      var foundLocalDir = false
+      while (!foundLocalDir && tries < 10) {
+        tries += 1
+        try {
+          localDir = new File(localDirRoot, "spark-local-" + UUID.randomUUID())
+          if (!localDir.exists()) {
+            localDir.mkdirs()
+            foundLocalDir = true
+          }
+        } catch {
+          case e: Exception =>
+            logWarning("Attempt " + tries + " to create local dir failed", e)
+        }
+      }
+      if (!foundLocalDir) {
+        logError("Failed 10 attempts to create local dir in " + localDirRoot)
+        System.exit(1)
+      }
+      logInfo("Local dir: " + localDir)
+      server = new HttpServer(localDir)
+      server.start()
+      serverUri = server.uri
+      initialized = true
+    }
+  }
+
+  def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = {
+    initializeIfNeeded()
+    val dir = new File(localDir, "shuffle/" + shuffleId + "/" + inputId)
+    dir.mkdirs()
+    val file = new File(dir, "" + outputId)
+    return file
+  }
+
+  def getServerUri(): String = {
+    initializeIfNeeded()
+    serverUri
+  }
+
+  def newShuffleId(): Long = {
+    nextShuffleId.getAndIncrement()
+  }
+}
diff --git a/src/scala/spark/NumberedSplitRDD.scala b/src/scala/spark/NumberedSplitRDD.scala
new file mode 100644
index 0000000000..7b12210d84
--- /dev/null
+++ b/src/scala/spark/NumberedSplitRDD.scala
@@ -0,0 +1,42 @@
+package spark
+
+import mesos.SlaveOffer
+
+
+/**
+ * An RDD that takes the splits of a parent RDD and gives them unique indexes.
+ * This is useful for a variety of shuffle implementations.
+ */
+class NumberedSplitRDD[T: ClassManifest](prev: RDD[T])
+extends RDD[(Int, Iterator[T])](prev.sparkContext) {
+  @transient val splits_ = {
+    prev.splits.zipWithIndex.map {
+      case (s, i) => new NumberedSplitRDDSplit(s, i): Split
+    }.toArray
+  }
+
+  override def splits = splits_
+
+  override def preferredLocations(split: Split) = {
+    val nsplit = split.asInstanceOf[NumberedSplitRDDSplit]
+    prev.preferredLocations(nsplit.prev)
+  }
+
+  override def iterator(split: Split) = {
+    val nsplit = split.asInstanceOf[NumberedSplitRDDSplit]
+    Iterator((nsplit.index, prev.iterator(nsplit.prev)))
+  }
+
+  override def taskStarted(split: Split, slot: SlaveOffer) = {
+    val nsplit = split.asInstanceOf[NumberedSplitRDDSplit]
+    prev.taskStarted(nsplit.prev, slot)
+  }
+}
+
+
+/**
+ * A split in a NumberedSplitRDD.
+ */
+class NumberedSplitRDDSplit(val prev: Split, val index: Int) extends Split {
+  override def getId() = "NumberedSplitRDDSplit(%d)".format(index)
+}
diff --git a/src/scala/spark/RDD.scala b/src/scala/spark/RDD.scala
index 9650ea9d8b..bac59319a0 100644
--- a/src/scala/spark/RDD.scala
+++ b/src/scala/spark/RDD.scala
@@ -356,8 +356,12 @@ extends RDD[Pair[T, U]](sc) {
                       mergeValue: (C, V) => C,
                       mergeCombiners: (C, C) => C,
                       numSplits: Int)
-  : RDD[(K, C)] = {
-    new DfsShuffle(self, numSplits, createCombiner, mergeValue, mergeCombiners).compute()
+  : RDD[(K, C)] =
+  {
+    val shufClass = Class.forName(System.getProperty(
+      "spark.shuffle.class", "spark.DfsShuffle"))
+    val shuf = shufClass.newInstance().asInstanceOf[Shuffle[K, V, C]]
+    shuf.compute(self, numSplits, createCombiner, mergeValue, mergeCombiners)
   }
 
   def reduceByKey(func: (V, V) => V, numSplits: Int): RDD[(K, V)] = {
diff --git a/src/scala/spark/Shuffle.scala b/src/scala/spark/Shuffle.scala
new file mode 100644
index 0000000000..4c5649b537
--- /dev/null
+++ b/src/scala/spark/Shuffle.scala
@@ -0,0 +1,15 @@
+package spark
+
+/**
+ * A trait for shuffle system. Given an input RDD and combiner functions
+ * for PairRDDExtras.combineByKey(), returns an output RDD.
+ */
+@serializable
+trait Shuffle[K, V, C] {
+  def compute(input: RDD[(K, V)],
+              numOutputSplits: Int,
+              createCombiner: V => C,
+              mergeValue: (C, V) => C,
+              mergeCombiners: (C, C) => C)
+  : RDD[(K, C)]
+}
diff --git a/src/scala/spark/Utils.scala b/src/scala/spark/Utils.scala
index 1b2fe50c0e..025472633b 100644
--- a/src/scala/spark/Utils.scala
+++ b/src/scala/spark/Utils.scala
@@ -100,8 +100,9 @@ object Utils {
   // Shuffle the elements of a collection into a random order, returning the
   // result in a new collection. Unlike scala.util.Random.shuffle, this method
   // uses a local random number generator, avoiding inter-thread contention.
-  def shuffle[T](seq: Seq[T]): Seq[T] = {
-    val buf = ArrayBuffer(seq: _*)
+  def shuffle[T](seq: TraversableOnce[T]): Seq[T] = {
+    val buf = new ArrayBuffer[T]()
+    buf ++= seq
     val rand = new Random()
     for (i <- (buf.size - 1) to 1 by -1) {
       val j = rand.nextInt(i)
-- 
GitLab