Skip to content
Snippets Groups Projects
Commit 0e93891d authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Replaced LocalFileShuffle with a non-singleton ShuffleManager class

and made DAGScheduler automatically set SparkEnv.
parent e02dc83a
No related branches found
No related tags found
No related merge requests found
......@@ -65,8 +65,9 @@ private trait DAGScheduler extends Scheduler with Logging {
var cacheLocs = new HashMap[Int, Array[List[String]]]
val cacheTracker = SparkEnv.get.cacheTracker
val mapOutputTracker = SparkEnv.get.mapOutputTracker
val env = SparkEnv.get
val cacheTracker = env.cacheTracker
val mapOutputTracker = env.mapOutputTracker
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
cacheLocs(rdd.id)
......@@ -166,6 +167,8 @@ private trait DAGScheduler extends Scheduler with Logging {
val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // missing tasks from each stage
var lastFetchFailureTime: Long = 0 // used to wait a bit to avoid repeated resubmits
SparkEnv.set(env)
updateCacheLocs()
logInfo("Final stage: " + finalStage)
......
package spark
import java.io._
import java.net.URL
import java.util.UUID
import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.{ArrayBuffer, HashMap}
import spark._
object LocalFileShuffle extends Logging {
private var initialized = false
private var nextShuffleId = new AtomicLong(0)
// Variables initialized by initializeIfNeeded()
private var shuffleDir: File = null
private var server: HttpServer = null
private var serverUri: String = null
private def initializeIfNeeded() = synchronized {
if (!initialized) {
// TODO: localDir should be created by some mechanism common to Spark
// so that it can be shared among shuffle, broadcast, etc
val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
var tries = 0
var foundLocalDir = false
var localDir: File = null
var localDirUuid: UUID = null
while (!foundLocalDir && tries < 10) {
tries += 1
try {
localDirUuid = UUID.randomUUID
localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
if (!localDir.exists) {
localDir.mkdirs()
foundLocalDir = true
}
} catch {
case e: Exception =>
logWarning("Attempt " + tries + " to create local dir failed", e)
}
}
if (!foundLocalDir) {
logError("Failed 10 attempts to create local dir in " + localDirRoot)
System.exit(1)
}
shuffleDir = new File(localDir, "shuffle")
shuffleDir.mkdirs()
logInfo("Shuffle dir: " + shuffleDir)
val extServerPort = System.getProperty(
"spark.localFileShuffle.external.server.port", "-1").toInt
if (extServerPort != -1) {
// We're using an external HTTP server; set URI relative to its root
var extServerPath = System.getProperty(
"spark.localFileShuffle.external.server.path", "")
if (extServerPath != "" && !extServerPath.endsWith("/")) {
extServerPath += "/"
}
serverUri = "http://%s:%d/%s/spark-local-%s".format(
Utils.localIpAddress, extServerPort, extServerPath, localDirUuid)
} else {
// Create our own server
server = new HttpServer(localDir)
server.start()
serverUri = server.uri
}
initialized = true
logInfo("Local URI: " + serverUri)
}
}
def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = {
initializeIfNeeded()
val dir = new File(shuffleDir, shuffleId + "/" + inputId)
dir.mkdirs()
val file = new File(dir, "" + outputId)
return file
}
def getServerUri(): String = {
initializeIfNeeded()
serverUri
}
def newShuffleId(): Long = {
nextShuffleId.getAndIncrement()
}
}
......@@ -12,8 +12,6 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule
var attemptId = new AtomicInteger(0)
var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory)
val env = SparkEnv.get
override def start() {}
override def waitForRegister() {}
......
package spark
import java.io._
import java.net.URL
import java.util.UUID
import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.{ArrayBuffer, HashMap}
import spark._
class ShuffleManager extends Logging {
private var nextShuffleId = new AtomicLong(0)
private var shuffleDir: File = null
private var server: HttpServer = null
private var serverUri: String = null
initialize()
private def initialize() {
// TODO: localDir should be created by some mechanism common to Spark
// so that it can be shared among shuffle, broadcast, etc
val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
var tries = 0
var foundLocalDir = false
var localDir: File = null
var localDirUuid: UUID = null
while (!foundLocalDir && tries < 10) {
tries += 1
try {
localDirUuid = UUID.randomUUID
localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
if (!localDir.exists) {
localDir.mkdirs()
foundLocalDir = true
}
} catch {
case e: Exception =>
logWarning("Attempt " + tries + " to create local dir failed", e)
}
}
if (!foundLocalDir) {
logError("Failed 10 attempts to create local dir in " + localDirRoot)
System.exit(1)
}
shuffleDir = new File(localDir, "shuffle")
shuffleDir.mkdirs()
logInfo("Shuffle dir: " + shuffleDir)
val extServerPort = System.getProperty(
"spark.localFileShuffle.external.server.port", "-1").toInt
if (extServerPort != -1) {
// We're using an external HTTP server; set URI relative to its root
var extServerPath = System.getProperty(
"spark.localFileShuffle.external.server.path", "")
if (extServerPath != "" && !extServerPath.endsWith("/")) {
extServerPath += "/"
}
serverUri = "http://%s:%d/%s/spark-local-%s".format(
Utils.localIpAddress, extServerPort, extServerPath, localDirUuid)
} else {
// Create our own server
server = new HttpServer(localDir)
server.start()
serverUri = server.uri
}
logInfo("Local URI: " + serverUri)
}
def stop() {
if (server != null) {
server.stop()
}
}
def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = {
val dir = new File(shuffleDir, shuffleId + "/" + inputId)
dir.mkdirs()
val file = new File(dir, "" + outputId)
return file
}
def getServerUri(): String = {
serverUri
}
def newShuffleId(): Long = {
nextShuffleId.getAndIncrement()
}
}
......@@ -30,7 +30,7 @@ extends DAGTask[String](stageId) with Logging {
}
val ser = SparkEnv.get.serializer.newInstance()
for (i <- 0 until numOutputSplits) {
val file = LocalFileShuffle.getOutputFile(dep.shuffleId, partition, i)
val file = SparkEnv.get.shuffleManager.getOutputFile(dep.shuffleId, partition, i)
val out = ser.outputStream(new FastBufferedOutputStream(new FileOutputStream(file)))
val iter = buckets(i).entrySet().iterator()
while (iter.hasNext()) {
......@@ -40,7 +40,7 @@ extends DAGTask[String](stageId) with Logging {
// TODO: have some kind of EOF marker
out.close()
}
return LocalFileShuffle.getServerUri
return SparkEnv.get.shuffleManager.getServerUri
}
override def preferredLocations: Seq[String] = locs
......
......@@ -222,6 +222,7 @@ extends Logging {
env.mapOutputTracker.stop()
env.cacheTracker.stop()
env.shuffleFetcher.stop()
env.shuffleManager.stop()
SparkEnv.set(null)
}
......
......@@ -5,7 +5,8 @@ class SparkEnv (
val serializer: Serializer,
val cacheTracker: CacheTracker,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher
val shuffleFetcher: ShuffleFetcher,
val shuffleManager: ShuffleManager
)
object SparkEnv {
......@@ -33,6 +34,8 @@ object SparkEnv {
val shuffleFetcherClass = System.getProperty("spark.shuffle.fetcher", "spark.SimpleShuffleFetcher")
val shuffleFetcher = Class.forName(shuffleFetcherClass).newInstance().asInstanceOf[ShuffleFetcher]
new SparkEnv(cache, serializer, cacheTracker, mapOutputTracker, shuffleFetcher)
val shuffleMgr = new ShuffleManager()
new SparkEnv(cache, serializer, cacheTracker, mapOutputTracker, shuffleFetcher, shuffleMgr)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment