diff --git a/core/src/main/scala/spark/BitTorrentBroadcast.scala b/core/src/main/scala/spark/BitTorrentBroadcast.scala index 96d3643ffd9b237b8c6136e11148f4e12ac3a68a..2f5d063438fc5afe46f35597bde54e1b8b7737da 100644 --- a/core/src/main/scala/spark/BitTorrentBroadcast.scala +++ b/core/src/main/scala/spark/BitTorrentBroadcast.scala @@ -1037,7 +1037,7 @@ extends BroadcastFactory { private object BitTorrentBroadcast extends Logging { - val values = Cache.newKeySpace() + val values = SparkEnv.get.cache.newKeySpace() var valueToGuideMap = Map[UUID, SourceInfo] () diff --git a/core/src/main/scala/spark/Cache.scala b/core/src/main/scala/spark/Cache.scala index 9887520758e0761315113957495de4f1004f95e7..89befae1a4df6e8d582b73ebd518b3a379e23778 100644 --- a/core/src/main/scala/spark/Cache.scala +++ b/core/src/main/scala/spark/Cache.scala @@ -37,27 +37,3 @@ class KeySpace(cache: Cache, id: Long) { def get(key: Any): Any = cache.get((id, key)) def put(key: Any, value: Any): Unit = cache.put((id, key), value) } - - -/** - * The Cache object maintains a global Cache instance, of the type specified - * by the spark.cache.class property. - */ -object Cache { - private var instance: Cache = null - - def initialize() { - val cacheClass = System.getProperty("spark.cache.class", - "spark.SoftReferenceCache") - instance = Class.forName(cacheClass).newInstance().asInstanceOf[Cache] - } - - def getInstance(): Cache = { - if (instance == null) { - throw new SparkException("Cache.getInstance called before initialize") - } - instance - } - - def newKeySpace(): KeySpace = getInstance().newKeySpace() -} diff --git a/core/src/main/scala/spark/RDDCache.scala b/core/src/main/scala/spark/CacheTracker.scala similarity index 80% rename from core/src/main/scala/spark/RDDCache.scala rename to core/src/main/scala/spark/CacheTracker.scala index c5557159a6a02216600263f0d5c29277e50ea820..8b5c99cf3c1eb3a73806213fe5e9ab66a888402b 100644 --- a/core/src/main/scala/spark/RDDCache.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -6,22 +6,22 @@ import scala.actors.remote._ import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet -sealed trait CacheMessage -case class AddedToCache(rddId: Int, partition: Int, host: String) extends CacheMessage -case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends CacheMessage -case class MemoryCacheLost(host: String) extends CacheMessage -case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheMessage -case object GetCacheLocations extends CacheMessage -case object StopCacheTracker extends CacheMessage +sealed trait CacheTrackerMessage +case class AddedToCache(rddId: Int, partition: Int, host: String) extends CacheTrackerMessage +case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends CacheTrackerMessage +case class MemoryCacheLost(host: String) extends CacheTrackerMessage +case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage +case object GetCacheLocations extends CacheTrackerMessage +case object StopCacheTracker extends CacheTrackerMessage -class RDDCacheTracker extends DaemonActor with Logging { +class CacheTrackerActor extends DaemonActor with Logging { val locs = new HashMap[Int, Array[List[String]]] // TODO: Should probably store (String, CacheType) tuples def act() { - val port = System.getProperty("spark.master.port", "50501").toInt + val port = System.getProperty("spark.master.port").toInt RemoteActor.alive(port) - RemoteActor.register('RDDCacheTracker, self) + RemoteActor.register('CacheTracker, self) logInfo("Registered actor on port " + port) loop { @@ -60,31 +60,27 @@ class RDDCacheTracker extends DaemonActor with Logging { } } -private object RDDCache extends Logging { - // Stores map results for various splits locally - var cache: KeySpace = null - - // Remembers which splits are currently being loaded (on worker nodes) - val loading = new HashSet[(Int, Int)] - +class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging { // Tracker actor on the master, or remote reference to it on workers var trackerActor: AbstractActor = null - var registeredRddIds: HashSet[Int] = null - - def initialize(isMaster: Boolean) { - if (isMaster) { - val tracker = new RDDCacheTracker - tracker.start - trackerActor = tracker - } else { - val host = System.getProperty("spark.master.host") - val port = System.getProperty("spark.master.port").toInt - trackerActor = RemoteActor.select(Node(host, port), 'RDDCacheTracker) - } - registeredRddIds = new HashSet[Int] - cache = Cache.newKeySpace() + if (isMaster) { + val tracker = new CacheTrackerActor + tracker.start + trackerActor = tracker + } else { + val host = System.getProperty("spark.master.host") + val port = System.getProperty("spark.master.port").toInt + trackerActor = RemoteActor.select(Node(host, port), 'CacheTracker) } + + val registeredRddIds = new HashSet[Int] + + // Stores map results for various splits locally + val cache = theCache.newKeySpace() + + // Remembers which splits are currently being loaded (on worker nodes) + val loading = new HashSet[(Int, Int)] // Registers an RDD (on master only) def registerRDD(rddId: Int, numPartitions: Int) { @@ -102,7 +98,7 @@ private object RDDCache extends Logging { (trackerActor !? GetCacheLocations) match { case h: HashMap[Int, Array[List[String]]] => h case _ => throw new SparkException( - "Internal error: RDDCache did not reply with a HashMap") + "Internal error: CacheTrackerActor did not reply with a HashMap") } } diff --git a/core/src/main/scala/spark/ChainedBroadcast.scala b/core/src/main/scala/spark/ChainedBroadcast.scala index afd3c0293c3c31f9b37a85023f12e263bc034a92..63c79c693e23d7457ad81858f28fa4a9031f7ad6 100644 --- a/core/src/main/scala/spark/ChainedBroadcast.scala +++ b/core/src/main/scala/spark/ChainedBroadcast.scala @@ -719,7 +719,7 @@ extends BroadcastFactory { private object ChainedBroadcast extends Logging { - val values = Cache.newKeySpace() + val values = SparkEnv.get.cache.newKeySpace() var valueToGuidePortMap = Map[UUID, Int] () diff --git a/core/src/main/scala/spark/CoGroupedRDD.scala b/core/src/main/scala/spark/CoGroupedRDD.scala index 4c427bd67cc8944c9d45e70b3975bc78d28414ad..53cae76e3a6d37f0942eacaf34a9d6c189c8d96d 100644 --- a/core/src/main/scala/spark/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/CoGroupedRDD.scala @@ -83,7 +83,7 @@ extends RDD[(K, Seq[Seq[_]])](rdds.first.context) with Logging { // Read map outputs of shuffle logInfo("Grabbing map outputs for shuffle ID " + shuffleId) val splitsByUri = new HashMap[String, ArrayBuffer[Int]] - val serverUris = MapOutputTracker.getServerUris(shuffleId) + val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId) for ((serverUri, index) <- serverUris.zipWithIndex) { splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index } @@ -109,4 +109,4 @@ extends RDD[(K, Seq[Seq[_]])](rdds.first.context) with Logging { } map.iterator } -} \ No newline at end of file +} diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala index 2e427dcb0cbf21a9562cac3e597d68d41af917a0..048a0faf2ffd54b4e8af48afad24b52ca5723f7a 100644 --- a/core/src/main/scala/spark/DAGScheduler.scala +++ b/core/src/main/scala/spark/DAGScheduler.scala @@ -35,12 +35,15 @@ private trait DAGScheduler extends Scheduler with Logging { var cacheLocs = new HashMap[Int, Array[List[String]]] + val cacheTracker = SparkEnv.get.cacheTracker + val mapOutputTracker = SparkEnv.get.mapOutputTracker + def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { cacheLocs(rdd.id) } def updateCacheLocs() { - cacheLocs = RDDCache.getLocationsSnapshot() + cacheLocs = cacheTracker.getLocationsSnapshot() } def getShuffleMapStage(shuf: ShuffleDependency[_,_,_]): Stage = { @@ -56,7 +59,7 @@ private trait DAGScheduler extends Scheduler with Logging { def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = { // Kind of ugly: need to register RDDs with the cache here since // we can't do it in its constructor because # of splits is unknown - RDDCache.registerRDD(rdd.id, rdd.splits.size) + cacheTracker.registerRDD(rdd.id, rdd.splits.size) val id = newStageId() val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd)) idToStage(id) = stage @@ -71,7 +74,7 @@ private trait DAGScheduler extends Scheduler with Logging { visited += r // Kind of ugly: need to register RDDs with the cache here since // we can't do it in its constructor because # of splits is unknown - RDDCache.registerRDD(r.id, r.splits.size) + cacheTracker.registerRDD(r.id, r.splits.size) for (dep <- r.dependencies) { dep match { case shufDep: ShuffleDependency[_,_,_] => @@ -187,7 +190,7 @@ private trait DAGScheduler extends Scheduler with Logging { logInfo(stage + " finished; looking for newly runnable stages") running -= stage if (stage.shuffleDep != None) { - MapOutputTracker.registerMapOutputs( + mapOutputTracker.registerMapOutputs( stage.shuffleDep.get.shuffleId, stage.outputLocs.map(_.first).toArray) } diff --git a/core/src/main/scala/spark/DfsBroadcast.scala b/core/src/main/scala/spark/DfsBroadcast.scala index 480d6dd9b1fe72fae626be6ea91a97bd55723ffc..895f55ca22ae202f561a4d190f90ef0f5e803933 100644 --- a/core/src/main/scala/spark/DfsBroadcast.scala +++ b/core/src/main/scala/spark/DfsBroadcast.scala @@ -61,7 +61,7 @@ extends BroadcastFactory { private object DfsBroadcast extends Logging { - val values = Cache.newKeySpace() + val values = SparkEnv.get.cache.newKeySpace() private var initialized = false diff --git a/core/src/main/scala/spark/DiskSpillingCache.scala b/core/src/main/scala/spark/DiskSpillingCache.scala index 9e52fee69e44858c5fbdf869d068d0319d39f46b..80e13a25196b79d2695d06688b7a3da1c345d080 100644 --- a/core/src/main/scala/spark/DiskSpillingCache.scala +++ b/core/src/main/scala/spark/DiskSpillingCache.scala @@ -14,7 +14,7 @@ class DiskSpillingCache extends BoundedMemoryCache { override def get(key: Any): Any = { synchronized { - val ser = Serializer.newInstance() + val ser = SparkEnv.get.serializer.newInstance() super.get(key) match { case bytes: Any => // found in memory ser.deserialize(bytes.asInstanceOf[Array[Byte]]) @@ -46,7 +46,7 @@ class DiskSpillingCache extends BoundedMemoryCache { } override def put(key: Any, value: Any) { - var ser = Serializer.newInstance() + var ser = SparkEnv.get.serializer.newInstance() super.put(key, ser.serialize(value)) } diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/Executor.scala index 98d757e116dcc4b74c603f82c889697d0e6a7846..a3666fdbaed3eeabe244a076168f6f241d3acfdc 100644 --- a/core/src/main/scala/spark/Executor.scala +++ b/core/src/main/scala/spark/Executor.scala @@ -15,6 +15,7 @@ import mesos.{TaskDescription, TaskState, TaskStatus} class Executor extends mesos.Executor with Logging { var classLoader: ClassLoader = null var threadPool: ExecutorService = null + var env: SparkEnv = null override def init(d: ExecutorDriver, args: ExecutorArgs) { // Read spark.* system properties from executor arg @@ -22,19 +23,19 @@ class Executor extends mesos.Executor with Logging { for ((key, value) <- props) System.setProperty(key, value) - // Initialize cache and broadcast system (uses some properties read above) - Cache.initialize() - Serializer.initialize() + // Initialize Spark environment (using system properties read above) + env = SparkEnv.createFromSystemProperties(false) + SparkEnv.set(env) + // Old stuff that isn't yet using env Broadcast.initialize(false) - MapOutputTracker.initialize(false) - RDDCache.initialize(false) // Create our ClassLoader (using spark properties) and set it on this thread classLoader = createClassLoader() Thread.currentThread.setContextClassLoader(classLoader) // Start worker thread pool (they will inherit our context ClassLoader) - threadPool = new ThreadPoolExecutor(1, 128, 600, TimeUnit.SECONDS, new LinkedBlockingQueue[Runnable]) + threadPool = new ThreadPoolExecutor( + 1, 128, 600, TimeUnit.SECONDS, new LinkedBlockingQueue[Runnable]) } override def launchTask(d: ExecutorDriver, desc: TaskDescription) { @@ -46,6 +47,7 @@ class Executor extends mesos.Executor with Logging { def run() = { logInfo("Running task ID " + taskId) try { + SparkEnv.set(env) Accumulators.clear val task = Utils.deserialize[Task[Any]](arg, classLoader) val value = task.run diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala index 8ee3044058f4ee07291b0cbd1ec09c64746715e0..af390d55d863f598c0bfcade83388a30b0de1228 100644 --- a/core/src/main/scala/spark/JavaSerializer.scala +++ b/core/src/main/scala/spark/JavaSerializer.scala @@ -19,7 +19,7 @@ class JavaDeserializationStream(in: InputStream) extends DeserializationStream { def close() { objIn.close() } } -class JavaSerializer extends Serializer { +class JavaSerializerInstance extends SerializerInstance { def serialize[T](t: T): Array[Byte] = { val bos = new ByteArrayOutputStream() val out = outputStream(bos) @@ -43,6 +43,6 @@ class JavaSerializer extends Serializer { } } -class JavaSerialization extends SerializationStrategy { - def newSerializer(): Serializer = new JavaSerializer +class JavaSerializer extends Serializer { + def newInstance(): SerializerInstance = new JavaSerializerInstance } diff --git a/core/src/main/scala/spark/KryoSerialization.scala b/core/src/main/scala/spark/KryoSerialization.scala index 54427ecf71b42d0b228789ac3d9d199d30ab825b..ba34a5452adc453c846da134edd8030681083113 100644 --- a/core/src/main/scala/spark/KryoSerialization.scala +++ b/core/src/main/scala/spark/KryoSerialization.scala @@ -82,8 +82,8 @@ extends DeserializationStream { def close() { in.close() } } -class KryoSerializer(strat: KryoSerialization) extends Serializer { - val buf = strat.threadBuf.get() +class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { + val buf = ks.threadBuf.get() def serialize[T](t: T): Array[Byte] = { buf.writeClassAndObject(t) @@ -94,7 +94,7 @@ class KryoSerializer(strat: KryoSerialization) extends Serializer { } def outputStream(s: OutputStream): SerializationStream = { - new KryoSerializationStream(strat.kryo, strat.threadByteBuf.get(), s) + new KryoSerializationStream(ks.kryo, ks.threadByteBuf.get(), s) } def inputStream(s: InputStream): DeserializationStream = { @@ -107,7 +107,7 @@ trait KryoRegistrator { def registerClasses(kryo: Kryo): Unit } -class KryoSerialization extends SerializationStrategy with Logging { +class KryoSerializer extends Serializer with Logging { val kryo = createKryo() val threadBuf = new ThreadLocal[ObjectBuffer] { @@ -162,5 +162,5 @@ class KryoSerialization extends SerializationStrategy with Logging { kryo } - def newSerializer(): Serializer = new KryoSerializer(this) + def newInstance(): SerializerInstance = new KryoSerializerInstance(this) } diff --git a/core/src/main/scala/spark/LocalFileShuffle.scala b/core/src/main/scala/spark/LocalFileShuffle.scala index 057a7ff43de2836311c4bc952d9da879f32947c9..ee57ddbf610155eb2d16f0d71a2b5d6e99fd6a03 100644 --- a/core/src/main/scala/spark/LocalFileShuffle.scala +++ b/core/src/main/scala/spark/LocalFileShuffle.scala @@ -47,7 +47,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { case None => createCombiner(v) } } - val ser = Serializer.newInstance() + val ser = SparkEnv.get.serializer.newInstance() for (i <- 0 until numOutputSplits) { val file = LocalFileShuffle.getOutputFile(shuffleId, myIndex, i) val out = ser.outputStream(new FileOutputStream(file)) @@ -70,7 +70,7 @@ class LocalFileShuffle[K, V, C] extends Shuffle[K, V, C] with Logging { val indexes = sc.parallelize(0 until numOutputSplits, numOutputSplits) return indexes.flatMap((myId: Int) => { val combiners = new HashMap[K, C] - val ser = Serializer.newInstance() + val ser = SparkEnv.get.serializer.newInstance() for ((serverUri, inputIds) <- Utils.shuffle(splitsByUri)) { for (i <- inputIds) { val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, myId) diff --git a/core/src/main/scala/spark/LocalScheduler.scala b/core/src/main/scala/spark/LocalScheduler.scala index 0287082687cb32d34dd1330341b2842a4ef14b1e..832ab8cca8b4148811aac1a765b55c70a59ff6c1 100644 --- a/core/src/main/scala/spark/LocalScheduler.scala +++ b/core/src/main/scala/spark/LocalScheduler.scala @@ -8,6 +8,8 @@ import java.util.concurrent._ private class LocalScheduler(threads: Int) extends DAGScheduler with Logging { var threadPool: ExecutorService = Executors.newFixedThreadPool(threads, DaemonThreadFactory) + + val env = SparkEnv.get override def start() {} @@ -18,6 +20,8 @@ private class LocalScheduler(threads: Int) extends DAGScheduler with Logging { threadPool.submit(new Runnable { def run() { logInfo("Running task " + i) + // Set the Spark execution environment for the worker thread + SparkEnv.set(env) try { // Serialize and deserialize the task so that accumulators are // changed to thread-local ones; this adds a bit of unnecessary @@ -47,4 +51,4 @@ private class LocalScheduler(threads: Int) extends DAGScheduler with Logging { override def stop() {} override def numCores() = threads -} \ No newline at end of file +} diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 4334034ecbc1c93af79ae1d616267de3b517b221..d36fbc77039e6eed5a3d74e1db2b8a62427dff2e 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -11,10 +11,10 @@ sealed trait MapOutputTrackerMessage case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage case object StopMapOutputTracker extends MapOutputTrackerMessage -class MapOutputTracker(serverUris: ConcurrentHashMap[Int, Array[String]]) +class MapOutputTrackerActor(serverUris: ConcurrentHashMap[Int, Array[String]]) extends DaemonActor with Logging { def act() { - val port = System.getProperty("spark.master.port", "50501").toInt + val port = System.getProperty("spark.master.port").toInt RemoteActor.alive(port) RemoteActor.register('MapOutputTracker, self) logInfo("Registered actor on port " + port) @@ -32,22 +32,20 @@ extends DaemonActor with Logging { } } -object MapOutputTracker extends Logging { +class MapOutputTracker(isMaster: Boolean) extends Logging { var trackerActor: AbstractActor = null - private val serverUris = new ConcurrentHashMap[Int, Array[String]] - - def initialize(isMaster: Boolean) { - if (isMaster) { - val tracker = new MapOutputTracker(serverUris) - tracker.start - trackerActor = tracker - } else { - val host = System.getProperty("spark.master.host") - val port = System.getProperty("spark.master.port").toInt - trackerActor = RemoteActor.select(Node(host, port), 'MapOutputTracker) - } + if (isMaster) { + val tracker = new MapOutputTrackerActor(serverUris) + tracker.start + trackerActor = tracker + } else { + val host = System.getProperty("spark.master.host") + val port = System.getProperty("spark.master.port").toInt + trackerActor = RemoteActor.select(Node(host, port), 'MapOutputTracker) } + + private val serverUris = new ConcurrentHashMap[Int, Array[String]] def registerMapOutput(shuffleId: Int, numMaps: Int, mapId: Int, serverUri: String) { var array = serverUris.get(shuffleId) @@ -62,7 +60,6 @@ object MapOutputTracker extends Logging { serverUris.put(shuffleId, Array[String]() ++ locs) } - // Remembers which map output locations are currently being fetched val fetching = new HashSet[Int] diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 40eb7967ec6defef8612c684001cea6982cc6490..6accd5e356d329d264c07d35f656a15d98063b1a 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -43,7 +43,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) { // Read this RDD; will read from cache if applicable, or otherwise compute final def iterator(split: Split): Iterator[T] = { if (shouldCache) { - RDDCache.getOrCompute[T](this, split) + SparkEnv.get.cacheTracker.getOrCompute[T](this, split) } else { compute(split) } diff --git a/core/src/main/scala/spark/Serializer.scala b/core/src/main/scala/spark/Serializer.scala index a182f6bddc75c1a806dc0c1686378d17f4caa869..cfc6d978bce818d467fd851e8dbe35152c3c99dc 100644 --- a/core/src/main/scala/spark/Serializer.scala +++ b/core/src/main/scala/spark/Serializer.scala @@ -2,39 +2,38 @@ package spark import java.io.{InputStream, OutputStream} -trait SerializationStream { - def writeObject[T](t: T): Unit - def flush(): Unit - def close(): Unit -} - -trait DeserializationStream { - def readObject[T](): T - def close(): Unit +/** + * A serializer. Because some serialization libraries are not thread safe, + * this class is used to create SerializerInstances that do the actual + * serialization. + */ +trait Serializer { + def newInstance(): SerializerInstance } -trait Serializer { +/** + * An instance of the serializer, for use by one thread at a time. + */ +trait SerializerInstance { def serialize[T](t: T): Array[Byte] def deserialize[T](bytes: Array[Byte]): T def outputStream(s: OutputStream): SerializationStream def inputStream(s: InputStream): DeserializationStream } -trait SerializationStrategy { - def newSerializer(): Serializer +/** + * A stream for writing serialized objects. + */ +trait SerializationStream { + def writeObject[T](t: T): Unit + def flush(): Unit + def close(): Unit } -object Serializer { - var strat: SerializationStrategy = null - - def initialize() { - val cls = System.getProperty("spark.serialization", - "spark.JavaSerialization") - strat = Class.forName(cls).newInstance().asInstanceOf[SerializationStrategy] - } - - // Return a serializer ** for use by a single thread ** - def newInstance(): Serializer = { - strat.newSerializer() - } +/** + * A stream for reading serialized objects. + */ +trait DeserializationStream { + def readObject[T](): T + def close(): Unit } diff --git a/core/src/main/scala/spark/SerializingCache.scala b/core/src/main/scala/spark/SerializingCache.scala index cbc64736e6f494ac757279f87cfcb259b976b489..2c1f96a7001dd49108a0669554f96eb1fc445f07 100644 --- a/core/src/main/scala/spark/SerializingCache.scala +++ b/core/src/main/scala/spark/SerializingCache.scala @@ -10,14 +10,14 @@ class SerializingCache extends Cache with Logging { val bmc = new BoundedMemoryCache override def put(key: Any, value: Any) { - val ser = Serializer.newInstance() + val ser = SparkEnv.get.serializer.newInstance() bmc.put(key, ser.serialize(value)) } override def get(key: Any): Any = { val bytes = bmc.get(key) if (bytes != null) { - val ser = Serializer.newInstance() + val ser = SparkEnv.get.serializer.newInstance() return ser.deserialize(bytes.asInstanceOf[Array[Byte]]) } else { return null diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala index 683df12019199af0ca6b75ba422118efa00e4a44..f730f0580e5f87d534417b4393a11dd47f9d97cb 100644 --- a/core/src/main/scala/spark/ShuffledRDD.scala +++ b/core/src/main/scala/spark/ShuffledRDD.scala @@ -33,7 +33,7 @@ extends RDD[(K, C)](parent.context) { val shuffleId = dep.shuffleId val splitId = split.index val splitsByUri = new HashMap[String, ArrayBuffer[Int]] - val serverUris = MapOutputTracker.getServerUris(shuffleId) + val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId) for ((serverUri, index) <- serverUris.zipWithIndex) { splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index } @@ -58,4 +58,4 @@ extends RDD[(K, C)](parent.context) { } combiners.iterator } -} \ No newline at end of file +} diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index dc6964e14bd28443508fc667b7fbd8652a6ea154..c1807de0ef91d7dd6c588876fd5ac78a5b992f62 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -20,7 +20,13 @@ extends Logging { System.setProperty("spark.master.host", Utils.localHostName) if (System.getProperty("spark.master.port") == null) System.setProperty("spark.master.port", "50501") + + // Create the Spark execution environment (cache, map output tracker, etc) + val env = SparkEnv.createFromSystemProperties(true) + SparkEnv.set(env) + Broadcast.initialize(true) + // Create and start the scheduler private var scheduler: Scheduler = { // Regular expression used for local[N] master format val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r @@ -34,16 +40,9 @@ extends Logging { new MesosScheduler(this, master, frameworkName) } } + scheduler.start() private val isLocal = scheduler.isInstanceOf[LocalScheduler] - - // Start the scheduler, the cache and the broadcast system - scheduler.start() - Cache.initialize() - Serializer.initialize() - Broadcast.initialize(true) - MapOutputTracker.initialize(true) - RDDCache.initialize(true) // Methods for creating RDDs @@ -122,8 +121,9 @@ extends Logging { scheduler.stop() scheduler = null // TODO: Broadcast.stop(), Cache.stop()? - MapOutputTracker.stop() - RDDCache.stop() + env.mapOutputTracker.stop() + env.cacheTracker.stop() + SparkEnv.set(null) } // Wait for the scheduler to be registered diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala new file mode 100644 index 0000000000000000000000000000000000000000..1bfd0172d7438b79acd5707272d89ed6674317f5 --- /dev/null +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -0,0 +1,36 @@ +package spark + +class SparkEnv ( + val cache: Cache, + val serializer: Serializer, + val cacheTracker: CacheTracker, + val mapOutputTracker: MapOutputTracker +) + +object SparkEnv { + private val env = new ThreadLocal[SparkEnv] + + def set(e: SparkEnv) { + env.set(e) + } + + def get: SparkEnv = { + env.get() + } + + def createFromSystemProperties(isMaster: Boolean): SparkEnv = { + val cacheClass = System.getProperty("spark.cache.class", + "spark.SoftReferenceCache") + val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache] + + val serClass = System.getProperty("spark.serializer", + "spark.JavaSerializer") + val ser = Class.forName(serClass).newInstance().asInstanceOf[Serializer] + + val cacheTracker = new CacheTracker(isMaster, cache) + + val mapOutputTracker = new MapOutputTracker(isMaster) + + new SparkEnv(cache, ser, cacheTracker, mapOutputTracker) + } +}