Skip to content
Snippets Groups Projects
Commit 9d3f05a9 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Made shuffle algorithm pluggable and added LocalFileShuffle.

parent d9ea6d69
No related branches found
No related tags found
No related merge requests found
...@@ -9,8 +9,6 @@ import scala.collection.mutable.HashMap ...@@ -9,8 +9,6 @@ import scala.collection.mutable.HashMap
import org.apache.hadoop.conf.Configuration import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem} import org.apache.hadoop.fs.{FileSystem, Path, RawLocalFileSystem}
import mesos.SlaveOffer
/** /**
* A simple implementation of shuffle using a distributed file system. * A simple implementation of shuffle using a distributed file system.
...@@ -18,20 +16,19 @@ import mesos.SlaveOffer ...@@ -18,20 +16,19 @@ import mesos.SlaveOffer
* TODO: Add support for compression when spark.compress is set to true. * TODO: Add support for compression when spark.compress is set to true.
*/ */
@serializable @serializable
class DfsShuffle[K, V, C]( class DfsShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
rdd: RDD[(K, V)], override def compute(input: RDD[(K, V)],
numOutputSplits: Int, numOutputSplits: Int,
createCombiner: V => C, createCombiner: V => C,
mergeValue: (C, V) => C, mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C) mergeCombiners: (C, C) => C)
extends Logging : RDD[(K, C)] =
{ {
def compute(): RDD[(K, C)] = { val sc = input.sparkContext
val sc = rdd.sparkContext
val dir = DfsShuffle.newTempDirectory() val dir = DfsShuffle.newTempDirectory()
logInfo("Intermediate data directory: " + dir) logInfo("Intermediate data directory: " + dir)
val numberedSplitRdd = new NumberedSplitRDD(rdd) val numberedSplitRdd = new NumberedSplitRDD(input)
val numInputSplits = numberedSplitRdd.splits.size val numInputSplits = numberedSplitRdd.splits.size
// Run a parallel foreach to write the intermediate data files // Run a parallel foreach to write the intermediate data files
...@@ -78,6 +75,7 @@ extends Logging ...@@ -78,6 +75,7 @@ extends Logging
} catch { } catch {
case e: EOFException => {} case e: EOFException => {}
} }
inputStream.close()
} }
combiners combiners
}) })
...@@ -85,9 +83,14 @@ extends Logging ...@@ -85,9 +83,14 @@ extends Logging
} }
/**
* Companion object of DfsShuffle; responsible for initializing a Hadoop
* FileSystem object based on the spark.dfs property and generating names
* for temporary directories.
*/
object DfsShuffle { object DfsShuffle {
var initialized = false private var initialized = false
var fileSystem: FileSystem = null private var fileSystem: FileSystem = null
private def initializeIfNeeded() = synchronized { private def initializeIfNeeded() = synchronized {
if (!initialized) { if (!initialized) {
...@@ -97,8 +100,8 @@ object DfsShuffle { ...@@ -97,8 +100,8 @@ object DfsShuffle {
conf.setInt("io.file.buffer.size", bufferSize) conf.setInt("io.file.buffer.size", bufferSize)
conf.setInt("dfs.replication", 1) conf.setInt("dfs.replication", 1)
fileSystem = FileSystem.get(new URI(dfs), conf) fileSystem = FileSystem.get(new URI(dfs), conf)
initialized = true
} }
initialized = true
} }
def getFileSystem(): FileSystem = { def getFileSystem(): FileSystem = {
...@@ -115,39 +118,3 @@ object DfsShuffle { ...@@ -115,39 +118,3 @@ object DfsShuffle {
return path return path
} }
} }
/**
* An RDD that captures the splits of a parent RDD and gives them unique indexes.
* This is useful for a variety of shuffle implementations.
*/
class NumberedSplitRDD[T: ClassManifest](prev: RDD[T])
extends RDD[(Int, Iterator[T])](prev.sparkContext) {
@transient val splits_ = {
prev.splits.zipWithIndex.map {
case (s, i) => new NumberedSplitRDDSplit(s, i): Split
}.toArray
}
override def splits = splits_
override def preferredLocations(split: Split) = {
val nsplit = split.asInstanceOf[NumberedSplitRDDSplit]
prev.preferredLocations(nsplit.prev)
}
override def iterator(split: Split) = {
val nsplit = split.asInstanceOf[NumberedSplitRDDSplit]
Iterator((nsplit.index, prev.iterator(nsplit.prev)))
}
override def taskStarted(split: Split, slot: SlaveOffer) = {
val nsplit = split.asInstanceOf[NumberedSplitRDDSplit]
prev.taskStarted(nsplit.prev, slot)
}
}
class NumberedSplitRDDSplit(val prev: Split, val index: Int) extends Split {
override def getId() = "NumberedSplitRDDSplit(%d)".format(index)
}
package spark
import java.io._
import java.net.URL
import java.util.UUID
import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.{ArrayBuffer, HashMap}
import spark.SparkContext._
/**
* A simple implementation of shuffle using local files served through HTTP.
*
* TODO: Add support for compression when spark.compress is set to true.
*/
@serializable
class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging {
override def compute(input: RDD[(K, V)],
numOutputSplits: Int,
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C)
: RDD[(K, C)] =
{
val sc = input.sparkContext
val shuffleId = LocalFileShuffle.newShuffleId()
logInfo("Shuffle ID: " + shuffleId)
val splitRdd = new NumberedSplitRDD(input)
val numInputSplits = splitRdd.splits.size
// Run a parallel map and collect to write the intermediate data files,
// returning a hash table of inputSplitId -> serverUri pairs
val outputLocs = splitRdd.map((pair: (Int, Iterator[(K, V)])) => {
val myIndex = pair._1
val myIterator = pair._2
val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C])
for ((k, v) <- myIterator) {
var bucketId = k.hashCode % numOutputSplits
if (bucketId < 0) { // Fix bucket ID if hash code was negative
bucketId += numOutputSplits
}
val bucket = buckets(bucketId)
bucket(k) = bucket.get(k) match {
case Some(c) => mergeValue(c, v)
case None => createCombiner(v)
}
}
for (i <- 0 until numOutputSplits) {
val file = LocalFileShuffle.getOutputFile(shuffleId, myIndex, i)
val out = new ObjectOutputStream(new FileOutputStream(file))
buckets(i).foreach(pair => out.writeObject(pair))
out.close()
}
(myIndex, LocalFileShuffle.serverUri)
}).collectAsMap()
// Build a hashmap from server URI to list of splits (to facillitate
// fetching all the URIs on a server within a single connection)
val splitsByUri = new HashMap[String, ArrayBuffer[Int]]
for ((inputId, serverUri) <- outputLocs) {
splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += inputId
}
// TODO: Could broadcast splitsByUri
// Return an RDD that does each of the merges for a given partition
val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits)
return indexes.flatMap((myId: Int) => {
val combiners = new HashMap[K, C]
for ((serverUri, inputIds) <- Utils.shuffle(splitsByUri)) {
for (i <- inputIds) {
val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, myId)
val inputStream = new ObjectInputStream(new URL(url).openStream())
try {
while (true) {
val (k, c) = inputStream.readObject().asInstanceOf[(K, C)]
combiners(k) = combiners.get(k) match {
case Some(oldC) => mergeCombiners(oldC, c)
case None => c
}
}
} catch {
case e: EOFException => {}
}
inputStream.close()
}
}
combiners
})
}
}
object LocalFileShuffle extends Logging {
private var initialized = false
private var nextShuffleId = new AtomicLong(0)
// Variables initialized by initializeIfNeeded()
private var localDir: File = null
private var server: HttpServer = null
private var serverUri: String = null
private def initializeIfNeeded() = synchronized {
if (!initialized) {
val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
var tries = 0
var foundLocalDir = false
while (!foundLocalDir && tries < 10) {
tries += 1
try {
localDir = new File(localDirRoot, "spark-local-" + UUID.randomUUID())
if (!localDir.exists()) {
localDir.mkdirs()
foundLocalDir = true
}
} catch {
case e: Exception =>
logWarning("Attempt " + tries + " to create local dir failed", e)
}
}
if (!foundLocalDir) {
logError("Failed 10 attempts to create local dir in " + localDirRoot)
System.exit(1)
}
logInfo("Local dir: " + localDir)
server = new HttpServer(localDir)
server.start()
serverUri = server.uri
initialized = true
}
}
def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = {
initializeIfNeeded()
val dir = new File(localDir, "shuffle/" + shuffleId + "/" + inputId)
dir.mkdirs()
val file = new File(dir, "" + outputId)
return file
}
def getServerUri(): String = {
initializeIfNeeded()
serverUri
}
def newShuffleId(): Long = {
nextShuffleId.getAndIncrement()
}
}
package spark
import mesos.SlaveOffer
/**
* An RDD that takes the splits of a parent RDD and gives them unique indexes.
* This is useful for a variety of shuffle implementations.
*/
class NumberedSplitRDD[T: ClassManifest](prev: RDD[T])
extends RDD[(Int, Iterator[T])](prev.sparkContext) {
@transient val splits_ = {
prev.splits.zipWithIndex.map {
case (s, i) => new NumberedSplitRDDSplit(s, i): Split
}.toArray
}
override def splits = splits_
override def preferredLocations(split: Split) = {
val nsplit = split.asInstanceOf[NumberedSplitRDDSplit]
prev.preferredLocations(nsplit.prev)
}
override def iterator(split: Split) = {
val nsplit = split.asInstanceOf[NumberedSplitRDDSplit]
Iterator((nsplit.index, prev.iterator(nsplit.prev)))
}
override def taskStarted(split: Split, slot: SlaveOffer) = {
val nsplit = split.asInstanceOf[NumberedSplitRDDSplit]
prev.taskStarted(nsplit.prev, slot)
}
}
/**
* A split in a NumberedSplitRDD.
*/
class NumberedSplitRDDSplit(val prev: Split, val index: Int) extends Split {
override def getId() = "NumberedSplitRDDSplit(%d)".format(index)
}
...@@ -356,8 +356,12 @@ extends RDD[Pair[T, U]](sc) { ...@@ -356,8 +356,12 @@ extends RDD[Pair[T, U]](sc) {
mergeValue: (C, V) => C, mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C, mergeCombiners: (C, C) => C,
numSplits: Int) numSplits: Int)
: RDD[(K, C)] = { : RDD[(K, C)] =
new DfsShuffle(self, numSplits, createCombiner, mergeValue, mergeCombiners).compute() {
val shufClass = Class.forName(System.getProperty(
"spark.shuffle.class", "spark.DfsShuffle"))
val shuf = shufClass.newInstance().asInstanceOf[Shuffle[K, V, C]]
shuf.compute(self, numSplits, createCombiner, mergeValue, mergeCombiners)
} }
def reduceByKey(func: (V, V) => V, numSplits: Int): RDD[(K, V)] = { def reduceByKey(func: (V, V) => V, numSplits: Int): RDD[(K, V)] = {
......
package spark
/**
* A trait for shuffle system. Given an input RDD and combiner functions
* for PairRDDExtras.combineByKey(), returns an output RDD.
*/
@serializable
trait Shuffle[K, V, C] {
def compute(input: RDD[(K, V)],
numOutputSplits: Int,
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C)
: RDD[(K, C)]
}
...@@ -100,8 +100,9 @@ object Utils { ...@@ -100,8 +100,9 @@ object Utils {
// Shuffle the elements of a collection into a random order, returning the // Shuffle the elements of a collection into a random order, returning the
// result in a new collection. Unlike scala.util.Random.shuffle, this method // result in a new collection. Unlike scala.util.Random.shuffle, this method
// uses a local random number generator, avoiding inter-thread contention. // uses a local random number generator, avoiding inter-thread contention.
def shuffle[T](seq: Seq[T]): Seq[T] = { def shuffle[T](seq: TraversableOnce[T]): Seq[T] = {
val buf = ArrayBuffer(seq: _*) val buf = new ArrayBuffer[T]()
buf ++= seq
val rand = new Random() val rand = new Random()
for (i <- (buf.size - 1) to 1 by -1) { for (i <- (buf.size - 1) to 1 by -1) {
val j = rand.nextInt(i) val j = rand.nextInt(i)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment