From fd1d255821bde844af28e897fabd59a715659038 Mon Sep 17 00:00:00 2001 From: Matei Zaharia <matei@eecs.berkeley.edu> Date: Tue, 17 May 2011 12:41:13 -0700 Subject: [PATCH] Stop objectifying various trackers, caches, etc. --- .../scala/spark/BitTorrentBroadcast.scala | 2 +- core/src/main/scala/spark/Cache.scala | 24 -------- .../{RDDCache.scala => CacheTracker.scala} | 60 +++++++++---------- .../main/scala/spark/ChainedBroadcast.scala | 2 +- core/src/main/scala/spark/CoGroupedRDD.scala | 4 +- core/src/main/scala/spark/DAGScheduler.scala | 11 ++-- core/src/main/scala/spark/DfsBroadcast.scala | 2 +- .../main/scala/spark/DiskSpillingCache.scala | 4 +- core/src/main/scala/spark/Executor.scala | 14 +++-- .../src/main/scala/spark/JavaSerializer.scala | 6 +- .../main/scala/spark/KryoSerialization.scala | 10 ++-- .../main/scala/spark/LocalFileShuffle.scala | 4 +- .../src/main/scala/spark/LocalScheduler.scala | 6 +- .../main/scala/spark/MapOutputTracker.scala | 29 ++++----- core/src/main/scala/spark/RDD.scala | 2 +- core/src/main/scala/spark/Serializer.scala | 49 ++++++++------- .../main/scala/spark/SerializingCache.scala | 4 +- core/src/main/scala/spark/ShuffledRDD.scala | 4 +- core/src/main/scala/spark/SparkContext.scala | 20 +++---- core/src/main/scala/spark/SparkEnv.scala | 36 +++++++++++ 20 files changed, 153 insertions(+), 140 deletions(-) rename core/src/main/scala/spark/{RDDCache.scala => CacheTracker.scala} (80%) create mode 100644 core/src/main/scala/spark/SparkEnv.scala diff --git a/core/src/main/scala/spark/BitTorrentBroadcast.scala b/core/src/main/scala/spark/BitTorrentBroadcast.scala index 96d3643ffd..2f5d063438 100644 --- a/core/src/main/scala/spark/BitTorrentBroadcast.scala +++ b/core/src/main/scala/spark/BitTorrentBroadcast.scala @@ -1037,7 +1037,7 @@ extends BroadcastFactory { private object BitTorrentBroadcast extends Logging { - val values = Cache.newKeySpace() + val values = SparkEnv.get.cache.newKeySpace() var valueToGuideMap = Map[UUID, SourceInfo] () diff --git a/core/src/main/scala/spark/Cache.scala b/core/src/main/scala/spark/Cache.scala index 9887520758..89befae1a4 100644 --- a/core/src/main/scala/spark/Cache.scala +++ b/core/src/main/scala/spark/Cache.scala @@ -37,27 +37,3 @@ class KeySpace(cache: Cache, id: Long) { def get(key: Any): Any = cache.get((id, key)) def put(key: Any, value: Any): Unit = cache.put((id, key), value) } - - -/** - * The Cache object maintains a global Cache instance, of the type specified - * by the spark.cache.class property. - */ -object Cache { - private var instance: Cache = null - - def initialize() { - val cacheClass = System.getProperty("spark.cache.class", - "spark.SoftReferenceCache") - instance = Class.forName(cacheClass).newInstance().asInstanceOf[Cache] - } - - def getInstance(): Cache = { - if (instance == null) { - throw new SparkException("Cache.getInstance called before initialize") - } - instance - } - - def newKeySpace(): KeySpace = getInstance().newKeySpace() -} diff --git a/core/src/main/scala/spark/RDDCache.scala b/core/src/main/scala/spark/CacheTracker.scala similarity index 80% rename from core/src/main/scala/spark/RDDCache.scala rename to core/src/main/scala/spark/CacheTracker.scala index c5557159a6..8b5c99cf3c 100644 --- a/core/src/main/scala/spark/RDDCache.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -6,22 +6,22 @@ import scala.actors.remote._ import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet -sealed trait CacheMessage -case class AddedToCache(rddId: Int, partition: Int, host: String) extends CacheMessage -case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends CacheMessage -case class MemoryCacheLost(host: String) extends CacheMessage -case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheMessage -case object GetCacheLocations extends CacheMessage -case object StopCacheTracker extends CacheMessage +sealed trait CacheTrackerMessage +case class AddedToCache(rddId: Int, partition: Int, host: String) extends CacheTrackerMessage +case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends CacheTrackerMessage +case class MemoryCacheLost(host: String) extends CacheTrackerMessage +case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage +case object GetCacheLocations extends CacheTrackerMessage +case object StopCacheTracker extends CacheTrackerMessage -class RDDCacheTracker extends DaemonActor with Logging { +class CacheTrackerActor extends DaemonActor with Logging { val locs = new HashMap[Int, Array[List[String]]] // TODO: Should probably store (String, CacheType) tuples def act() { - val port = System.getProperty("spark.master.port", "50501").toInt + val port = System.getProperty("spark.master.port").toInt RemoteActor.alive(port) - RemoteActor.register('RDDCacheTracker, self) + RemoteActor.register('CacheTracker, self) logInfo("Registered actor on port " + port) loop { @@ -60,31 +60,27 @@ class RDDCacheTracker extends DaemonActor with Logging { } } -private object RDDCache extends Logging { - // Stores map results for various splits locally - var cache: KeySpace = null - - // Remembers which splits are currently being loaded (on worker nodes) - val loading = new HashSet[(Int, Int)] - +class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging { // Tracker actor on the master, or remote reference to it on workers var trackerActor: AbstractActor = null - var registeredRddIds: HashSet[Int] = null - - def initialize(isMaster: Boolean) { - if (isMaster) { - val tracker = new RDDCacheTracker - tracker.start - trackerActor = tracker - } else { - val host = System.getProperty("spark.master.host") - val port = System.getProperty("spark.master.port").toInt - trackerActor = RemoteActor.select(Node(host, port), 'RDDCacheTracker) - } - registeredRddIds = new HashSet[Int] - cache = Cache.newKeySpace() + if (isMaster) { + val tracker = new CacheTrackerActor + tracker.start + trackerActor = tracker + } else { + val host = System.getProperty("spark.master.host") + val port = System.getProperty("spark.master.port").toInt + trackerActor = RemoteActor.select(Node(host, port), 'CacheTracker) } + + val registeredRddIds = new HashSet[Int] + + // Stores map results for various splits locally + val cache = theCache.newKeySpace() + + // Remembers which splits are currently being loaded (on worker nodes) + val loading = new HashSet[(Int, Int)] // Registers an RDD (on master only) def registerRDD(rddId: Int, numPartitions: Int) { @@ -102,7 +98,7 @@ private object RDDCache extends Logging { (trackerActor !? GetCacheLocations) match { case h: HashMap[Int, Array[List[String]]] => h case _ => throw new SparkException( - "Internal error: RDDCache did not reply with a HashMap") + "Internal error: CacheTrackerActor did not reply with a HashMap") } } diff --git a/core/src/main/scala/spark/ChainedBroadcast.scala b/core/src/main/scala/spark/ChainedBroadcast.scala index afd3c0293c..63c79c693e 100644 --- a/core/src/main/scala/spark/ChainedBroadcast.scala +++ b/core/src/main/scala/spark/ChainedBroadcast.scala @@ -719,7 +719,7 @@ extends BroadcastFactory { private object ChainedBroadcast extends Logging { - val values = Cache.newKeySpace() + val values = SparkEnv.get.cache.newKeySpace() var valueToGuidePortMap = Map[UUID, Int] () diff --git a/core/src/main/scala/spark/CoGroupedRDD.scala b/core/src/main/scala/spark/CoGroupedRDD.scala index 4c427bd67c..53cae76e3a 100644 --- a/core/src/main/scala/spark/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/CoGroupedRDD.scala @@ -83,7 +83,7 @@ extends RDD[(K, Seq[Seq[_]])](rdds.first.context) with Logging { // Read map outputs of shuffle logInfo("Grabbing map outputs for shuffle ID " + shuffleId) val splitsByUri = new HashMap[String, ArrayBuffer[Int]] - val serverUris = MapOutputTracker.getServerUris(shuffleId) + val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId) for ((serverUri, index) <- serverUris.zipWithIndex) { splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index } @@ -109,4 +109,4 @@ extends RDD[(K, Seq[Seq[_]])](rdds.first.context) with Logging { } map.iterator } -} \ No newline at end of file +} diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala index 2e427dcb0c..048a0faf2f 100644 --- a/core/src/main/scala/spark/DAGScheduler.scala +++ b/core/src/main/scala/spark/DAGScheduler.scala @@ -35,12 +35,15 @@ private trait DAGScheduler extends Scheduler with Logging { var cacheLocs = new HashMap[Int, Array[List[String]]] + val cacheTracker = SparkEnv.get.cacheTracker + val mapOutputTracker = SparkEnv.get.mapOutputTracker + def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { cacheLocs(rdd.id) } def updateCacheLocs() { - cacheLocs = RDDCache.getLocationsSnapshot() + cacheLocs = cacheTracker.getLocationsSnapshot() } def getShuffleMapStage(shuf: ShuffleDependency[_,_,_]): Stage = { @@ -56,7 +59,7 @@ private trait DAGScheduler extends Scheduler with Logging { def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = { // Kind of ugly: need to register RDDs with the cache here since // we can't do it in its constructor because # of splits is unknown - RDDCache.registerRDD(rdd.id, rdd.splits.size) + cacheTracker.registerRDD(rdd.id, rdd.splits.size) val id = newStageId() val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd)) idToStage(id) = stage @@ -71,7 +74,7 @@ private trait DAGScheduler extends Scheduler with Logging { visited += r // Kind of ugly: need to register RDDs with the cache here since // we can't do it in its constructor because # of splits is unknown - RDDCache.registerRDD(r.id, r.splits.size) + cacheTracker.registerRDD(r.id, r.splits.size) for (dep <- r.dependencies) { dep match { case shufDep: ShuffleDependency[_,_,_] => @@ -187,7 +190,7 @@ private trait DAGScheduler extends Scheduler with Logging { logInfo(stage + " finished; looking for newly runnable stages") running -= stage if (stage.shuffleDep != None) { - MapOutputTracker.registerMapOutputs( + mapOutputTracker.registerMapOutputs( stage.shuffleDep.get.shuffleId, stage.outputLocs.map(_.first).toArray) } diff --git a/core/src/main/scala/spark/DfsBroadcast.scala b/core/src/main/scala/spark/DfsBroadcast.scala index 480d6dd9b1..895f55ca22 100644 --- a/core/src/main/scala/spark/DfsBroadcast.scala +++ b/core/src/main/scala/spark/DfsBroadcast.scala @@ -61,7 +61,7 @@ extends BroadcastFactory { private object DfsBroadcast extends Logging { - val values = Cache.newKeySpace() + val values = SparkEnv.get.cache.newKeySpace() private var initialized = false diff --git a/core/src/main/scala/spark/DiskSpillingCache.scala b/core/src/main/scala/spark/DiskSpillingCache.scala index 9e52fee69e..80e13a2519 100644 --- a/core/src/main/scala/spark/DiskSpillingCache.scala +++ b/core/src/main/scala/spark/DiskSpillingCache.scala @@ -14,7 +14,7 @@ class DiskSpillingCache extends BoundedMemoryCache { override def get(key: Any): Any = { synchronized { - val ser = Serializer.newInstance() + val ser = SparkEnv.get.serializer.newInstance() super.get(key) match { case bytes: Any => // found in memory ser.deserialize(bytes.asInstanceOf[Array[Byte]]) @@ -46,7 +46,7 @@ class DiskSpillingCache extends BoundedMemoryCache { } override def put(key: Any, value: Any) { - var ser = Serializer.newInstance() + var ser = SparkEnv.get.serializer.newInstance() super.put(key, ser.serialize(value)) } diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/Executor.scala index 98d757e116..a3666fdbae 100644 --- a/core/src/main/scala/spark/Executor.scala +++ b/core/src/main/scala/spark/Executor.scala @@ -15,6 +15,7 @@ import mesos.{TaskDescription, TaskState, TaskStatus} class Executor extends mesos.Executor with Logging { var classLoader: ClassLoader = null var threadPool: ExecutorService = null + var env: SparkEnv = null override def init(d: ExecutorDriver, args: ExecutorArgs) { // Read spark.* system properties from executor arg @@ -22,19 +23,19 @@ class Executor extends mesos.Executor with Logging { for ((key, value) <- props) System.setProperty(key, value) - // Initialize cache and broadcast system (uses some properties read above) - Cache.initialize() - Serializer.initialize() + // Initialize Spark environment (using system properties read above) + env = SparkEnv.createFromSystemProperties(false) + SparkEnv.set(env) + // Old stuff that isn't yet using env Broadcast.initialize(false) - MapOutputTracker.initialize(false) - RDDCache.initialize(false) // Create our ClassLoader (using spark properties) and set it on this thread classLoader = createClassLoader() Thread.currentThread.setContextClassLoader(classLoader) // Start worker thread pool (they will inherit our context ClassLoader) - threadPool = new ThreadPoolExecutor(1, 128, 600, TimeUnit.SECONDS, new LinkedBlockingQueue[Runnable]) + threadPool = new ThreadPoolExecutor( + 1, 128, 600, TimeUnit.SECONDS, new LinkedBlockingQueue[Runnable]) } override def launchTask(d: ExecutorDriver, desc: TaskDescription) { @@ -46,6 +47,7 @@ class Executor extends mesos.Executor with Logging { def run() = { logInfo("Running task ID " + taskId) try { + SparkEnv.set(env) Accumulators.clear val task = Utils.deserialize[Task[Any]](arg, classLoader) val value = task.run diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala index 8ee3044058..af390d55d8 100644 --- a/core/src/main/scala/spark/JavaSerializer.scala +++ b/core/src/main/scala/spark/JavaSerializer.scala @@ -19,7 +19,7 @@ class JavaDeserializationStream(in: InputStream) extends DeserializationStream { def close() { objIn.close() } } -class JavaSerializer extends Serializer { +class JavaSerializerInstance extends SerializerInstance { def serialize[T](t: T): Array[Byte] = { val bos = new ByteArrayOutputStream() val out = outputStream(bos) @@ -43,6 +43,6 @@ class JavaSerializer extends Serializer { } } -class JavaSerialization extends SerializationStrategy { - def newSerializer(): Serializer = new JavaSerializer +class JavaSerializer extends Serializer { + def newInstance(): SerializerInstance = new JavaSerializerInstance } diff --git a/core/src/main/scala/spark/KryoSerialization.scala b/core/src/main/scala/spark/KryoSerialization.scala index 54427ecf71..ba34a5452a 100644 --- a/core/src/main/scala/spark/KryoSerialization.scala +++ b/core/src/main/scala/spark/KryoSerialization.scala @@ -82,8 +82,8 @@ extends DeserializationStream { def close() { in.close() } } -class KryoSerializer(strat: KryoSerialization) extends Serializer { - val buf = strat.threadBuf.get() +class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { + val buf = ks.threadBuf.get() def serialize[T](t: T): Array[Byte] = { buf.writeClassAndObject(t) @@ -94,7 +94,7 @@ class KryoSerializer(strat: KryoSerialization) extends Serializer { } def outputStream(s: OutputStream): SerializationStream = { - new KryoSerializationStream(strat.kryo, strat.threadByteBuf.get(), s) + new KryoSerializationStream(ks.kryo, ks.threadByteBuf.get(), s) } def inputStream(s: InputStream): DeserializationStream = { @@ -107,7 +107,7 @@ trait KryoRegistrator { def registerClasses(kryo: Kryo): Unit } -class KryoSerialization extends SerializationStrategy with Logging { +class KryoSerializer extends Serializer with Logging { val kryo = createKryo() val threadBuf = new ThreadLocal[ObjectBuffer] { @@ -162,5 +162,5 @@ class KryoSerialization extends SerializationStrategy with Logging { kryo } - def newSerializer(): Serializer = new KryoSerializer(this) + def newInstance(): SerializerInstance = new KryoSerializerInstance(this) } diff --git a/core/src/main/scala/spark/LocalFileShuffle.scala b/core/src/main/scala/spark/LocalFileShuffle.scala index 057a7ff43d..ee57ddbf61 100644 --- a/core/src/main/scala/spark/LocalFileShuffle.scala +++ b/core/src/main/scala/spark/LocalFileShuffle.scala @@ -47,7 +47,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { case None => createCombiner(v) } } - val ser = Serializer.newInstance() + val ser = SparkEnv.get.serializer.newInstance() for (i <- 0 until numOutputSplits) { val file = LocalFileShuffle.getOutputFile(shuffleId, myIndex, i) val out = ser.outputStream(new FileOutputStream(file)) @@ -70,7 +70,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) return indexes.flatMap((myId: Int) => { val combiners = new HashMap[K, C] - val ser = Serializer.newInstance() + val ser = SparkEnv.get.serializer.newInstance() for ((serverUri, inputIds) <- Utils.shuffle(splitsByUri)) { for (i <- inputIds) { val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, myId) diff --git a/core/src/main/scala/spark/LocalScheduler.scala b/core/src/main/scala/spark/LocalScheduler.scala index 0287082687..832ab8cca8 100644 --- a/core/src/main/scala/spark/LocalScheduler.scala +++ b/core/src/main/scala/spark/LocalScheduler.scala @@ -8,6 +8,8 @@ import java.util.concurrent._ private class LocalScheduler(threads: Int) extends DAGScheduler with Logging { var threadPool: ExecutorService = Executors.newFixedThreadPool(threads, DaemonThreadFactory) + + val env = SparkEnv.get override def start() {} @@ -18,6 +20,8 @@ private class LocalScheduler(threads: Int) extends DAGScheduler with Logging { threadPool.submit(new Runnable { def run() { logInfo("Running task " + i) + // Set the Spark execution environment for the worker thread + SparkEnv.set(env) try { // Serialize and deserialize the task so that accumulators are // changed to thread-local ones; this adds a bit of unnecessary @@ -47,4 +51,4 @@ private class LocalScheduler(threads: Int) extends DAGScheduler with Logging { override def stop() {} override def numCores() = threads -} \ No newline at end of file +} diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 4334034ecb..d36fbc7703 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -11,10 +11,10 @@ sealed trait MapOutputTrackerMessage case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage case object StopMapOutputTracker extends MapOutputTrackerMessage -class MapOutputTracker(serverUris: ConcurrentHashMap[Int, Array[String]]) +class MapOutputTrackerActor(serverUris: ConcurrentHashMap[Int, Array[String]]) extends DaemonActor with Logging { def act() { - val port = System.getProperty("spark.master.port", "50501").toInt + val port = System.getProperty("spark.master.port").toInt RemoteActor.alive(port) RemoteActor.register('MapOutputTracker, self) logInfo("Registered actor on port " + port) @@ -32,22 +32,20 @@ extends DaemonActor with Logging { } } -object MapOutputTracker extends Logging { +class MapOutputTracker(isMaster: Boolean) extends Logging { var trackerActor: AbstractActor = null - private val serverUris = new ConcurrentHashMap[Int, Array[String]] - - def initialize(isMaster: Boolean) { - if (isMaster) { - val tracker = new MapOutputTracker(serverUris) - tracker.start - trackerActor = tracker - } else { - val host = System.getProperty("spark.master.host") - val port = System.getProperty("spark.master.port").toInt - trackerActor = RemoteActor.select(Node(host, port), 'MapOutputTracker) - } + if (isMaster) { + val tracker = new MapOutputTrackerActor(serverUris) + tracker.start + trackerActor = tracker + } else { + val host = System.getProperty("spark.master.host") + val port = System.getProperty("spark.master.port").toInt + trackerActor = RemoteActor.select(Node(host, port), 'MapOutputTracker) } + + private val serverUris = new ConcurrentHashMap[Int, Array[String]] def registerMapOutput(shuffleId: Int, numMaps: Int, mapId: Int, serverUri: String) { var array = serverUris.get(shuffleId) @@ -62,7 +60,6 @@ object MapOutputTracker extends Logging { serverUris.put(shuffleId, Array[String]() ++ locs) } - // Remembers which map output locations are currently being fetched val fetching = new HashSet[Int] diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 40eb7967ec..6accd5e356 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -43,7 +43,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) { // Read this RDD; will read from cache if applicable, or otherwise compute final def iterator(split: Split): Iterator[T] = { if (shouldCache) { - RDDCache.getOrCompute[T](this, split) + SparkEnv.get.cacheTracker.getOrCompute[T](this, split) } else { compute(split) } diff --git a/core/src/main/scala/spark/Serializer.scala b/core/src/main/scala/spark/Serializer.scala index a182f6bddc..cfc6d978bc 100644 --- a/core/src/main/scala/spark/Serializer.scala +++ b/core/src/main/scala/spark/Serializer.scala @@ -2,39 +2,38 @@ package spark import java.io.{InputStream, OutputStream} -trait SerializationStream { - def writeObject[T](t: T): Unit - def flush(): Unit - def close(): Unit -} - -trait DeserializationStream { - def readObject[T](): T - def close(): Unit +/** + * A serializer. Because some serialization libraries are not thread safe, + * this class is used to create SerializerInstances that do the actual + * serialization. + */ +trait Serializer { + def newInstance(): SerializerInstance } -trait Serializer { +/** + * An instance of the serializer, for use by one thread at a time. + */ +trait SerializerInstance { def serialize[T](t: T): Array[Byte] def deserialize[T](bytes: Array[Byte]): T def outputStream(s: OutputStream): SerializationStream def inputStream(s: InputStream): DeserializationStream } -trait SerializationStrategy { - def newSerializer(): Serializer +/** + * A stream for writing serialized objects. + */ +trait SerializationStream { + def writeObject[T](t: T): Unit + def flush(): Unit + def close(): Unit } -object Serializer { - var strat: SerializationStrategy = null - - def initialize() { - val cls = System.getProperty("spark.serialization", - "spark.JavaSerialization") - strat = Class.forName(cls).newInstance().asInstanceOf[SerializationStrategy] - } - - // Return a serializer ** for use by a single thread ** - def newInstance(): Serializer = { - strat.newSerializer() - } +/** + * A stream for reading serialized objects. + */ +trait DeserializationStream { + def readObject[T](): T + def close(): Unit } diff --git a/core/src/main/scala/spark/SerializingCache.scala b/core/src/main/scala/spark/SerializingCache.scala index cbc64736e6..2c1f96a700 100644 --- a/core/src/main/scala/spark/SerializingCache.scala +++ b/core/src/main/scala/spark/SerializingCache.scala @@ -10,14 +10,14 @@ class SerializingCache extends Cache with Logging { val bmc = new BoundedMemoryCache override def put(key: Any, value: Any) { - val ser = Serializer.newInstance() + val ser = SparkEnv.get.serializer.newInstance() bmc.put(key, ser.serialize(value)) } override def get(key: Any): Any = { val bytes = bmc.get(key) if (bytes != null) { - val ser = Serializer.newInstance() + val ser = SparkEnv.get.serializer.newInstance() return ser.deserialize(bytes.asInstanceOf[Array[Byte]]) } else { return null diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala index 683df12019..f730f0580e 100644 --- a/core/src/main/scala/spark/ShuffledRDD.scala +++ b/core/src/main/scala/spark/ShuffledRDD.scala @@ -33,7 +33,7 @@ extends RDD[(K, C)](parent.context) { val shuffleId = dep.shuffleId val splitId = split.index val splitsByUri = new HashMap[String, ArrayBuffer[Int]] - val serverUris = MapOutputTracker.getServerUris(shuffleId) + val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId) for ((serverUri, index) <- serverUris.zipWithIndex) { splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index } @@ -58,4 +58,4 @@ extends RDD[(K, C)](parent.context) { } combiners.iterator } -} \ No newline at end of file +} diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index dc6964e14b..c1807de0ef 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -20,7 +20,13 @@ extends Logging { System.setProperty("spark.master.host", Utils.localHostName) if (System.getProperty("spark.master.port") == null) System.setProperty("spark.master.port", "50501") + + // Create the Spark execution environment (cache, map output tracker, etc) + val env = SparkEnv.createFromSystemProperties(true) + SparkEnv.set(env) + Broadcast.initialize(true) + // Create and start the scheduler private var scheduler: Scheduler = { // Regular expression used for local[N] master format val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r @@ -34,16 +40,9 @@ extends Logging { new MesosScheduler(this, master, frameworkName) } } + scheduler.start() private val isLocal = scheduler.isInstanceOf[LocalScheduler] - - // Start the scheduler, the cache and the broadcast system - scheduler.start() - Cache.initialize() - Serializer.initialize() - Broadcast.initialize(true) - MapOutputTracker.initialize(true) - RDDCache.initialize(true) // Methods for creating RDDs @@ -122,8 +121,9 @@ extends Logging { scheduler.stop() scheduler = null // TODO: Broadcast.stop(), Cache.stop()? - MapOutputTracker.stop() - RDDCache.stop() + env.mapOutputTracker.stop() + env.cacheTracker.stop() + SparkEnv.set(null) } // Wait for the scheduler to be registered diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala new file mode 100644 index 0000000000..1bfd0172d7 --- /dev/null +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -0,0 +1,36 @@ +package spark + +class SparkEnv ( + val cache: Cache, + val serializer: Serializer, + val cacheTracker: CacheTracker, + val mapOutputTracker: MapOutputTracker +) + +object SparkEnv { + private val env = new ThreadLocal[SparkEnv] + + def set(e: SparkEnv) { + env.set(e) + } + + def get: SparkEnv = { + env.get() + } + + def createFromSystemProperties(isMaster: Boolean): SparkEnv = { + val cacheClass = System.getProperty("spark.cache.class", + "spark.SoftReferenceCache") + val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache] + + val serClass = System.getProperty("spark.serializer", + "spark.JavaSerializer") + val ser = Class.forName(serClass).newInstance().asInstanceOf[Serializer] + + val cacheTracker = new CacheTracker(isMaster, cache) + + val mapOutputTracker = new MapOutputTracker(isMaster) + + new SparkEnv(cache, ser, cacheTracker, mapOutputTracker) + } +} -- GitLab