diff --git a/src/scala/spark/BasicLocalFileShuffle.scala b/src/scala/spark/BasicLocalFileShuffle.scala new file mode 100644 index 0000000000000000000000000000000000000000..6d8b42e58bbc2638b71db6c50ba2a584bbdec555 --- /dev/null +++ b/src/scala/spark/BasicLocalFileShuffle.scala @@ -0,0 +1,182 @@ +package spark + +import java.io._ +import java.net.URL +import java.util.UUID +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable.{ArrayBuffer, HashMap} + + +/** + * A simple implementation of shuffle using local files served through HTTP. + * + * TODO: Add support for compression when spark.compress is set to true. + */ +@serializable +class BasicLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { + override def compute(input: RDD[(K, V)], + numOutputSplits: Int, + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C) + : RDD[(K, C)] = + { + val sc = input.sparkContext + val shuffleId = BasicLocalFileShuffle.newShuffleId() + logInfo("Shuffle ID: " + shuffleId) + + val splitRdd = new NumberedSplitRDD(input) + val numInputSplits = splitRdd.splits.size + + // Run a parallel map and collect to write the intermediate data files, + // returning a list of inputSplitId -> serverUri pairs + val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => { + val myIndex = pair._1 + val myIterator = pair._2 + val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) + for ((k, v) <- myIterator) { + var bucketId = k.hashCode % numOutputSplits + if (bucketId < 0) { // Fix bucket ID if hash code was negative + bucketId += numOutputSplits + } + val bucket = buckets(bucketId) + bucket(k) = bucket.get(k) match { + case Some(c) => mergeValue(c, v) + case None => createCombiner(v) + } + } + for (i <- 0 until numOutputSplits) { + val file = BasicLocalFileShuffle.getOutputFile(shuffleId, myIndex, i) + val writeStartTime = System.currentTimeMillis + logInfo ("BEGIN WRITE: " + file) + val out = new ObjectOutputStream(new FileOutputStream(file)) + buckets(i).foreach(pair => out.writeObject(pair)) + out.close() + logInfo ("END WRITE: " + file) + val writeTime = (System.currentTimeMillis - writeStartTime) + logInfo ("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") + } + (myIndex, BasicLocalFileShuffle.serverUri) + }).collect() + + // Build a hashmap from server URI to list of splits (to facillitate + // fetching all the URIs on a server within a single connection) + val splitsByUri = new HashMap[String, ArrayBuffer[Int]] + for ((inputId, serverUri) <- outputLocs) { + splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += inputId + } + + // TODO: Could broadcast splitsByUri + + // Return an RDD that does each of the merges for a given partition + val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) + return indexes.flatMap((myId: Int) => { + val combiners = new HashMap[K, C] + for ((serverUri, inputIds) <- Utils.shuffle(splitsByUri)) { + for (i <- inputIds) { + val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, myId) + val readStartTime = System.currentTimeMillis + logInfo ("BEGIN READ: " + url) + val inputStream = new ObjectInputStream(new URL(url).openStream()) + try { + while (true) { + val (k, c) = inputStream.readObject().asInstanceOf[(K, C)] + combiners(k) = combiners.get(k) match { + case Some(oldC) => mergeCombiners(oldC, c) + case None => c + } + } + } catch { + case e: EOFException => {} + } + inputStream.close() + logInfo ("END READ: " + url) + val readTime = (System.currentTimeMillis - readStartTime) + logInfo ("Reading " + url + " took " + readTime + " millis.") + } + } + combiners + }) + } +} + + +object BasicLocalFileShuffle extends Logging { + private var initialized = false + private var nextShuffleId = new AtomicLong(0) + + // Variables initialized by initializeIfNeeded() + private var shuffleDir: File = null + private var server: HttpServer = null + private var serverUri: String = null + + private def initializeIfNeeded() = synchronized { + if (!initialized) { + // TODO: localDir should be created by some mechanism common to Spark + // so that it can be shared among shuffle, broadcast, etc + val localDirRoot = System.getProperty("spark.local.dir", "/tmp") + var tries = 0 + var foundLocalDir = false + var localDir: File = null + var localDirUuid: UUID = null + while (!foundLocalDir && tries < 10) { + tries += 1 + try { + localDirUuid = UUID.randomUUID() + localDir = new File(localDirRoot, "spark-local-" + localDirUuid) + if (!localDir.exists()) { + localDir.mkdirs() + foundLocalDir = true + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir failed", e) + } + } + if (!foundLocalDir) { + logError("Failed 10 attempts to create local dir in " + localDirRoot) + System.exit(1) + } + shuffleDir = new File(localDir, "shuffle") + shuffleDir.mkdirs() + logInfo("Shuffle dir: " + shuffleDir) + val extServerPort = System.getProperty( + "spark.localFileShuffle.external.server.port", "-1").toInt + if (extServerPort != -1) { + // We're using an external HTTP server; set URI relative to its root + var extServerPath = System.getProperty( + "spark.localFileShuffle.external.server.path", "") + if (extServerPath != "" && !extServerPath.endsWith("/")) { + extServerPath += "/" + } + serverUri = "http://%s:%d/%s/spark-local-%s".format( + Utils.localIpAddress, extServerPort, extServerPath, localDirUuid) + } else { + // Create our own server + server = new HttpServer(localDir) + server.start() + serverUri = server.uri + } + initialized = true + logInfo ("Local URI: " + serverUri) + } + } + + def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = { + initializeIfNeeded() + val dir = new File(shuffleDir, shuffleId + "/" + inputId) + dir.mkdirs() + val file = new File(dir, "" + outputId) + return file + } + + def getServerUri(): String = { + initializeIfNeeded() + serverUri + } + + def newShuffleId(): Long = { + nextShuffleId.getAndIncrement() + } +} diff --git a/src/scala/spark/LocalFileShuffle.scala b/src/scala/spark/ParallelLocalFileShuffle.scala similarity index 94% rename from src/scala/spark/LocalFileShuffle.scala rename to src/scala/spark/ParallelLocalFileShuffle.scala index b70315deffb64f75144a12fb01f63f43c36dc668..208fad10739815174fc6649300db2cab7535d629 100644 --- a/src/scala/spark/LocalFileShuffle.scala +++ b/src/scala/spark/ParallelLocalFileShuffle.scala @@ -15,7 +15,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap} * TODO: Add support for compression when spark.compress is set to true. */ @serializable -class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { +class ParallelLocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { @transient var totalSplits = 0 @transient var hasSplits = 0 @transient var hasSplitsBitVector: BitSet = null @@ -31,7 +31,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { : RDD[(K, C)] = { val sc = input.sparkContext - val shuffleId = LocalFileShuffle.newShuffleId() + val shuffleId = ParallelLocalFileShuffle.newShuffleId() logInfo("Shuffle ID: " + shuffleId) val splitRdd = new NumberedSplitRDD(input) @@ -55,7 +55,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { } } for (i <- 0 until numOutputSplits) { - val file = LocalFileShuffle.getOutputFile(shuffleId, myIndex, i) + val file = ParallelLocalFileShuffle.getOutputFile(shuffleId, myIndex, i) val writeStartTime = System.currentTimeMillis logInfo ("BEGIN WRITE: " + file) val out = new ObjectOutputStream(new FileOutputStream(file)) @@ -65,7 +65,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { val writeTime = (System.currentTimeMillis - writeStartTime) logInfo ("Writing " + file + " of size " + file.length + " bytes took " + writeTime + " millis.") } - (myIndex, LocalFileShuffle.serverUri) + (myIndex, ParallelLocalFileShuffle.serverUri) }).collect() // Load config option to decide whether or not to use HTTP pipelining @@ -102,12 +102,12 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { splitsInRequestBitVector = new BitSet (totalSplits) combiners = new HashMap[K, C] - var threadPool = LocalFileShuffle.newDaemonFixedThreadPool ( - LocalFileShuffle.MaxConnections) + var threadPool = ParallelLocalFileShuffle.newDaemonFixedThreadPool ( + ParallelLocalFileShuffle.MaxConnections) while (hasSplits < totalSplits) { var numThreadsToCreate = - Math.min (totalSplits, LocalFileShuffle.MaxConnections) - + Math.min (totalSplits, ParallelLocalFileShuffle.MaxConnections) - threadPool.getActiveCount while (hasSplits < totalSplits && numThreadsToCreate > 0) { @@ -130,7 +130,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { } // Sleep for a while before creating new threads - Thread.sleep (LocalFileShuffle.MinKnockInterval) + Thread.sleep (ParallelLocalFileShuffle.MinKnockInterval) } combiners }) @@ -148,7 +148,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { } if (requiredSplits.size > 0) { - requiredSplits(LocalFileShuffle.ranGen.nextInt (requiredSplits.size)) + requiredSplits(ParallelLocalFileShuffle.ranGen.nextInt (requiredSplits.size)) } else { -1 } @@ -222,7 +222,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { } -object LocalFileShuffle extends Logging { +object ParallelLocalFileShuffle extends Logging { // Used thoughout the code for small and large waits/timeouts private var MinKnockInterval_ = 1000 private var MaxKnockInterval_ = 5000