diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala index df88936b3a283d117be78f3af767f9fd54c37d3d..8ed5d7ba7193a7a667fb6b5b9b698f4eda467c52 100644 --- a/core/src/main/scala/spark/DAGScheduler.scala +++ b/core/src/main/scala/spark/DAGScheduler.scala @@ -65,8 +65,9 @@ private trait DAGScheduler extends Scheduler with Logging { var cacheLocs = new HashMap[Int, Array[List[String]]] - val cacheTracker = SparkEnv.get.cacheTracker - val mapOutputTracker = SparkEnv.get.mapOutputTracker + val env = SparkEnv.get + val cacheTracker = env.cacheTracker + val mapOutputTracker = env.mapOutputTracker def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { cacheLocs(rdd.id) @@ -166,6 +167,8 @@ private trait DAGScheduler extends Scheduler with Logging { val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // missing tasks from each stage var lastFetchFailureTime: Long = 0 // used to wait a bit to avoid repeated resubmits + SparkEnv.set(env) + updateCacheLocs() logInfo("Final stage: " + finalStage) diff --git a/core/src/main/scala/spark/LocalFileShuffle.scala b/core/src/main/scala/spark/LocalFileShuffle.scala deleted file mode 100644 index 6c7f3dede28ab29ba64ab8478ed9769755cfd5d2..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/LocalFileShuffle.scala +++ /dev/null @@ -1,90 +0,0 @@ -package spark - -import java.io._ -import java.net.URL -import java.util.UUID -import java.util.concurrent.atomic.AtomicLong - -import scala.collection.mutable.{ArrayBuffer, HashMap} - -import spark._ - -object LocalFileShuffle extends Logging { - private var initialized = false - private var nextShuffleId = new AtomicLong(0) - - // Variables initialized by initializeIfNeeded() - private var shuffleDir: File = null - private var server: HttpServer = null - private var serverUri: String = null - - private def initializeIfNeeded() = synchronized { - if (!initialized) { - // TODO: localDir should be created by some mechanism common to Spark - // so that it can be shared among shuffle, broadcast, etc - val localDirRoot = System.getProperty("spark.local.dir", "/tmp") - var tries = 0 - var foundLocalDir = false - var localDir: File = null - var localDirUuid: UUID = null - while (!foundLocalDir && tries < 10) { - tries += 1 - try { - localDirUuid = UUID.randomUUID - localDir = new File(localDirRoot, "spark-local-" + localDirUuid) - if (!localDir.exists) { - localDir.mkdirs() - foundLocalDir = true - } - } catch { - case e: Exception => - logWarning("Attempt " + tries + " to create local dir failed", e) - } - } - if (!foundLocalDir) { - logError("Failed 10 attempts to create local dir in " + localDirRoot) - System.exit(1) - } - shuffleDir = new File(localDir, "shuffle") - shuffleDir.mkdirs() - logInfo("Shuffle dir: " + shuffleDir) - - val extServerPort = System.getProperty( - "spark.localFileShuffle.external.server.port", "-1").toInt - if (extServerPort != -1) { - // We're using an external HTTP server; set URI relative to its root - var extServerPath = System.getProperty( - "spark.localFileShuffle.external.server.path", "") - if (extServerPath != "" && !extServerPath.endsWith("/")) { - extServerPath += "/" - } - serverUri = "http://%s:%d/%s/spark-local-%s".format( - Utils.localIpAddress, extServerPort, extServerPath, localDirUuid) - } else { - // Create our own server - server = new HttpServer(localDir) - server.start() - serverUri = server.uri - } - initialized = true - logInfo("Local URI: " + serverUri) - } - } - - def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = { - initializeIfNeeded() - val dir = new File(shuffleDir, shuffleId + "/" + inputId) - dir.mkdirs() - val file = new File(dir, "" + outputId) - return file - } - - def getServerUri(): String = { - initializeIfNeeded() - serverUri - } - - def newShuffleId(): Long = { - nextShuffleId.getAndIncrement() - } -} diff --git a/core/src/main/scala/spark/LocalScheduler.scala b/core/src/main/scala/spark/LocalScheduler.scala index 6485da0b51498f0797380f4f52ea756913dc2497..34f06b747d539167987d7b0fe72ea6201a1ce3fb 100644 --- a/core/src/main/scala/spark/LocalScheduler.scala +++ b/core/src/main/scala/spark/LocalScheduler.scala @@ -12,8 +12,6 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule var attemptId = new AtomicInteger(0) var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) - val env = SparkEnv.get - override def start() {} override def waitForRegister() {} diff --git a/core/src/main/scala/spark/ShuffleManager.scala b/core/src/main/scala/spark/ShuffleManager.scala new file mode 100644 index 0000000000000000000000000000000000000000..c1d3af2729479c7b8846a61f64e8adf3f0a412cd --- /dev/null +++ b/core/src/main/scala/spark/ShuffleManager.scala @@ -0,0 +1,91 @@ +package spark + +import java.io._ +import java.net.URL +import java.util.UUID +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +import spark._ + +class ShuffleManager extends Logging { + private var nextShuffleId = new AtomicLong(0) + + private var shuffleDir: File = null + private var server: HttpServer = null + private var serverUri: String = null + + initialize() + + private def initialize() { + // TODO: localDir should be created by some mechanism common to Spark + // so that it can be shared among shuffle, broadcast, etc + val localDirRoot = System.getProperty("spark.local.dir", "/tmp") + var tries = 0 + var foundLocalDir = false + var localDir: File = null + var localDirUuid: UUID = null + while (!foundLocalDir && tries < 10) { + tries += 1 + try { + localDirUuid = UUID.randomUUID + localDir = new File(localDirRoot, "spark-local-" + localDirUuid) + if (!localDir.exists) { + localDir.mkdirs() + foundLocalDir = true + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir failed", e) + } + } + if (!foundLocalDir) { + logError("Failed 10 attempts to create local dir in " + localDirRoot) + System.exit(1) + } + shuffleDir = new File(localDir, "shuffle") + shuffleDir.mkdirs() + logInfo("Shuffle dir: " + shuffleDir) + + val extServerPort = System.getProperty( + "spark.localFileShuffle.external.server.port", "-1").toInt + if (extServerPort != -1) { + // We're using an external HTTP server; set URI relative to its root + var extServerPath = System.getProperty( + "spark.localFileShuffle.external.server.path", "") + if (extServerPath != "" && !extServerPath.endsWith("/")) { + extServerPath += "/" + } + serverUri = "http://%s:%d/%s/spark-local-%s".format( + Utils.localIpAddress, extServerPort, extServerPath, localDirUuid) + } else { + // Create our own server + server = new HttpServer(localDir) + server.start() + serverUri = server.uri + } + logInfo("Local URI: " + serverUri) + } + + def stop() { + if (server != null) { + server.stop() + } + } + + def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = { + val dir = new File(shuffleDir, shuffleId + "/" + inputId) + dir.mkdirs() + val file = new File(dir, "" + outputId) + return file + } + + def getServerUri(): String = { + serverUri + } + + def newShuffleId(): Long = { + nextShuffleId.getAndIncrement() + } +} diff --git a/core/src/main/scala/spark/ShuffleMapTask.scala b/core/src/main/scala/spark/ShuffleMapTask.scala index 93a93d57509583586efbb2b022ba39c9ae2bc9e6..d68c6cf4efea269d67ad47099b8084460614a603 100644 --- a/core/src/main/scala/spark/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/ShuffleMapTask.scala @@ -30,7 +30,7 @@ extends DAGTask[String](stageId) with Logging { } val ser = SparkEnv.get.serializer.newInstance() for (i <- 0 until numOutputSplits) { - val file = LocalFileShuffle.getOutputFile(dep.shuffleId, partition, i) + val file = SparkEnv.get.shuffleManager.getOutputFile(dep.shuffleId, partition, i) val out = ser.outputStream(new FastBufferedOutputStream(new FileOutputStream(file))) val iter = buckets(i).entrySet().iterator() while (iter.hasNext()) { @@ -40,7 +40,7 @@ extends DAGTask[String](stageId) with Logging { // TODO: have some kind of EOF marker out.close() } - return LocalFileShuffle.getServerUri + return SparkEnv.get.shuffleManager.getServerUri } override def preferredLocations: Seq[String] = locs diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index b0cc0e64547a8bb0fa0cc81278607e6cb4367e12..9f1a49f853fec728d35c00df4e0559ea7fcba58d 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -222,6 +222,7 @@ extends Logging { env.mapOutputTracker.stop() env.cacheTracker.stop() env.shuffleFetcher.stop() + env.shuffleManager.stop() SparkEnv.set(null) } diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index ad6d54d9051510d46cbc67db120304c0a0f533e0..81caf7cff05928ea8c653d57f2bfcfdac5132fe6 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -5,7 +5,8 @@ class SparkEnv ( val serializer: Serializer, val cacheTracker: CacheTracker, val mapOutputTracker: MapOutputTracker, - val shuffleFetcher: ShuffleFetcher + val shuffleFetcher: ShuffleFetcher, + val shuffleManager: ShuffleManager ) object SparkEnv { @@ -33,6 +34,8 @@ object SparkEnv { val shuffleFetcherClass = System.getProperty("spark.shuffle.fetcher", "spark.SimpleShuffleFetcher") val shuffleFetcher = Class.forName(shuffleFetcherClass).newInstance().asInstanceOf[ShuffleFetcher] - new SparkEnv(cache, serializer, cacheTracker, mapOutputTracker, shuffleFetcher) + val shuffleMgr = new ShuffleManager() + + new SparkEnv(cache, serializer, cacheTracker, mapOutputTracker, shuffleFetcher, shuffleMgr) } }