diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala index ed8ace3a572ce092b45ea304986c67e54a22edee..8ce7abd03f6bd2fd5fa684cb0fea9b3cb4fd4b61 100644 --- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala +++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala @@ -11,7 +11,6 @@ import scala.xml.{XML,NodeSeq} import scala.collection.mutable.ArrayBuffer import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} -import java.nio.ByteBuffer object WikipediaPageRankStandalone { def main(args: Array[String]) { @@ -118,23 +117,23 @@ class WPRSerializer extends spark.Serializer { } class WPRSerializerInstance extends SerializerInstance { - def serialize[T](t: T): ByteBuffer = { + def serialize[T](t: T): Array[Byte] = { throw new UnsupportedOperationException() } - def deserialize[T](bytes: ByteBuffer): T = { + def deserialize[T](bytes: Array[Byte]): T = { throw new UnsupportedOperationException() } - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { + def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { throw new UnsupportedOperationException() } - def serializeStream(s: OutputStream): SerializationStream = { + def outputStream(s: OutputStream): SerializationStream = { new WPRSerializationStream(s) } - def deserializeStream(s: InputStream): DeserializationStream = { + def inputStream(s: InputStream): DeserializationStream = { new WPRDeserializationStream(s) } } diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala deleted file mode 100644 index e00a0d80fa25a15e4bf884912613566acba5ab63..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala +++ /dev/null @@ -1,70 +0,0 @@ -package spark - -import java.io.EOFException -import java.net.URL - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import spark.storage.BlockException -import spark.storage.BlockManagerId - -import it.unimi.dsi.fastutil.io.FastBufferedInputStream - - -class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { - def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) { - logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) - val ser = SparkEnv.get.serializer.newInstance() - val blockManager = SparkEnv.get.blockManager - - val startTime = System.currentTimeMillis - val addresses = SparkEnv.get.mapOutputTracker.getServerAddresses(shuffleId) - logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( - shuffleId, reduceId, System.currentTimeMillis - startTime)) - - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[Int]] - for ((address, index) <- addresses.zipWithIndex) { - splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += index - } - - val blocksByAddress: Seq[(BlockManagerId, Seq[String])] = splitsByAddress.toSeq.map { - case (address, splits) => - (address, splits.map(i => "shuffleid_%d_%d_%d".format(shuffleId, i, reduceId))) - } - - try { - val blockOptions = blockManager.get(blocksByAddress) - logDebug("Fetching map output blocks for shuffle %d, reduce %d took %d ms".format( - shuffleId, reduceId, System.currentTimeMillis - startTime)) - blockOptions.foreach(x => { - val (blockId, blockOption) = x - blockOption match { - case Some(block) => { - val values = block.asInstanceOf[Iterator[Any]] - for(value <- values) { - val v = value.asInstanceOf[(K, V)] - func(v._1, v._2) - } - } - case None => { - throw new BlockException(blockId, "Did not get block " + blockId) - } - } - }) - } catch { - case be: BlockException => { - val regex = "shuffledid_([0-9]*)_([0-9]*)_([0-9]]*)".r - be.blockId match { - case regex(sId, mId, rId) => { - val address = addresses(mId.toInt) - throw new FetchFailedException(address, sId.toInt, mId.toInt, rId.toInt, be) - } - case _ => { - throw be - } - } - } - } - } -} diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala index fa5dcee7bbf0c4cd66a1d2f0bd363799e4c9eaff..1162e34ab03340c763e943b696a611ba9cb5d8d8 100644 --- a/core/src/main/scala/spark/BoundedMemoryCache.scala +++ b/core/src/main/scala/spark/BoundedMemoryCache.scala @@ -90,8 +90,7 @@ class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging { protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) { logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size)) - // TODO: remove BoundedMemoryCache - SparkEnv.get.cacheTracker.dropEntry(datasetId.asInstanceOf[(Int, Int)]._2, partition) + SparkEnv.get.cacheTracker.dropEntry(datasetId, partition) } } diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index 64b4af0ae20e327b90abd36df6ea9a33969a64ed..4867829c17ac6519dba55aed2a47af78f54fe85f 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -1,17 +1,11 @@ package spark -import akka.actor._ -import akka.actor.Actor -import akka.actor.Actor._ -import akka.util.duration._ - -import scala.collection.mutable.ArrayBuffer +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet -import spark.storage.BlockManager -import spark.storage.StorageLevel - sealed trait CacheTrackerMessage case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L) extends CacheTrackerMessage @@ -24,8 +18,8 @@ case object GetCacheStatus extends CacheTrackerMessage case object GetCacheLocations extends CacheTrackerMessage case object StopCacheTracker extends CacheTrackerMessage -class CacheTrackerActor extends Actor with Logging { - // TODO: Should probably store (String, CacheType) tuples + +class CacheTrackerActor extends DaemonActor with Logging { private val locs = new HashMap[Int, Array[List[String]]] /** @@ -34,93 +28,109 @@ class CacheTrackerActor extends Actor with Logging { private val slaveCapacity = new HashMap[String, Long] private val slaveUsage = new HashMap[String, Long] + // TODO: Should probably store (String, CacheType) tuples + private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host) - def receive = { - case SlaveCacheStarted(host: String, size: Long) => - logInfo("Started slave cache (size %s) on %s".format( - Utils.memoryBytesToString(size), host)) - slaveCapacity.put(host, size) - slaveUsage.put(host, 0) - self.reply(true) - - case RegisterRDD(rddId: Int, numPartitions: Int) => - logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions") - locs(rddId) = Array.fill[List[String]](numPartitions)(Nil) - self.reply(true) - - case AddedToCache(rddId, partition, host, size) => - slaveUsage.put(host, getCacheUsage(host) + size) - logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format( - rddId, partition, host, Utils.memoryBytesToString(size), - Utils.memoryBytesToString(getCacheAvailable(host)))) - locs(rddId)(partition) = host :: locs(rddId)(partition) - self.reply(true) + def act() { + val port = System.getProperty("spark.master.port").toInt + RemoteActor.alive(port) + RemoteActor.register('CacheTracker, self) + logInfo("Registered actor on port " + port) + + loop { + react { + case SlaveCacheStarted(host: String, size: Long) => + logInfo("Started slave cache (size %s) on %s".format( + Utils.memoryBytesToString(size), host)) + slaveCapacity.put(host, size) + slaveUsage.put(host, 0) + reply('OK) + + case RegisterRDD(rddId: Int, numPartitions: Int) => + logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions") + locs(rddId) = Array.fill[List[String]](numPartitions)(Nil) + reply('OK) + + case AddedToCache(rddId, partition, host, size) => + if (size > 0) { + slaveUsage.put(host, getCacheUsage(host) + size) + logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format( + rddId, partition, host, Utils.memoryBytesToString(size), + Utils.memoryBytesToString(getCacheAvailable(host)))) + } else { + logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host)) + } + locs(rddId)(partition) = host :: locs(rddId)(partition) + reply('OK) + + case DroppedFromCache(rddId, partition, host, size) => + if (size > 0) { + logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format( + rddId, partition, host, Utils.memoryBytesToString(size), + Utils.memoryBytesToString(getCacheAvailable(host)))) + slaveUsage.put(host, getCacheUsage(host) - size) + + // Do a sanity check to make sure usage is greater than 0. + val usage = getCacheUsage(host) + if (usage < 0) { + logError("Cache usage on %s is negative (%d)".format(host, usage)) + } + } else { + logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host)) + } + locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host) + reply('OK) - case DroppedFromCache(rddId, partition, host, size) => - logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format( - rddId, partition, host, Utils.memoryBytesToString(size), - Utils.memoryBytesToString(getCacheAvailable(host)))) - slaveUsage.put(host, getCacheUsage(host) - size) - // Do a sanity check to make sure usage is greater than 0. - val usage = getCacheUsage(host) - if (usage < 0) { - logError("Cache usage on %s is negative (%d)".format(host, usage)) + case MemoryCacheLost(host) => + logInfo("Memory cache lost on " + host) + // TODO: Drop host from the memory locations list of all RDDs + + case GetCacheLocations => + logInfo("Asked for current cache locations") + reply(locs.map{case (rrdId, array) => (rrdId -> array.clone())}) + + case GetCacheStatus => + val status = slaveCapacity.map { case (host,capacity) => + (host, capacity, getCacheUsage(host)) + }.toSeq + reply(status) + + case StopCacheTracker => + reply('OK) + exit() } - locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host) - self.reply(true) + } + } +} - case MemoryCacheLost(host) => - logInfo("Memory cache lost on " + host) - for ((id, locations) <- locs) { - for (i <- 0 until locations.length) { - locations(i) = locations(i).filterNot(_ == host) - } - } - self.reply(true) - case GetCacheLocations => - logInfo("Asked for current cache locations") - self.reply(locs.map{case (rrdId, array) => (rrdId -> array.clone())}) +class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging { + // Tracker actor on the master, or remote reference to it on workers + var trackerActor: AbstractActor = null - case GetCacheStatus => - val status = slaveCapacity.map { case (host, capacity) => - (host, capacity, getCacheUsage(host)) - }.toSeq - self.reply(status) + val registeredRddIds = new HashSet[Int] - case StopCacheTracker => - logInfo("CacheTrackerActor Server stopped!") - self.reply(true) - self.exit() - } -} + // Stores map results for various splits locally + val cache = theCache.newKeySpace() -class CacheTracker(isMaster: Boolean, blockManager: BlockManager) extends Logging { - // Tracker actor on the master, or remote reference to it on workers - val ip: String = System.getProperty("spark.master.host", "localhost") - val port: Int = System.getProperty("spark.master.port", "7077").toInt - val aName: String = "CacheTracker" - if (isMaster) { - } - - var trackerActor: ActorRef = if (isMaster) { - val actor = actorOf(new CacheTrackerActor) - remote.register(aName, actor) - actor.start() - logInfo("Registered CacheTrackerActor actor @ " + ip + ":" + port) - actor + val tracker = new CacheTrackerActor + tracker.start() + trackerActor = tracker } else { - remote.actorFor(aName, ip, port) + val host = System.getProperty("spark.master.host") + val port = System.getProperty("spark.master.port").toInt + trackerActor = RemoteActor.select(Node(host, port), 'CacheTracker) } - val registeredRddIds = new HashSet[Int] + // Report the cache being started. + trackerActor !? SlaveCacheStarted(Utils.getHost, cache.getCapacity) // Remembers which splits are currently being loaded (on worker nodes) - val loading = new HashSet[String] + val loading = new HashSet[(Int, Int)] // Registers an RDD (on master only) def registerRDD(rddId: Int, numPartitions: Int) { @@ -128,33 +138,24 @@ class CacheTracker(isMaster: Boolean, blockManager: BlockManager) extends Loggin if (!registeredRddIds.contains(rddId)) { logInfo("Registering RDD ID " + rddId + " with cache") registeredRddIds += rddId - (trackerActor ? RegisterRDD(rddId, numPartitions)).as[Any] match { - case Some(true) => - logInfo("CacheTracker registerRDD " + RegisterRDD(rddId, numPartitions) + " successfully.") - case Some(oops) => - logError("CacheTracker registerRDD" + RegisterRDD(rddId, numPartitions) + " failed: " + oops) - case None => - logError("CacheTracker registerRDD None. " + RegisterRDD(rddId, numPartitions)) - throw new SparkException("Internal error: CacheTracker registerRDD None.") - } + trackerActor !? RegisterRDD(rddId, numPartitions) } } } - - // For BlockManager.scala only - def cacheLost(host: String) { - (trackerActor ? MemoryCacheLost(host)).as[Any] match { - case Some(true) => - logInfo("CacheTracker successfully removed entries on " + host) - case _ => - logError("CacheTracker did not reply to MemoryCacheLost") + + // Get a snapshot of the currently known locations + def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = { + (trackerActor !? GetCacheLocations) match { + case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]] + + case _ => throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap") } } // Get the usage status of slave caches. Each tuple in the returned sequence // is in the form of (host name, capacity, usage). def getCacheStatus(): Seq[(String, Long, Long)] = { - (trackerActor ? GetCacheStatus) match { + (trackerActor !? GetCacheStatus) match { case h: Seq[(String, Long, Long)] => h.asInstanceOf[Seq[(String, Long, Long)]] case _ => @@ -163,94 +164,75 @@ class CacheTracker(isMaster: Boolean, blockManager: BlockManager) extends Loggin } } - // For BlockManager.scala only - def notifyTheCacheTrackerFromBlockManager(t: AddedToCache) { - (trackerActor ? t).as[Any] match { - case Some(true) => - logInfo("CacheTracker notifyTheCacheTrackerFromBlockManager successfully.") - case Some(oops) => - logError("CacheTracker notifyTheCacheTrackerFromBlockManager failed: " + oops) - case None => - logError("CacheTracker notifyTheCacheTrackerFromBlockManager None.") - } - } - - // Get a snapshot of the currently known locations - def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = { - (trackerActor ? GetCacheLocations).as[Any] match { - case Some(h: HashMap[_, _]) => - h.asInstanceOf[HashMap[Int, Array[List[String]]]] - - case _ => - throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap") - } - } - // Gets or computes an RDD split - def getOrCompute[T](rdd: RDD[T], split: Split, storageLevel: StorageLevel): Iterator[T] = { - val key = "rdd:%d:%d".format(rdd.id, split.index) - logInfo("Cache key is " + key) - blockManager.get(key) match { - case Some(cachedValues) => - // Split is in cache, so just return its values - logInfo("Found partition in cache!") - return cachedValues.asInstanceOf[Iterator[T]] - - case None => - // Mark the split as loading (unless someone else marks it first) + def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T]): Iterator[T] = { + logInfo("Looking for RDD partition %d:%d".format(rdd.id, split.index)) + val cachedVal = cache.get(rdd.id, split.index) + if (cachedVal != null) { + // Split is in cache, so just return its values + logInfo("Found partition in cache!") + return cachedVal.asInstanceOf[Array[T]].iterator + } else { + // Mark the split as loading (unless someone else marks it first) + val key = (rdd.id, split.index) + loading.synchronized { + while (loading.contains(key)) { + // Someone else is loading it; let's wait for them + try { loading.wait() } catch { case _ => } + } + // See whether someone else has successfully loaded it. The main way this would fail + // is for the RDD-level cache eviction policy if someone else has loaded the same RDD + // partition but we didn't want to make space for it. However, that case is unlikely + // because it's unlikely that two threads would work on the same RDD partition. One + // downside of the current code is that threads wait serially if this does happen. + val cachedVal = cache.get(rdd.id, split.index) + if (cachedVal != null) { + return cachedVal.asInstanceOf[Array[T]].iterator + } + // Nobody's loading it and it's not in the cache; let's load it ourselves + loading.add(key) + } + // If we got here, we have to load the split + // Tell the master that we're doing so + + // TODO: fetch any remote copy of the split that may be available + logInfo("Computing partition " + split) + var array: Array[T] = null + var putResponse: CachePutResponse = null + try { + array = rdd.compute(split).toArray(m) + putResponse = cache.put(rdd.id, split.index, array) + } finally { + // Tell other threads that we've finished our attempt to load the key (whether or not + // we've actually succeeded to put it in the map) loading.synchronized { - if (loading.contains(key)) { - logInfo("Loading contains " + key + ", waiting...") - while (loading.contains(key)) { - try {loading.wait()} catch {case _ =>} - } - logInfo("Loading no longer contains " + key + ", so returning cached result") - // See whether someone else has successfully loaded it. The main way this would fail - // is for the RDD-level cache eviction policy if someone else has loaded the same RDD - // partition but we didn't want to make space for it. However, that case is unlikely - // because it's unlikely that two threads would work on the same RDD partition. One - // downside of the current code is that threads wait serially if this does happen. - blockManager.get(key) match { - case Some(values) => - return values.asInstanceOf[Iterator[T]] - case None => - logInfo("Whoever was loading " + key + " failed; we'll try it ourselves") - loading.add(key) - } - } else { - loading.add(key) - } + loading.remove(key) + loading.notifyAll() } - // If we got here, we have to load the split - // Tell the master that we're doing so - //val host = System.getProperty("spark.hostname", Utils.localHostName) - //val future = trackerActor !! AddedToCache(rdd.id, split.index, host) - // TODO: fetch any remote copy of the split that may be available - // TODO: also register a listener for when it unloads - logInfo("Computing partition " + split) - try { - val values = new ArrayBuffer[Any] - values ++= rdd.compute(split) - blockManager.put(key, values.iterator, storageLevel, false) - //future.apply() // Wait for the reply from the cache tracker - return values.iterator.asInstanceOf[Iterator[T]] - } finally { - loading.synchronized { - loading.remove(key) - loading.notifyAll() - } + } + + putResponse match { + case CachePutSuccess(size) => { + // Tell the master that we added the entry. Don't return until it + // replies so it can properly schedule future tasks that use this RDD. + trackerActor !? AddedToCache(rdd.id, split.index, Utils.getHost, size) } + case _ => null + } + return array.iterator } } // Called by the Cache to report that an entry has been dropped from it - def dropEntry(rddId: Int, partition: Int) { - //TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here. - trackerActor !! DroppedFromCache(rddId, partition, Utils.localHostName()) + def dropEntry(datasetId: Any, partition: Int) { + datasetId match { + //TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here. + case (cache.keySpaceId, rddId: Int) => trackerActor !! DroppedFromCache(rddId, partition, Utils.getHost) + } } def stop() { - trackerActor !! StopCacheTracker + trackerActor !? StopCacheTracker registeredRddIds.clear() trackerActor = null } diff --git a/core/src/main/scala/spark/CoGroupedRDD.scala b/core/src/main/scala/spark/CoGroupedRDD.scala index 3543c8afa8a081f201f630df60dcb6f915c01115..93f453bc5e4341bcf74de43ee22d332cdeaf4e1a 100644 --- a/core/src/main/scala/spark/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/CoGroupedRDD.scala @@ -22,12 +22,11 @@ class CoGroupAggregator { (b1, b2) => b1 ++ b2 }) with Serializable -class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) +class CoGroupedRDD[K](rdds: Seq[RDD[(_, _)]], part: Partitioner) extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging { val aggr = new CoGroupAggregator - @transient override val dependencies = { val deps = new ArrayBuffer[Dependency[_]] for ((rdd, index) <- rdds.zipWithIndex) { @@ -68,10 +67,9 @@ class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = { val split = s.asInstanceOf[CoGroupSplit] - val numRdds = split.deps.size val map = new HashMap[K, Seq[ArrayBuffer[Any]]] def getSeq(k: K): Seq[ArrayBuffer[Any]] = { - map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any])) + map.getOrElseUpdate(k, Array.fill(rdds.size)(new ArrayBuffer[Any])) } for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, itsSplit) => { diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala new file mode 100644 index 0000000000000000000000000000000000000000..1b4af9d84c6d2159eb05084e2587ddef62a6bed1 --- /dev/null +++ b/core/src/main/scala/spark/DAGScheduler.scala @@ -0,0 +1,374 @@ +package spark + +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map} + +/** + * A task created by the DAG scheduler. Knows its stage ID and map ouput tracker generation. + */ +abstract class DAGTask[T](val runId: Int, val stageId: Int) extends Task[T] { + val gen = SparkEnv.get.mapOutputTracker.getGeneration + override def generation: Option[Long] = Some(gen) +} + +/** + * A completion event passed by the underlying task scheduler to the DAG scheduler. + */ +case class CompletionEvent( + task: DAGTask[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Map[Long, Any]) + +/** + * Various possible reasons why a DAG task ended. The underlying scheduler is supposed to retry + * tasks several times for "ephemeral" failures, and only report back failures that require some + * old stages to be resubmitted, such as shuffle map fetch failures. + */ +sealed trait TaskEndReason +case object Success extends TaskEndReason +case class FetchFailed(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason +case class ExceptionFailure(exception: Throwable) extends TaskEndReason +case class OtherFailure(message: String) extends TaskEndReason + +/** + * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for + * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal + * schedule to run the job. Subclasses only need to implement the code to send a task to the cluster + * and to report fetch failures (the submitTasks method, and code to add CompletionEvents). + */ +private trait DAGScheduler extends Scheduler with Logging { + // Must be implemented by subclasses to start running a set of tasks. The subclass should also + // attempt to run different sets of tasks in the order given by runId (lower values first). + def submitTasks(tasks: Seq[Task[_]], runId: Int): Unit + + // Must be called by subclasses to report task completions or failures. + def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]) { + lock.synchronized { + val dagTask = task.asInstanceOf[DAGTask[_]] + eventQueues.get(dagTask.runId) match { + case Some(queue) => + queue += CompletionEvent(dagTask, reason, result, accumUpdates) + lock.notifyAll() + case None => + logInfo("Ignoring completion event for DAG job " + dagTask.runId + " because it's gone") + } + } + } + + // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; + // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one + // as more failure events come in + val RESUBMIT_TIMEOUT = 2000L + + // The time, in millis, to wake up between polls of the completion queue in order to potentially + // resubmit failed stages + val POLL_TIMEOUT = 500L + + private val lock = new Object // Used for access to the entire DAGScheduler + + private val eventQueues = new HashMap[Int, Queue[CompletionEvent]] // Indexed by run ID + + val nextRunId = new AtomicInteger(0) + + val nextStageId = new AtomicInteger(0) + + val idToStage = new HashMap[Int, Stage] + + val shuffleToMapStage = new HashMap[Int, Stage] + + var cacheLocs = new HashMap[Int, Array[List[String]]] + + val env = SparkEnv.get + val cacheTracker = env.cacheTracker + val mapOutputTracker = env.mapOutputTracker + + def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { + cacheLocs(rdd.id) + } + + def updateCacheLocs() { + cacheLocs = cacheTracker.getLocationsSnapshot() + } + + def getShuffleMapStage(shuf: ShuffleDependency[_,_,_]): Stage = { + shuffleToMapStage.get(shuf.shuffleId) match { + case Some(stage) => stage + case None => + val stage = newStage(shuf.rdd, Some(shuf)) + shuffleToMapStage(shuf.shuffleId) = stage + stage + } + } + + def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = { + // Kind of ugly: need to register RDDs with the cache and map output tracker here + // since we can't do it in the RDD constructor because # of splits is unknown + cacheTracker.registerRDD(rdd.id, rdd.splits.size) + if (shuffleDep != None) { + mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size) + } + val id = nextStageId.getAndIncrement() + val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd)) + idToStage(id) = stage + stage + } + + def getParentStages(rdd: RDD[_]): List[Stage] = { + val parents = new HashSet[Stage] + val visited = new HashSet[RDD[_]] + def visit(r: RDD[_]) { + if (!visited(r)) { + visited += r + // Kind of ugly: need to register RDDs with the cache here since + // we can't do it in its constructor because # of splits is unknown + cacheTracker.registerRDD(r.id, r.splits.size) + for (dep <- r.dependencies) { + dep match { + case shufDep: ShuffleDependency[_,_,_] => + parents += getShuffleMapStage(shufDep) + case _ => + visit(dep.rdd) + } + } + } + } + visit(rdd) + parents.toList + } + + def getMissingParentStages(stage: Stage): List[Stage] = { + val missing = new HashSet[Stage] + val visited = new HashSet[RDD[_]] + def visit(rdd: RDD[_]) { + if (!visited(rdd)) { + visited += rdd + val locs = getCacheLocs(rdd) + for (p <- 0 until rdd.splits.size) { + if (locs(p) == Nil) { + for (dep <- rdd.dependencies) { + dep match { + case shufDep: ShuffleDependency[_,_,_] => + val stage = getShuffleMapStage(shufDep) + if (!stage.isAvailable) { + missing += stage + } + case narrowDep: NarrowDependency[_] => + visit(narrowDep.rdd) + } + } + } + } + } + } + visit(stage.rdd) + missing.toList + } + + override def runJob[T, U]( + finalRdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + allowLocal: Boolean) + (implicit m: ClassManifest[U]): Array[U] = { + lock.synchronized { + val runId = nextRunId.getAndIncrement() + + val outputParts = partitions.toArray + val numOutputParts: Int = partitions.size + val finalStage = newStage(finalRdd, None) + val results = new Array[U](numOutputParts) + val finished = new Array[Boolean](numOutputParts) + var numFinished = 0 + + val waiting = new HashSet[Stage] // stages we need to run whose parents aren't done + val running = new HashSet[Stage] // stages we are running right now + val failed = new HashSet[Stage] // stages that must be resubmitted due to fetch failures + val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // missing tasks from each stage + var lastFetchFailureTime: Long = 0 // used to wait a bit to avoid repeated resubmits + + SparkEnv.set(env) + + updateCacheLocs() + + logInfo("Final stage: " + finalStage) + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + + // Optimization for short actions like first() and take() that can be computed locally + // without shipping tasks to the cluster. + if (allowLocal && finalStage.parents.size == 0 && numOutputParts == 1) { + logInfo("Computing the requested partition locally") + val split = finalRdd.splits(outputParts(0)) + val taskContext = new TaskContext(finalStage.id, outputParts(0), 0) + return Array(func(taskContext, finalRdd.iterator(split))) + } + + // Register the job ID so that we can get completion events for it + eventQueues(runId) = new Queue[CompletionEvent] + + def submitStage(stage: Stage) { + if (!waiting(stage) && !running(stage)) { + val missing = getMissingParentStages(stage) + if (missing == Nil) { + logInfo("Submitting " + stage + ", which has no missing parents") + submitMissingTasks(stage) + running += stage + } else { + for (parent <- missing) { + submitStage(parent) + } + waiting += stage + } + } + } + + def submitMissingTasks(stage: Stage) { + // Get our pending tasks and remember them in our pendingTasks entry + val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet) + var tasks = ArrayBuffer[Task[_]]() + if (stage == finalStage) { + for (id <- 0 until numOutputParts if (!finished(id))) { + val part = outputParts(id) + val locs = getPreferredLocs(finalRdd, part) + tasks += new ResultTask(runId, finalStage.id, finalRdd, func, part, locs, id) + } + } else { + for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { + val locs = getPreferredLocs(stage.rdd, p) + tasks += new ShuffleMapTask(runId, stage.id, stage.rdd, stage.shuffleDep.get, p, locs) + } + } + myPending ++= tasks + submitTasks(tasks, runId) + } + + submitStage(finalStage) + + while (numFinished != numOutputParts) { + val eventOption = waitForEvent(runId, POLL_TIMEOUT) + val time = System.currentTimeMillis // TODO: use a pluggable clock for testability + + // If we got an event off the queue, mark the task done or react to a fetch failure + if (eventOption != None) { + val evt = eventOption.get + val stage = idToStage(evt.task.stageId) + pendingTasks(stage) -= evt.task + if (evt.reason == Success) { + // A task ended + logInfo("Completed " + evt.task) + Accumulators.add(evt.accumUpdates) + evt.task match { + case rt: ResultTask[_, _] => + results(rt.outputId) = evt.result.asInstanceOf[U] + finished(rt.outputId) = true + numFinished += 1 + case smt: ShuffleMapTask => + val stage = idToStage(smt.stageId) + stage.addOutputLoc(smt.partition, evt.result.asInstanceOf[String]) + if (running.contains(stage) && pendingTasks(stage).isEmpty) { + logInfo(stage + " finished; looking for newly runnable stages") + running -= stage + if (stage.shuffleDep != None) { + mapOutputTracker.registerMapOutputs( + stage.shuffleDep.get.shuffleId, + stage.outputLocs.map(_.head).toArray) + } + updateCacheLocs() + val newlyRunnable = new ArrayBuffer[Stage] + for (stage <- waiting if getMissingParentStages(stage) == Nil) { + newlyRunnable += stage + } + waiting --= newlyRunnable + running ++= newlyRunnable + for (stage <- newlyRunnable) { + submitMissingTasks(stage) + } + } + } + } else { + evt.reason match { + case FetchFailed(serverUri, shuffleId, mapId, reduceId) => + // Mark the stage that the reducer was in as unrunnable + val failedStage = idToStage(evt.task.stageId) + running -= failedStage + failed += failedStage + // TODO: Cancel running tasks in the stage + logInfo("Marking " + failedStage + " for resubmision due to a fetch failure") + // Mark the map whose fetch failed as broken in the map stage + val mapStage = shuffleToMapStage(shuffleId) + mapStage.removeOutputLoc(mapId, serverUri) + mapOutputTracker.unregisterMapOutput(shuffleId, mapId, serverUri) + logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission") + failed += mapStage + // Remember that a fetch failed now; this is used to resubmit the broken + // stages later, after a small wait (to give other tasks the chance to fail) + lastFetchFailureTime = time + // TODO: If there are a lot of fetch failures on the same node, maybe mark all + // outputs on the node as dead. + case _ => + // Non-fetch failure -- probably a bug in the job, so bail out + throw new SparkException("Task failed: " + evt.task + ", reason: " + evt.reason) + // TODO: Cancel all tasks that are still running + } + } + } // end if (evt != null) + + // If fetches have failed recently and we've waited for the right timeout, + // resubmit all the failed stages + if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) { + logInfo("Resubmitting failed stages") + updateCacheLocs() + for (stage <- failed) { + submitStage(stage) + } + failed.clear() + } + } + + eventQueues -= runId + return results + } + } + + def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = { + // If the partition is cached, return the cache locations + val cached = getCacheLocs(rdd)(partition) + if (cached != Nil) { + return cached + } + // If the RDD has some placement preferences (as is the case for input RDDs), get those + val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList + if (rddPrefs != Nil) { + return rddPrefs + } + // If the RDD has narrow dependencies, pick the first partition of the first narrow dep + // that has any placement preferences. Ideally we would choose based on transfer sizes, + // but this will do for now. + rdd.dependencies.foreach(_ match { + case n: NarrowDependency[_] => + for (inPart <- n.getParents(partition)) { + val locs = getPreferredLocs(n.rdd, inPart) + if (locs != Nil) + return locs; + } + case _ => + }) + return Nil + } + + // Assumes that lock is held on entrance, but will release it to wait for the next event. + def waitForEvent(runId: Int, timeout: Long): Option[CompletionEvent] = { + val endTime = System.currentTimeMillis() + timeout // TODO: Use pluggable clock for testing + while (eventQueues(runId).isEmpty) { + val time = System.currentTimeMillis() + if (time >= endTime) { + return None + } else { + lock.wait(endTime - time) + } + } + return Some(eventQueues(runId).dequeue()) + } +} diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala index c0ff94acc6266b3e25e1988d700680100affec24..d93c84924a5038fb202157b907092591b1343ac8 100644 --- a/core/src/main/scala/spark/Dependency.scala +++ b/core/src/main/scala/spark/Dependency.scala @@ -8,7 +8,7 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd, false) { class ShuffleDependency[K, V, C]( val shuffleId: Int, - @transient rdd: RDD[(K, V)], + rdd: RDD[(K, V)], val aggregator: Aggregator[K, V, C], val partitioner: Partitioner) extends Dependency(rdd, true) diff --git a/core/src/main/scala/spark/DiskSpillingCache.scala b/core/src/main/scala/spark/DiskSpillingCache.scala new file mode 100644 index 0000000000000000000000000000000000000000..e11466eb64eec01e923bde295867653d88bb7706 --- /dev/null +++ b/core/src/main/scala/spark/DiskSpillingCache.scala @@ -0,0 +1,75 @@ +package spark + +import java.io.File +import java.io.{FileOutputStream,FileInputStream} +import java.io.IOException +import java.util.LinkedHashMap +import java.util.UUID + +// TODO: cache into a separate directory using Utils.createTempDir +// TODO: clean up disk cache afterwards +class DiskSpillingCache extends BoundedMemoryCache { + private val diskMap = new LinkedHashMap[(Any, Int), File](32, 0.75f, true) + + override def get(datasetId: Any, partition: Int): Any = { + synchronized { + val ser = SparkEnv.get.serializer.newInstance() + super.get(datasetId, partition) match { + case bytes: Any => // found in memory + ser.deserialize(bytes.asInstanceOf[Array[Byte]]) + + case _ => diskMap.get((datasetId, partition)) match { + case file: Any => // found on disk + try { + val startTime = System.currentTimeMillis + val bytes = new Array[Byte](file.length.toInt) + new FileInputStream(file).read(bytes) + val timeTaken = System.currentTimeMillis - startTime + logInfo("Reading key (%s, %d) of size %d bytes from disk took %d ms".format( + datasetId, partition, file.length, timeTaken)) + super.put(datasetId, partition, bytes) + ser.deserialize(bytes.asInstanceOf[Array[Byte]]) + } catch { + case e: IOException => + logWarning("Failed to read key (%s, %d) from disk at %s: %s".format( + datasetId, partition, file.getPath(), e.getMessage())) + diskMap.remove((datasetId, partition)) // remove dead entry + null + } + + case _ => // not found + null + } + } + } + } + + override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = { + var ser = SparkEnv.get.serializer.newInstance() + super.put(datasetId, partition, ser.serialize(value)) + } + + /** + * Spill the given entry to disk. Assumes that a lock is held on the + * DiskSpillingCache. Assumes that entry.value is a byte array. + */ + override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) { + logInfo("Spilling key (%s, %d) of size %d to make space".format( + datasetId, partition, entry.size)) + val cacheDir = System.getProperty( + "spark.diskSpillingCache.cacheDir", + System.getProperty("java.io.tmpdir")) + val file = new File(cacheDir, "spark-dsc-" + UUID.randomUUID.toString) + try { + val stream = new FileOutputStream(file) + stream.write(entry.value.asInstanceOf[Array[Byte]]) + stream.close() + diskMap.put((datasetId, partition), file) + } catch { + case e: IOException => + logWarning("Failed to spill key (%s, %d) to disk at %s: %s".format( + datasetId, partition, file.getPath(), e.getMessage())) + // Do nothing and let the entry be discarded + } + } +} diff --git a/core/src/main/scala/spark/DoubleRDDFunctions.scala b/core/src/main/scala/spark/DoubleRDDFunctions.scala deleted file mode 100644 index 1fbf66b7ded3c2e16ed708159be075e12ea0e8e3..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/DoubleRDDFunctions.scala +++ /dev/null @@ -1,39 +0,0 @@ -package spark - -import spark.partial.BoundedDouble -import spark.partial.MeanEvaluator -import spark.partial.PartialResult -import spark.partial.SumEvaluator - -import spark.util.StatCounter - -/** - * Extra functions available on RDDs of Doubles through an implicit conversion. - */ -class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { - def sum(): Double = { - self.reduce(_ + _) - } - - def stats(): StatCounter = { - self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b)) - } - - def mean(): Double = stats().mean - - def variance(): Double = stats().variance - - def stdev(): Double = stats().stdev - - def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { - val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new MeanEvaluator(self.splits.size, confidence) - self.context.runApproximateJob(self, processPartition, evaluator, timeout) - } - - def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { - val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) - val evaluator = new SumEvaluator(self.splits.size, confidence) - self.context.runApproximateJob(self, processPartition, evaluator, timeout) - } -} diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/Executor.scala index af9eb9c878ede5fd39441c413bf72c56524b0b5f..c795b6c3519332a6ea3fe0a9193918a32ec69b99 100644 --- a/core/src/main/scala/spark/Executor.scala +++ b/core/src/main/scala/spark/Executor.scala @@ -10,10 +10,9 @@ import scala.collection.mutable.ArrayBuffer import com.google.protobuf.ByteString import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _} +import org.apache.mesos.Protos._ import spark.broadcast._ -import spark.scheduler._ /** * The Mesos executor for Spark. @@ -30,9 +29,6 @@ class Executor extends org.apache.mesos.Executor with Logging { executorInfo: ExecutorInfo, frameworkInfo: FrameworkInfo, slaveInfo: SlaveInfo) { - // Make sure the local hostname we report matches Mesos's name for this host - Utils.setCustomHostname(slaveInfo.getHostname()) - // Read spark.* system properties from executor arg val props = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) for ((key, value) <- props) { @@ -43,7 +39,7 @@ class Executor extends org.apache.mesos.Executor with Logging { RemoteActor.classLoader = getClass.getClassLoader // Initialize Spark environment (using system properties read above) - env = SparkEnv.createFromSystemProperties(false, false) + env = SparkEnv.createFromSystemProperties(false) SparkEnv.set(env) // Old stuff that isn't yet using env Broadcast.initialize(false) @@ -61,11 +57,11 @@ class Executor extends org.apache.mesos.Executor with Logging { override def reregistered(d: ExecutorDriver, s: SlaveInfo) {} - override def launchTask(d: ExecutorDriver, task: MTaskInfo) { + override def launchTask(d: ExecutorDriver, task: TaskInfo) { threadPool.execute(new TaskRunner(task, d)) } - class TaskRunner(info: MTaskInfo, d: ExecutorDriver) + class TaskRunner(info: TaskInfo, d: ExecutorDriver) extends Runnable { override def run() = { val tid = info.getTaskId.getValue @@ -78,11 +74,11 @@ class Executor extends org.apache.mesos.Executor with Logging { .setState(TaskState.TASK_RUNNING) .build()) try { - SparkEnv.set(env) - Thread.currentThread.setContextClassLoader(classLoader) Accumulators.clear - val task = ser.deserialize[Task[Any]](info.getData.asReadOnlyByteBuffer, classLoader) - env.mapOutputTracker.updateGeneration(task.generation) + val task = ser.deserialize[Task[Any]](info.getData.toByteArray, classLoader) + for (gen <- task.generation) {// Update generation if any is set + env.mapOutputTracker.updateGeneration(gen) + } val value = task.run(tid.toInt) val accumUpdates = Accumulators.values val result = new TaskResult(value, accumUpdates) @@ -109,11 +105,9 @@ class Executor extends org.apache.mesos.Executor with Logging { .setData(ByteString.copyFrom(ser.serialize(reason))) .build()) - // TODO: Should we exit the whole executor here? On the one hand, the failed task may - // have left some weird state around depending on when the exception was thrown, but on - // the other hand, maybe we could detect that when future tasks fail and exit then. + // TODO: Handle errors in tasks less dramatically logError("Exception in task ID " + tid, t) - //System.exit(1) + System.exit(1) } } } diff --git a/core/src/main/scala/spark/FetchFailedException.scala b/core/src/main/scala/spark/FetchFailedException.scala index 55512f4481af231aa13c7c4b629ccdcc6bd556b5..a3c4e7873d7ac5b11468320008252c1d2b84a549 100644 --- a/core/src/main/scala/spark/FetchFailedException.scala +++ b/core/src/main/scala/spark/FetchFailedException.scala @@ -1,9 +1,7 @@ package spark -import spark.storage.BlockManagerId - class FetchFailedException( - val bmAddress: BlockManagerId, + val serverUri: String, val shuffleId: Int, val mapId: Int, val reduceId: Int, @@ -11,10 +9,10 @@ class FetchFailedException( extends Exception { override def getMessage(): String = - "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId) + "Fetch failed: %s %d %d %d".format(serverUri, shuffleId, mapId, reduceId) override def getCause(): Throwable = cause def toTaskEndReason: TaskEndReason = - FetchFailed(bmAddress, shuffleId, mapId, reduceId) + FetchFailed(serverUri, shuffleId, mapId, reduceId) } diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala index ec5c33d1df0f639289401f0c9d5891f9bc57d9be..80f615eeb0a942f183d63128a3adec119101fcbe 100644 --- a/core/src/main/scala/spark/JavaSerializer.scala +++ b/core/src/main/scala/spark/JavaSerializer.scala @@ -1,7 +1,6 @@ package spark import java.io._ -import java.nio.ByteBuffer class JavaSerializationStream(out: OutputStream) extends SerializationStream { val objOut = new ObjectOutputStream(out) @@ -10,11 +9,10 @@ class JavaSerializationStream(out: OutputStream) extends SerializationStream { def close() { objOut.close() } } -class JavaDeserializationStream(in: InputStream, loader: ClassLoader) -extends DeserializationStream { +class JavaDeserializationStream(in: InputStream) extends DeserializationStream { val objIn = new ObjectInputStream(in) { override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, loader) + Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader) } def readObject[T](): T = objIn.readObject().asInstanceOf[T] @@ -22,36 +20,35 @@ extends DeserializationStream { } class JavaSerializerInstance extends SerializerInstance { - def serialize[T](t: T): ByteBuffer = { + def serialize[T](t: T): Array[Byte] = { val bos = new ByteArrayOutputStream() - val out = serializeStream(bos) + val out = outputStream(bos) out.writeObject(t) out.close() - ByteBuffer.wrap(bos.toByteArray) + bos.toByteArray } - def deserialize[T](bytes: ByteBuffer): T = { - val bis = new ByteArrayInputStream(bytes.array()) - val in = deserializeStream(bis) + def deserialize[T](bytes: Array[Byte]): T = { + val bis = new ByteArrayInputStream(bytes) + val in = inputStream(bis) in.readObject().asInstanceOf[T] } - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { - val bis = new ByteArrayInputStream(bytes.array()) - val in = deserializeStream(bis, loader) - in.readObject().asInstanceOf[T] + def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { + val bis = new ByteArrayInputStream(bytes) + val ois = new ObjectInputStream(bis) { + override def resolveClass(desc: ObjectStreamClass) = + Class.forName(desc.getName, false, loader) + } + return ois.readObject.asInstanceOf[T] } - def serializeStream(s: OutputStream): SerializationStream = { + def outputStream(s: OutputStream): SerializationStream = { new JavaSerializationStream(s) } - def deserializeStream(s: InputStream): DeserializationStream = { - new JavaDeserializationStream(s, currentThread.getContextClassLoader) - } - - def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = { - new JavaDeserializationStream(s, loader) + def inputStream(s: InputStream): DeserializationStream = { + new JavaDeserializationStream(s) } } diff --git a/core/src/main/scala/spark/Job.scala b/core/src/main/scala/spark/Job.scala new file mode 100644 index 0000000000000000000000000000000000000000..b7b0361c62c34c0377737b0328fe131a35d772e7 --- /dev/null +++ b/core/src/main/scala/spark/Job.scala @@ -0,0 +1,16 @@ +package spark + +import org.apache.mesos._ +import org.apache.mesos.Protos._ + +/** + * Class representing a parallel job in MesosScheduler. Schedules the job by implementing various + * callbacks. + */ +abstract class Job(val runId: Int, val jobId: Int) { + def slaveOffer(s: Offer, availableCpus: Double): Option[TaskInfo] + + def statusUpdate(t: TaskStatus): Unit + + def error(message: String): Unit +} diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 65d0532bd58dddaea498fd4d9169eecfc4dea470..5693613d6d45804767aeeab09c8990cb43babf43 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -12,8 +12,6 @@ import com.esotericsoftware.kryo.{Serializer => KSerializer} import com.esotericsoftware.kryo.serialize.ClassSerializer import de.javakaffee.kryoserializers.KryoReflectionFactorySupport -import spark.storage._ - /** * Zig-zag encoder used to write object sizes to serialization streams. * Based on Kryo's integer encoder. @@ -66,90 +64,57 @@ object ZigZag { } } -class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream) +class KryoSerializationStream(kryo: Kryo, buf: ByteBuffer, out: OutputStream) extends SerializationStream { val channel = Channels.newChannel(out) def writeObject[T](t: T) { - kryo.writeClassAndObject(threadBuffer, t) - ZigZag.writeInt(threadBuffer.position(), out) - threadBuffer.flip() - channel.write(threadBuffer) - threadBuffer.clear() + kryo.writeClassAndObject(buf, t) + ZigZag.writeInt(buf.position(), out) + buf.flip() + channel.write(buf) + buf.clear() } def flush() { out.flush() } def close() { out.close() } } -class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream) +class KryoDeserializationStream(buf: ObjectBuffer, in: InputStream) extends DeserializationStream { def readObject[T](): T = { val len = ZigZag.readInt(in) - objectBuffer.readClassAndObject(in, len).asInstanceOf[T] + buf.readClassAndObject(in, len).asInstanceOf[T] } def close() { in.close() } } class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { - val kryo = ks.kryo - val threadBuffer = ks.threadBuffer.get() - val objectBuffer = ks.objectBuffer.get() + val buf = ks.threadBuf.get() - def serialize[T](t: T): ByteBuffer = { - // Write it to our thread-local scratch buffer first to figure out the size, then return a new - // ByteBuffer of the appropriate size - threadBuffer.clear() - kryo.writeClassAndObject(threadBuffer, t) - val newBuf = ByteBuffer.allocate(threadBuffer.position) - threadBuffer.flip() - newBuf.put(threadBuffer) - newBuf.flip() - newBuf + def serialize[T](t: T): Array[Byte] = { + buf.writeClassAndObject(t) } - def deserialize[T](bytes: ByteBuffer): T = { - kryo.readClassAndObject(bytes).asInstanceOf[T] + def deserialize[T](bytes: Array[Byte]): T = { + buf.readClassAndObject(bytes).asInstanceOf[T] } - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { - val oldClassLoader = kryo.getClassLoader - kryo.setClassLoader(loader) - val obj = kryo.readClassAndObject(bytes).asInstanceOf[T] - kryo.setClassLoader(oldClassLoader) + def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { + val oldClassLoader = ks.kryo.getClassLoader + ks.kryo.setClassLoader(loader) + val obj = buf.readClassAndObject(bytes).asInstanceOf[T] + ks.kryo.setClassLoader(oldClassLoader) obj } - def serializeStream(s: OutputStream): SerializationStream = { - threadBuffer.clear() - new KryoSerializationStream(kryo, threadBuffer, s) - } - - def deserializeStream(s: InputStream): DeserializationStream = { - new KryoDeserializationStream(objectBuffer, s) + def outputStream(s: OutputStream): SerializationStream = { + new KryoSerializationStream(ks.kryo, ks.threadByteBuf.get(), s) } - override def serializeMany[T](iterator: Iterator[T]): ByteBuffer = { - threadBuffer.clear() - while (iterator.hasNext) { - val element = iterator.next() - // TODO: Do we also want to write the object's size? Doesn't seem necessary. - kryo.writeClassAndObject(threadBuffer, element) - } - val newBuf = ByteBuffer.allocate(threadBuffer.position) - threadBuffer.flip() - newBuf.put(threadBuffer) - newBuf.flip() - newBuf - } - - override def deserializeMany(buffer: ByteBuffer): Iterator[Any] = { - buffer.rewind() - new Iterator[Any] { - override def hasNext: Boolean = buffer.remaining > 0 - override def next(): Any = kryo.readClassAndObject(buffer) - } + def inputStream(s: InputStream): DeserializationStream = { + new KryoDeserializationStream(buf, s) } } @@ -161,17 +126,20 @@ trait KryoRegistrator { class KryoSerializer extends Serializer with Logging { val kryo = createKryo() - val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024 + val bufferSize = + System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 - val objectBuffer = new ThreadLocal[ObjectBuffer] { + val threadBuf = new ThreadLocal[ObjectBuffer] { override def initialValue = new ObjectBuffer(kryo, bufferSize) } - val threadBuffer = new ThreadLocal[ByteBuffer] { + val threadByteBuf = new ThreadLocal[ByteBuffer] { override def initialValue = ByteBuffer.allocate(bufferSize) } def createKryo(): Kryo = { + // This is used so we can serialize/deserialize objects without a zero-arg + // constructor. val kryo = new KryoReflectionFactorySupport() // Register some commonly used classes @@ -180,20 +148,14 @@ class KryoSerializer extends Serializer with Logging { Array(1), Array(1.0), Array(1.0f), Array(1L), Array(""), Array(("", "")), Array(new java.lang.Object), Array(1.toByte), Array(true), Array('c'), // Specialized Tuple2s - ("", ""), ("", 1), (1, 1), (1.0, 1.0), (1L, 1L), + ("", ""), (1, 1), (1.0, 1.0), (1L, 1L), (1, 1.0), (1.0, 1), (1L, 1.0), (1.0, 1L), (1, 1L), (1L, 1), // Scala collections List(1), mutable.ArrayBuffer(1), // Options and Either Some(1), Left(1), Right(1), // Higher-dimensional tuples - (1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1), - None, - ByteBuffer.allocate(1), - StorageLevel.MEMORY_ONLY_DESER, - PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY_DESER), - GotBlock("1", ByteBuffer.allocate(1)), - GetBlock("1") + (1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1) ) for (obj <- toRegister) { kryo.register(obj.getClass) diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/LocalScheduler.scala similarity index 57% rename from core/src/main/scala/spark/scheduler/local/LocalScheduler.scala rename to core/src/main/scala/spark/LocalScheduler.scala index 8339c0ae9025aab942f26f97a078d31235f99613..3910c7b09e915c173c41c8d6b96bc427d2b6aea1 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/LocalScheduler.scala @@ -1,21 +1,16 @@ -package spark.scheduler.local +package spark import java.util.concurrent.Executors import java.util.concurrent.atomic.AtomicInteger -import spark._ -import spark.scheduler._ - /** - * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally - * the scheduler also allows each task to fail up to maxFailures times, which is useful for - * testing fault recovery. + * A simple Scheduler implementation that runs tasks locally in a thread pool. Optionally the + * scheduler also allows each task to fail up to maxFailures times, which is useful for testing + * fault recovery. */ -class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with Logging { +private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGScheduler with Logging { var attemptId = new AtomicInteger(0) var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) - val env = SparkEnv.get - var listener: TaskSchedulerListener = null // TODO: Need to take into account stage priority in scheduling @@ -23,12 +18,7 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with override def waitForRegister() {} - override def setListener(listener: TaskSchedulerListener) { - this.listener = listener - } - - override def submitTasks(taskSet: TaskSet) { - val tasks = taskSet.tasks + override def submitTasks(tasks: Seq[Task[_]], runId: Int) { val failCount = new Array[Int](tasks.size) def submitTask(task: Task[_], idInJob: Int) { @@ -48,14 +38,23 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with // Serialize and deserialize the task so that accumulators are changed to thread-local ones; // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. Accumulators.clear - val bytes = Utils.serialize(task) - logInfo("Size of task " + idInJob + " is " + bytes.size + " bytes") - val deserializedTask = Utils.deserialize[Task[_]]( - bytes, Thread.currentThread.getContextClassLoader) + val ser = SparkEnv.get.closureSerializer.newInstance() + val startTime = System.currentTimeMillis + val bytes = ser.serialize(task) + val timeTaken = System.currentTimeMillis - startTime + logInfo("Size of task %d is %d bytes and took %d ms to serialize".format( + idInJob, bytes.size, timeTaken)) + val deserializedTask = ser.deserialize[Task[_]](bytes, currentThread.getContextClassLoader) val result: Any = deserializedTask.run(attemptId) + + // Serialize and deserialize the result to emulate what the mesos + // executor does. This is useful to catch serialization errors early + // on in development (so when users move their local Spark programs + // to the cluster, they don't get surprised by serialization errors). + val resultToReturn = ser.deserialize[Any](ser.serialize(result)) val accumUpdates = Accumulators.values logInfo("Finished task " + idInJob) - listener.taskEnded(task, Success, result, accumUpdates) + taskEnded(task, Success, resultToReturn, accumUpdates) } catch { case t: Throwable => { logError("Exception in task " + idInJob, t) @@ -65,7 +64,7 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with submitTask(task, idInJob) } else { // TODO: Do something nicer here to return all the way to the user - listener.taskEnded(task, new ExceptionFailure(t), null, null) + taskEnded(task, new ExceptionFailure(t), null, null) } } } diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala index 54bd57f6d3c94d2c17160f3ddaf38b1485f12e50..0d11ab9cbd836a5495f5392b942cb39ffd60e385 100644 --- a/core/src/main/scala/spark/Logging.scala +++ b/core/src/main/scala/spark/Logging.scala @@ -28,11 +28,9 @@ trait Logging { } // Log methods that take only a String - def logInfo(msg: => String) = if (log.isInfoEnabled /*&& msg.contains("job finished in")*/) log.info(msg) + def logInfo(msg: => String) = if (log.isInfoEnabled) log.info(msg) def logDebug(msg: => String) = if (log.isDebugEnabled) log.debug(msg) - - def logTrace(msg: => String) = if (log.isTraceEnabled) log.trace(msg) def logWarning(msg: => String) = if (log.isWarnEnabled) log.warn(msg) @@ -45,9 +43,6 @@ trait Logging { def logDebug(msg: => String, throwable: Throwable) = if (log.isDebugEnabled) log.debug(msg) - def logTrace(msg: => String, throwable: Throwable) = - if (log.isTraceEnabled) log.trace(msg) - def logWarning(msg: => String, throwable: Throwable) = if (log.isWarnEnabled) log.warn(msg, throwable) diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index d938a6eb629867b0a45c9a4abbe24233a5947b5b..a934c5a02fe30706ddb9d6ce7194743c91c40ca1 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -2,80 +2,80 @@ package spark import java.util.concurrent.ConcurrentHashMap -import akka.actor._ -import akka.actor.Actor -import akka.actor.Actor._ -import akka.util.duration._ - +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ import scala.collection.mutable.HashSet -import spark.storage.BlockManagerId - sealed trait MapOutputTrackerMessage case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage case object StopMapOutputTracker extends MapOutputTrackerMessage -class MapOutputTrackerActor(bmAddresses: ConcurrentHashMap[Int, Array[BlockManagerId]]) -extends Actor with Logging { - def receive = { - case GetMapOutputLocations(shuffleId: Int) => - logInfo("Asked to get map output locations for shuffle " + shuffleId) - self.reply(bmAddresses.get(shuffleId)) - - case StopMapOutputTracker => - logInfo("MapOutputTrackerActor stopped!") - self.reply(true) - self.exit() +class MapOutputTrackerActor(serverUris: ConcurrentHashMap[Int, Array[String]]) +extends DaemonActor with Logging { + def act() { + val port = System.getProperty("spark.master.port").toInt + RemoteActor.alive(port) + RemoteActor.register('MapOutputTracker, self) + logInfo("Registered actor on port " + port) + + loop { + react { + case GetMapOutputLocations(shuffleId: Int) => + logInfo("Asked to get map output locations for shuffle " + shuffleId) + reply(serverUris.get(shuffleId)) + + case StopMapOutputTracker => + reply('OK) + exit() + } + } } } class MapOutputTracker(isMaster: Boolean) extends Logging { - val ip: String = System.getProperty("spark.master.host", "localhost") - val port: Int = System.getProperty("spark.master.port", "7077").toInt - val aName: String = "MapOutputTracker" + var trackerActor: AbstractActor = null - private var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]] + private var serverUris = new ConcurrentHashMap[Int, Array[String]] // Incremented every time a fetch fails so that client nodes know to clear // their cache of map output locations if this happens. private var generation: Long = 0 private var generationLock = new java.lang.Object - - var trackerActor: ActorRef = if (isMaster) { - val actor = actorOf(new MapOutputTrackerActor(bmAddresses)) - remote.register(aName, actor) - logInfo("Registered MapOutputTrackerActor actor @ " + ip + ":" + port) - actor + + if (isMaster) { + val tracker = new MapOutputTrackerActor(serverUris) + tracker.start() + trackerActor = tracker } else { - remote.actorFor(aName, ip, port) + val host = System.getProperty("spark.master.host") + val port = System.getProperty("spark.master.port").toInt + trackerActor = RemoteActor.select(Node(host, port), 'MapOutputTracker) } def registerShuffle(shuffleId: Int, numMaps: Int) { - if (bmAddresses.get(shuffleId) != null) { + if (serverUris.get(shuffleId) != null) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } - bmAddresses.put(shuffleId, new Array[BlockManagerId](numMaps)) + serverUris.put(shuffleId, new Array[String](numMaps)) } - def registerMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - var array = bmAddresses.get(shuffleId) + def registerMapOutput(shuffleId: Int, mapId: Int, serverUri: String) { + var array = serverUris.get(shuffleId) array.synchronized { - array(mapId) = bmAddress + array(mapId) = serverUri } } - def registerMapOutputs(shuffleId: Int, locs: Array[BlockManagerId], changeGeneration: Boolean = false) { - bmAddresses.put(shuffleId, Array[BlockManagerId]() ++ locs) - if (changeGeneration) { - incrementGeneration() - } + def registerMapOutputs(shuffleId: Int, locs: Array[String]) { + serverUris.put(shuffleId, Array[String]() ++ locs) } - def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { - var array = bmAddresses.get(shuffleId) + def unregisterMapOutput(shuffleId: Int, mapId: Int, serverUri: String) { + var array = serverUris.get(shuffleId) if (array != null) { array.synchronized { - if (array(mapId) == bmAddress) { + if (array(mapId) == serverUri) { array(mapId) = null } } @@ -89,10 +89,10 @@ class MapOutputTracker(isMaster: Boolean) extends Logging { val fetching = new HashSet[Int] // Called on possibly remote nodes to get the server URIs for a given shuffle - def getServerAddresses(shuffleId: Int): Array[BlockManagerId] = { - val locs = bmAddresses.get(shuffleId) + def getServerUris(shuffleId: Int): Array[String] = { + val locs = serverUris.get(shuffleId) if (locs == null) { - logInfo("Don't have map outputs for shuffe " + shuffleId + ", fetching them") + logInfo("Don't have map outputs for " + shuffleId + ", fetching them") fetching.synchronized { if (fetching.contains(shuffleId)) { // Someone else is fetching it; wait for them to be done @@ -103,17 +103,15 @@ class MapOutputTracker(isMaster: Boolean) extends Logging { case _ => } } - return bmAddresses.get(shuffleId) + return serverUris.get(shuffleId) } else { fetching += shuffleId } } // We won the race to fetch the output locs; do so logInfo("Doing the fetch; tracker actor = " + trackerActor) - val fetched = (trackerActor ? GetMapOutputLocations(shuffleId)).as[Array[BlockManagerId]].get - - logInfo("Got the output locations") - bmAddresses.put(shuffleId, fetched) + val fetched = (trackerActor !? GetMapOutputLocations(shuffleId)).asInstanceOf[Array[String]] + serverUris.put(shuffleId, fetched) fetching.synchronized { fetching -= shuffleId fetching.notifyAll() @@ -123,10 +121,14 @@ class MapOutputTracker(isMaster: Boolean) extends Logging { return locs } } + + def getMapOutputUri(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int): String = { + "%s/shuffle/%s/%s/%s".format(serverUri, shuffleId, mapId, reduceId) + } def stop() { - trackerActor !! StopMapOutputTracker - bmAddresses.clear() + trackerActor !? StopMapOutputTracker + serverUris.clear() trackerActor = null } @@ -151,7 +153,7 @@ class MapOutputTracker(isMaster: Boolean) extends Logging { generationLock.synchronized { if (newGen > generation) { logInfo("Updating generation to " + newGen + " and clearing cache") - bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]] + serverUris = new ConcurrentHashMap[Int, Array[String]] generation = newGen } } diff --git a/core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala b/core/src/main/scala/spark/MesosScheduler.scala similarity index 58% rename from core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala rename to core/src/main/scala/spark/MesosScheduler.scala index f72618c03fc8a1b996f32c86678b19de6ecf31cd..a7711e0d352f04c004aa3030413f1593f4a76849 100644 --- a/core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala +++ b/core/src/main/scala/spark/MesosScheduler.scala @@ -1,4 +1,4 @@ -package spark.scheduler.mesos +package spark import java.io.{File, FileInputStream, FileOutputStream} import java.util.{ArrayList => JArrayList} @@ -17,23 +17,20 @@ import com.google.protobuf.ByteString import org.apache.mesos.{Scheduler => MScheduler} import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _} - -import spark._ -import spark.scheduler._ +import org.apache.mesos.Protos._ /** - * The main TaskScheduler implementation, which runs tasks on Mesos. Clients should first call - * start(), then submit task sets through the runTasks method. + * The main Scheduler implementation, which runs jobs on Mesos. Clients should first call start(), + * then submit tasks through the runTasks method. */ -class MesosScheduler( +private class MesosScheduler( sc: SparkContext, master: String, frameworkName: String) - extends TaskScheduler - with MScheduler + extends MScheduler + with DAGScheduler with Logging { - + // Environment variables to pass to our executors val ENV_VARS_TO_SEND_TO_EXECUTORS = Array( "SPARK_MEM", @@ -52,60 +49,55 @@ class MesosScheduler( } } - // How often to check for speculative tasks - val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong - // Lock used to wait for scheduler to be registered - var isRegistered = false - val registeredLock = new Object() + private var isRegistered = false + private val registeredLock = new Object() - val activeTaskSets = new HashMap[String, TaskSetManager] - var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] + private val activeJobs = new HashMap[Int, Job] + private var activeJobsQueue = new ArrayBuffer[Job] - val taskIdToTaskSetId = new HashMap[String, String] - val taskIdToSlaveId = new HashMap[String, String] - val taskSetTaskIds = new HashMap[String, HashSet[String]] + private val taskIdToJobId = new HashMap[String, Int] + private val taskIdToSlaveId = new HashMap[String, String] + private val jobTasks = new HashMap[Int, HashSet[String]] - // Incrementing Mesos task IDs - var nextTaskId = 0 + // Incrementing job and task IDs + private var nextJobId = 0 + private var nextTaskId = 0 // Driver for talking to Mesos var driver: SchedulerDriver = null - // Which hosts in the cluster are alive (contains hostnames) - val hostsAlive = new HashSet[String] - - // Which slave IDs we have executors on - val slaveIdsWithExecutors = new HashSet[String] - - val slaveIdToHost = new HashMap[String, String] + // Which nodes we have executors on + private val slavesWithExecutors = new HashSet[String] // JAR server, if any JARs were added by the user to the SparkContext var jarServer: HttpServer = null // URIs of JARs to pass to executor var jarUris: String = "" - + // Create an ExecutorInfo for our tasks val executorInfo = createExecutorInfo() - // Listener object to pass upcalls into - var listener: TaskSchedulerListener = null - - val mapOutputTracker = SparkEnv.get.mapOutputTracker - - override def setListener(listener: TaskSchedulerListener) { - this.listener = listener + // Sorts jobs in reverse order of run ID for use in our priority queue (so lower IDs run first) + private val jobOrdering = new Ordering[Job] { + override def compare(j1: Job, j2: Job): Int = j2.runId - j1.runId + } + + def newJobId(): Int = this.synchronized { + val id = nextJobId + nextJobId += 1 + return id } def newTaskId(): TaskID = { - val id = TaskID.newBuilder().setValue("" + nextTaskId).build() - nextTaskId += 1 - return id + val id = "" + nextTaskId; + nextTaskId += 1; + return TaskID.newBuilder().setValue(id).build() } override def start() { - new Thread("MesosScheduler driver") { + new Thread("Spark scheduler") { setDaemon(true) override def run { val sched = MesosScheduler.this @@ -118,27 +110,12 @@ class MesosScheduler( case e: Exception => logError("driver.run() failed", e) } } - }.start() - if (System.getProperty("spark.speculation", "false") == "true") { - new Thread("MesosScheduler speculation check") { - setDaemon(true) - override def run { - waitForRegister() - while (true) { - try { - Thread.sleep(SPECULATION_INTERVAL) - } catch { case e: InterruptedException => {} } - checkSpeculatableTasks() - } - } - }.start() - } + }.start } def createExecutorInfo(): ExecutorInfo = { val sparkHome = sc.getSparkHome match { - case Some(path) => - path + case Some(path) => path case None => throw new SparkException("Spark home is not set; set it through the spark.home system " + "property, the SPARK_HOME environment variable or the SparkContext constructor") @@ -174,26 +151,27 @@ class MesosScheduler( .build() } - def submitTasks(taskSet: TaskSet) { - val tasks = taskSet.tasks - logInfo("Adding task set " + taskSet.id + " with " + tasks.size + " tasks") + def submitTasks(tasks: Seq[Task[_]], runId: Int) { + logInfo("Got a job with " + tasks.size + " tasks") waitForRegister() this.synchronized { - val manager = new TaskSetManager(this, taskSet) - activeTaskSets(taskSet.id) = manager - activeTaskSetsQueue += manager - taskSetTaskIds(taskSet.id) = new HashSet() + val jobId = newJobId() + val myJob = new SimpleJob(this, tasks, runId, jobId) + activeJobs(jobId) = myJob + activeJobsQueue += myJob + logInfo("Adding job with ID " + jobId) + jobTasks(jobId) = HashSet.empty[String] } - reviveOffers(); + driver.reviveOffers(); } - def taskSetFinished(manager: TaskSetManager) { + def jobFinished(job: Job) { this.synchronized { - activeTaskSets -= manager.taskSet.id - activeTaskSetsQueue -= manager - taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskIdToSlaveId --= taskSetTaskIds(manager.taskSet.id) - taskSetTaskIds.remove(manager.taskSet.id) + activeJobs -= job.jobId + activeJobsQueue -= job + taskIdToJobId --= jobTasks(job.jobId) + taskIdToSlaveId --= jobTasks(job.jobId) + jobTasks.remove(job.jobId) } } @@ -218,40 +196,33 @@ class MesosScheduler( override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} /** - * Method called by Mesos to offer resources on slaves. We resond by asking our active task sets - * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that - * tasks are balanced across the cluster. + * Method called by Mesos to offer resources on slaves. We resond by asking our active jobs for + * tasks in FIFO order. We fill each node with tasks in a round-robin manner so that tasks are + * balanced across the cluster. */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { synchronized { - // Mark each slave as alive and remember its hostname - for (o <- offers) { - slaveIdToHost(o.getSlaveId.getValue) = o.getHostname - hostsAlive += o.getHostname - } - // Build a list of tasks to assign to each slave - val tasks = offers.map(o => new JArrayList[MTaskInfo]) + val tasks = offers.map(o => new JArrayList[TaskInfo]) val availableCpus = offers.map(o => getResource(o.getResourcesList(), "cpus")) val enoughMem = offers.map(o => { val mem = getResource(o.getResourcesList(), "mem") val slaveId = o.getSlaveId.getValue - mem >= EXECUTOR_MEMORY || slaveIdsWithExecutors.contains(slaveId) + mem >= EXECUTOR_MEMORY || slavesWithExecutors.contains(slaveId) }) var launchedTask = false - for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) { + for (job <- activeJobsQueue.sorted(jobOrdering)) { do { launchedTask = false for (i <- 0 until offers.size if enoughMem(i)) { - val sid = offers(i).getSlaveId.getValue - val host = offers(i).getHostname - manager.slaveOffer(sid, host, availableCpus(i)) match { + job.slaveOffer(offers(i), availableCpus(i)) match { case Some(task) => tasks(i).add(task) val tid = task.getTaskId.getValue - taskIdToTaskSetId(tid) = manager.taskSet.id - taskSetTaskIds(manager.taskSet.id) += tid + val sid = offers(i).getSlaveId.getValue + taskIdToJobId(tid) = job.jobId + jobTasks(job.jobId) += tid taskIdToSlaveId(tid) = sid - slaveIdsWithExecutors += sid + slavesWithExecutors += sid availableCpus(i) -= getResource(task.getResourcesList(), "cpus") launchedTask = true @@ -285,74 +256,53 @@ class MesosScheduler( } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { - val tid = status.getTaskId.getValue - var taskSetToUpdate: Option[TaskSetManager] = None - var failedHost: Option[String] = None - var taskFailed = false + var jobToUpdate: Option[Job] = None synchronized { try { - if (status.getState == TaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) { + val tid = status.getTaskId.getValue + if (status.getState == TaskState.TASK_LOST + && taskIdToSlaveId.contains(tid)) { // We lost the executor on this slave, so remember that it's gone - val slaveId = taskIdToSlaveId(tid) - val host = slaveIdToHost(slaveId) - if (hostsAlive.contains(host)) { - slaveIdsWithExecutors -= slaveId - hostsAlive -= host - activeTaskSetsQueue.foreach(_.hostLost(host)) - failedHost = Some(host) - } + slavesWithExecutors -= taskIdToSlaveId(tid) } - taskIdToTaskSetId.get(tid) match { - case Some(taskSetId) => - if (activeTaskSets.contains(taskSetId)) { - //activeTaskSets(taskSetId).statusUpdate(status) - taskSetToUpdate = Some(activeTaskSets(taskSetId)) + taskIdToJobId.get(tid) match { + case Some(jobId) => + if (activeJobs.contains(jobId)) { + jobToUpdate = Some(activeJobs(jobId)) } if (isFinished(status.getState)) { - taskIdToTaskSetId.remove(tid) - if (taskSetTaskIds.contains(taskSetId)) { - taskSetTaskIds(taskSetId) -= tid + taskIdToJobId.remove(tid) + if (jobTasks.contains(jobId)) { + jobTasks(jobId) -= tid } taskIdToSlaveId.remove(tid) } - if (status.getState == TaskState.TASK_FAILED) { - taskFailed = true - } case None => - logInfo("Ignoring update from TID " + tid + " because its task set is gone") + logInfo("Ignoring update from TID " + tid + " because its job is gone") } } catch { case e: Exception => logError("Exception in statusUpdate", e) } } - // Update the task set and DAGScheduler without holding a lock on this, because that can deadlock - if (taskSetToUpdate != None) { - taskSetToUpdate.get.statusUpdate(status) - } - if (failedHost != None) { - listener.hostLost(failedHost.get) - reviveOffers(); - } - if (taskFailed) { - // Also revive offers if a task had failed for some reason other than host lost - reviveOffers() + for (j <- jobToUpdate) { + j.statusUpdate(status) } } override def error(d: SchedulerDriver, message: String) { logError("Mesos error: " + message) synchronized { - if (activeTaskSets.size > 0) { - // Have each task set throw a SparkException with the error - for ((taskSetId, manager) <- activeTaskSets) { + if (activeJobs.size > 0) { + // Have each job throw a SparkException with the error + for ((jobId, activeJob) <- activeJobs) { try { - manager.error(message) + activeJob.error(message) } catch { case e: Exception => logError("Exception in error callback", e) } } } else { - // No task sets are active but we still got an error. Just exit since this + // No jobs are active but we still got an error. Just exit since this // must mean the error is during registration. // It might be good to do something smarter here in the future. System.exit(1) @@ -423,68 +373,41 @@ class MesosScheduler( return Utils.serialize(props.toArray) } - override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} + override def frameworkMessage( + d: SchedulerDriver, + e: ExecutorID, + s: SlaveID, + b: Array[Byte]) {} override def slaveLost(d: SchedulerDriver, s: SlaveID) { - var failedHost: Option[String] = None - synchronized { - val slaveId = s.getValue - val host = slaveIdToHost(slaveId) - if (hostsAlive.contains(host)) { - slaveIdsWithExecutors -= slaveId - hostsAlive -= host - activeTaskSetsQueue.foreach(_.hostLost(host)) - failedHost = Some(host) - } - } - if (failedHost != None) { - listener.hostLost(failedHost.get) - reviveOffers(); - } + slavesWithExecutors.remove(s.getValue) } override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) { - logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) - slaveLost(d, s) + slavesWithExecutors.remove(s.getValue) } override def offerRescinded(d: SchedulerDriver, o: OfferID) {} - - // Check for speculatable tasks in all our active jobs. - def checkSpeculatableTasks() { - var shouldRevive = false - synchronized { - for (ts <- activeTaskSetsQueue) { - shouldRevive |= ts.checkSpeculatableTasks() - } - } - if (shouldRevive) { - reviveOffers() - } - } - - def reviveOffers() { - driver.reviveOffers() - } } object MesosScheduler { /** - * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. - * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM + * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. + * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM * environment variable. */ def memoryStringToMb(str: String): Int = { val lower = str.toLowerCase if (lower.endsWith("k")) { - (lower.substring(0, lower.length-1).toLong / 1024).toInt + (lower.substring(0, lower.length - 1).toLong / 1024).toInt } else if (lower.endsWith("m")) { - lower.substring(0, lower.length-1).toInt + lower.substring(0, lower.length - 1).toInt } else if (lower.endsWith("g")) { - lower.substring(0, lower.length-1).toInt * 1024 + lower.substring(0, lower.length - 1).toInt * 1024 } else if (lower.endsWith("t")) { - lower.substring(0, lower.length-1).toInt * 1024 * 1024 - } else {// no suffix, so it's just a number in bytes + lower.substring(0, lower.length - 1).toInt * 1024 * 1024 + } else { + // no suffix, so it's just a number in bytes (lower.toLong / 1024 / 1024).toInt } } diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 270447712b2870a05cefb8dc41d109b5c539095c..e880f9872f23c0ac11161c120761853ba6f3160f 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -4,14 +4,14 @@ import java.io.EOFException import java.net.URL import java.io.ObjectInputStream import java.util.concurrent.atomic.AtomicLong -import java.util.{HashMap => JHashMap} +import java.util.HashSet +import java.util.Random import java.util.Date import java.text.SimpleDateFormat -import scala.collection.Map import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.Map import scala.collection.mutable.HashMap -import scala.collection.JavaConversions._ import org.apache.hadoop.fs.Path import org.apache.hadoop.io.BytesWritable @@ -34,9 +34,7 @@ import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob} import org.apache.hadoop.mapreduce.TaskAttemptID import org.apache.hadoop.mapreduce.TaskAttemptContext -import spark.SparkContext._ -import spark.partial.BoundedDouble -import spark.partial.PartialResult +import SparkContext._ /** * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. @@ -45,6 +43,19 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( self: RDD[(K, V)]) extends Logging with Serializable { + + def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = { + def mergeMaps(m1: HashMap[K, V], m2: HashMap[K, V]): HashMap[K, V] = { + for ((k, v) <- m2) { + m1.get(k) match { + case None => m1(k) = v + case Some(w) => m1(k) = func(w, v) + } + } + return m1 + } + self.map(pair => HashMap(pair)).reduce(mergeMaps) + } def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, @@ -64,39 +75,6 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = { combineByKey[V]((v: V) => v, func, func, partitioner) } - - def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = { - def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = { - val map = new JHashMap[K, V] - for ((k, v) <- iter) { - val old = map.get(k) - map.put(k, if (old == null) v else func(old, v)) - } - Iterator(map) - } - - def mergeMaps(m1: JHashMap[K, V], m2: JHashMap[K, V]): JHashMap[K, V] = { - for ((k, v) <- m2) { - val old = m1.get(k) - m1.put(k, if (old == null) v else func(old, v)) - } - return m1 - } - - self.mapPartitions(reducePartition).reduce(mergeMaps) - } - - // Alias for backwards compatibility - def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func) - - // TODO: This should probably be a distributed version - def countByKey(): Map[K, Long] = self.map(_._1).countByValue() - - // TODO: This should probably be a distributed version - def countByKeyApprox(timeout: Long, confidence: Double = 0.95) - : PartialResult[Map[K, BoundedDouble]] = { - self.map(_._1).countByValueApprox(timeout, confidence) - } def reduceByKey(func: (V, V) => V, numSplits: Int): RDD[(K, V)] = { reduceByKey(new HashPartitioner(numSplits), func) diff --git a/core/src/main/scala/spark/ParallelShuffleFetcher.scala b/core/src/main/scala/spark/ParallelShuffleFetcher.scala new file mode 100644 index 0000000000000000000000000000000000000000..19eb288e8460e599b501091f75b178cea388501a --- /dev/null +++ b/core/src/main/scala/spark/ParallelShuffleFetcher.scala @@ -0,0 +1,119 @@ +package spark + +import java.io.ByteArrayInputStream +import java.io.EOFException +import java.net.URL +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.AtomicReference + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap + +import it.unimi.dsi.fastutil.io.FastBufferedInputStream + + +class ParallelShuffleFetcher extends ShuffleFetcher with Logging { + val parallelFetches = System.getProperty("spark.parallel.fetches", "3").toInt + + def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) { + logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) + + // Figure out a list of input IDs (mapper IDs) for each server + val ser = SparkEnv.get.serializer.newInstance() + val inputsByUri = new HashMap[String, ArrayBuffer[Int]] + val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId) + for ((serverUri, index) <- serverUris.zipWithIndex) { + inputsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index + } + + // Randomize them and put them in a LinkedBlockingQueue + val serverQueue = new LinkedBlockingQueue[(String, ArrayBuffer[Int])] + for (pair <- Utils.randomize(inputsByUri)) { + serverQueue.put(pair) + } + + // Create a queue to hold the fetched data + val resultQueue = new LinkedBlockingQueue[Array[Byte]] + + // Atomic variables to communicate failures and # of fetches done + var failure = new AtomicReference[FetchFailedException](null) + + // Start multiple threads to do the fetching (TODO: may be possible to do it asynchronously) + for (i <- 0 until parallelFetches) { + new Thread("Fetch thread " + i + " for reduce " + reduceId) { + override def run() { + while (true) { + val pair = serverQueue.poll() + if (pair == null) + return + val (serverUri, inputIds) = pair + //logInfo("Pulled out server URI " + serverUri) + for (i <- inputIds) { + if (failure.get != null) + return + val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId) + logInfo("Starting HTTP request for " + url) + try { + val conn = new URL(url).openConnection() + conn.connect() + val len = conn.getContentLength() + if (len == -1) { + throw new SparkException("Content length was not specified by server") + } + val buf = new Array[Byte](len) + val in = new FastBufferedInputStream(conn.getInputStream()) + var pos = 0 + while (pos < len) { + val n = in.read(buf, pos, len-pos) + if (n == -1) { + throw new SparkException("EOF before reading the expected " + len + " bytes") + } else { + pos += n + } + } + // Done reading everything + resultQueue.put(buf) + in.close() + } catch { + case e: Exception => + logError("Fetch failed from " + url, e) + failure.set(new FetchFailedException(serverUri, shuffleId, i, reduceId, e)) + return + } + } + //logInfo("Done with server URI " + serverUri) + } + } + }.start() + } + + // Wait for results from the threads (either a failure or all servers done) + var resultsDone = 0 + var totalResults = inputsByUri.map{case (uri, inputs) => inputs.size}.sum + while (failure.get == null && resultsDone < totalResults) { + try { + val result = resultQueue.poll(100, TimeUnit.MILLISECONDS) + if (result != null) { + //logInfo("Pulled out a result") + val in = ser.inputStream(new ByteArrayInputStream(result)) + try { + while (true) { + val pair = in.readObject().asInstanceOf[(K, V)] + func(pair._1, pair._2) + } + } catch { + case e: EOFException => {} // TODO: cleaner way to detect EOF, such as a sentinel + } + resultsDone += 1 + //logInfo("Results done = " + resultsDone) + } + } catch { case e: InterruptedException => {} } + } + if (failure.get != null) { + throw failure.get + } + } +} diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala index 0e45ebd35cfb3954871216d901823c1f2133ace0..024a4580acce5f4e10ad29c935c47fedaf12e0bb 100644 --- a/core/src/main/scala/spark/Partitioner.scala +++ b/core/src/main/scala/spark/Partitioner.scala @@ -71,3 +71,4 @@ class RangePartitioner[K <% Ordered[K]: ClassManifest, V]( false } } + diff --git a/core/src/main/scala/spark/PipedRDD.scala b/core/src/main/scala/spark/PipedRDD.scala index 9e0a01b5f9fb0357ca5eb0f599ccc2e567aef83b..8a5de3d7e96055ca839b476e661b0a9ed10035ad 100644 --- a/core/src/main/scala/spark/PipedRDD.scala +++ b/core/src/main/scala/spark/PipedRDD.scala @@ -3,7 +3,6 @@ package spark import java.io.PrintWriter import java.util.StringTokenizer -import scala.collection.Map import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import scala.io.Source diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 1191523ccc752bdfffaa4063a252635d9e773f54..4c4b2ee30d604b963ebd17a6dccdf9a6dc70a915 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -4,14 +4,11 @@ import java.io.EOFException import java.net.URL import java.io.ObjectInputStream import java.util.concurrent.atomic.AtomicLong +import java.util.HashSet import java.util.Random import java.util.Date -import java.util.{HashMap => JHashMap} import scala.collection.mutable.ArrayBuffer -import scala.collection.Map -import scala.collection.mutable.HashMap -import scala.collection.JavaConversions.mapAsScalaMap import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.NullWritable @@ -25,14 +22,6 @@ import org.apache.hadoop.mapred.OutputFormat import org.apache.hadoop.mapred.SequenceFileOutputFormat import org.apache.hadoop.mapred.TextOutputFormat -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - -import spark.partial.BoundedDouble -import spark.partial.CountEvaluator -import spark.partial.GroupedCountEvaluator -import spark.partial.PartialResult -import spark.storage.StorageLevel - import SparkContext._ /** @@ -72,32 +61,19 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial // Get a unique ID for this RDD val id = sc.newRddId() - // Variables relating to persistence - private var storageLevel: StorageLevel = StorageLevel.NONE + // Variables relating to caching + private var shouldCache = false - // Change this RDD's storage level - def persist(newLevel: StorageLevel): RDD[T] = { - // TODO: Handle changes of StorageLevel - if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) { - throw new UnsupportedOperationException( - "Cannot change storage level of an RDD after it was already assigned a level") - } - storageLevel = newLevel + // Change this RDD's caching + def cache(): RDD[T] = { + shouldCache = true this } - - // Turn on the default caching level for this RDD - def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY_DESER) - - // Turn on the default caching level for this RDD - def cache(): RDD[T] = persist() - - def getStorageLevel = storageLevel // Read this RDD; will read from cache if applicable, or otherwise compute final def iterator(split: Split): Iterator[T] = { - if (storageLevel != StorageLevel.NONE) { - SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel) + if (shouldCache) { + SparkEnv.get.cacheTracker.getOrCompute[T](this, split) } else { compute(split) } @@ -186,8 +162,6 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial Array.concat(results: _*) } - def toArray(): Array[T] = collect() - def reduce(f: (T, T) => T): T = { val cleanF = sc.clean(f) val reducePartition: Iterator[T] => Option[T] = iter => { @@ -248,67 +222,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial }).sum } - /** - * Approximate version of count() that returns a potentially incomplete result after a timeout. - */ - def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { - val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) => - var result = 0L - while (iter.hasNext) { - result += 1L - iter.next - } - result - } - val evaluator = new CountEvaluator(splits.size, confidence) - sc.runApproximateJob(this, countElements, evaluator, timeout) - } - - /** - * Count elements equal to each value, returning a map of (value, count) pairs. The final combine - * step happens locally on the master, equivalent to running a single reduce task. - * - * TODO: This should perhaps be distributed by default. - */ - def countByValue(): Map[T, Long] = { - def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = { - val map = new OLMap[T] - while (iter.hasNext) { - val v = iter.next() - map.put(v, map.getLong(v) + 1L) - } - Iterator(map) - } - def mergeMaps(m1: OLMap[T], m2: OLMap[T]): OLMap[T] = { - val iter = m2.object2LongEntrySet.fastIterator() - while (iter.hasNext) { - val entry = iter.next() - m1.put(entry.getKey, m1.getLong(entry.getKey) + entry.getLongValue) - } - return m1 - } - val myResult = mapPartitions(countPartition).reduce(mergeMaps) - myResult.asInstanceOf[java.util.Map[T, Long]] // Will be wrapped as a Scala mutable Map - } - - /** - * Approximate version of countByValue(). - */ - def countByValueApprox( - timeout: Long, - confidence: Double = 0.95 - ): PartialResult[Map[T, BoundedDouble]] = { - val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) => - val map = new OLMap[T] - while (iter.hasNext) { - val v = iter.next() - map.put(v, map.getLong(v) + 1L) - } - map - } - val evaluator = new GroupedCountEvaluator[T](splits.size, confidence) - sc.runApproximateJob(this, countPartition, evaluator, timeout) - } + def toArray(): Array[T] = collect() /** * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/ResultTask.scala similarity index 71% rename from core/src/main/scala/spark/scheduler/ResultTask.scala rename to core/src/main/scala/spark/ResultTask.scala index d2fab55b5e8a1aa3af9d0ea4f1f9607449dc5b2a..3952bf85b2cdb89f83aaed4bbca8c73086e08f5d 100644 --- a/core/src/main/scala/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/spark/ResultTask.scala @@ -1,15 +1,14 @@ -package spark.scheduler - -import spark._ +package spark class ResultTask[T, U]( - stageId: Int, - rdd: RDD[T], + runId: Int, + stageId: Int, + rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, - val partition: Int, - @transient locs: Seq[String], + val partition: Int, + locs: Seq[String], val outputId: Int) - extends Task[U](stageId) { + extends DAGTask[U](runId, stageId) { val split = rdd.splits(partition) diff --git a/core/src/main/scala/spark/Scheduler.scala b/core/src/main/scala/spark/Scheduler.scala new file mode 100644 index 0000000000000000000000000000000000000000..6c7e569313b9f6a325b39c1606700715b90c56d9 --- /dev/null +++ b/core/src/main/scala/spark/Scheduler.scala @@ -0,0 +1,27 @@ +package spark + +/** + * Scheduler trait, implemented by both MesosScheduler and LocalScheduler. + */ +private trait Scheduler { + def start() + + // Wait for registration with Mesos. + def waitForRegister() + + /** + * Run a function on some partitions of an RDD, returning an array of results. The allowLocal + * flag specifies whether the scheduler is allowed to run the job on the master machine rather + * than shipping it to the cluster, for actions that create short jobs such as first() and take(). + */ + def runJob[T, U: ClassManifest]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + allowLocal: Boolean): Array[U] + + def stop() + + // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. + def defaultParallelism(): Int +} diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala index 9da73c4b028c8f70a085f5ec22a5891516c575d7..b213ca9dcbde6c70ad6ef03ca4c2150a84a1390f 100644 --- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala @@ -44,7 +44,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla } // TODO: use something like WritableConverter to avoid reflection } - c.asInstanceOf[Class[_ <: Writable]] + c.asInstanceOf[Class[ _ <: Writable]] } def saveAsSequenceFile(path: String) { diff --git a/core/src/main/scala/spark/Serializer.scala b/core/src/main/scala/spark/Serializer.scala index 61a70beaf1fd73566443f8cf7e05c2317eceafd4..2429bbfeb927445e887359465d54a8c8aafcade8 100644 --- a/core/src/main/scala/spark/Serializer.scala +++ b/core/src/main/scala/spark/Serializer.scala @@ -1,12 +1,6 @@ package spark -import java.io.{EOFException, InputStream, OutputStream} -import java.nio.ByteBuffer -import java.nio.channels.Channels - -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream - -import spark.util.ByteBufferInputStream +import java.io.{InputStream, OutputStream} /** * A serializer. Because some serialization libraries are not thread safe, this class is used to @@ -20,31 +14,11 @@ trait Serializer { * An instance of the serializer, for use by one thread at a time. */ trait SerializerInstance { - def serialize[T](t: T): ByteBuffer - - def deserialize[T](bytes: ByteBuffer): T - - def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T - - def serializeStream(s: OutputStream): SerializationStream - - def deserializeStream(s: InputStream): DeserializationStream - - def serializeMany[T](iterator: Iterator[T]): ByteBuffer = { - // Default implementation uses serializeStream - val stream = new FastByteArrayOutputStream() - serializeStream(stream).writeAll(iterator) - val buffer = ByteBuffer.allocate(stream.position.toInt) - buffer.put(stream.array, 0, stream.position.toInt) - buffer.flip() - buffer - } - - def deserializeMany(buffer: ByteBuffer): Iterator[Any] = { - // Default implementation uses deserializeStream - buffer.rewind() - deserializeStream(new ByteBufferInputStream(buffer)).toIterator - } + def serialize[T](t: T): Array[Byte] + def deserialize[T](bytes: Array[Byte]): T + def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T + def outputStream(s: OutputStream): SerializationStream + def inputStream(s: InputStream): DeserializationStream } /** @@ -54,13 +28,6 @@ trait SerializationStream { def writeObject[T](t: T): Unit def flush(): Unit def close(): Unit - - def writeAll[T](iter: Iterator[T]): SerializationStream = { - while (iter.hasNext) { - writeObject(iter.next()) - } - this - } } /** @@ -69,45 +36,4 @@ trait SerializationStream { trait DeserializationStream { def readObject[T](): T def close(): Unit - - /** - * Read the elements of this stream through an iterator. This can only be called once, as - * reading each element will consume data from the input source. - */ - def toIterator: Iterator[Any] = new Iterator[Any] { - var gotNext = false - var finished = false - var nextValue: Any = null - - private def getNext() { - try { - nextValue = readObject[Any]() - } catch { - case eof: EOFException => - finished = true - } - gotNext = true - } - - override def hasNext: Boolean = { - if (!gotNext) { - getNext() - } - if (finished) { - close() - } - !finished - } - - override def next(): Any = { - if (!gotNext) { - getNext() - } - if (finished) { - throw new NoSuchElementException("End of stream") - } - gotNext = false - nextValue - } - } } diff --git a/core/src/main/scala/spark/SerializingCache.scala b/core/src/main/scala/spark/SerializingCache.scala new file mode 100644 index 0000000000000000000000000000000000000000..3d192f24034a0f5a59a7247bf2850ba29efbbc80 --- /dev/null +++ b/core/src/main/scala/spark/SerializingCache.scala @@ -0,0 +1,26 @@ +package spark + +import java.io._ + +/** + * Wrapper around a BoundedMemoryCache that stores serialized objects as byte arrays in order to + * reduce storage cost and GC overhead + */ +class SerializingCache extends Cache with Logging { + val bmc = new BoundedMemoryCache + + override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = { + val ser = SparkEnv.get.serializer.newInstance() + bmc.put(datasetId, partition, ser.serialize(value)) + } + + override def get(datasetId: Any, partition: Int): Any = { + val bytes = bmc.get(datasetId, partition) + if (bytes != null) { + val ser = SparkEnv.get.serializer.newInstance() + return ser.deserialize(bytes.asInstanceOf[Array[Byte]]) + } else { + return null + } + } +} diff --git a/core/src/main/scala/spark/ShuffleMapTask.scala b/core/src/main/scala/spark/ShuffleMapTask.scala new file mode 100644 index 0000000000000000000000000000000000000000..5fc59af06c039f6d74638c63cea13ad824058e40 --- /dev/null +++ b/core/src/main/scala/spark/ShuffleMapTask.scala @@ -0,0 +1,56 @@ +package spark + +import java.io.BufferedOutputStream +import java.io.FileOutputStream +import java.io.ObjectOutputStream +import java.util.{HashMap => JHashMap} + +import it.unimi.dsi.fastutil.io.FastBufferedOutputStream + +class ShuffleMapTask( + runId: Int, + stageId: Int, + rdd: RDD[_], + dep: ShuffleDependency[_,_,_], + val partition: Int, + locs: Seq[String]) + extends DAGTask[String](runId, stageId) + with Logging { + + val split = rdd.splits(partition) + + override def run (attemptId: Int): String = { + val numOutputSplits = dep.partitioner.numPartitions + val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]] + val partitioner = dep.partitioner.asInstanceOf[Partitioner] + val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any]) + for (elem <- rdd.iterator(split)) { + val (k, v) = elem.asInstanceOf[(Any, Any)] + var bucketId = partitioner.getPartition(k) + val bucket = buckets(bucketId) + var existing = bucket.get(k) + if (existing == null) { + bucket.put(k, aggregator.createCombiner(v)) + } else { + bucket.put(k, aggregator.mergeValue(existing, v)) + } + } + val ser = SparkEnv.get.serializer.newInstance() + for (i <- 0 until numOutputSplits) { + val file = SparkEnv.get.shuffleManager.getOutputFile(dep.shuffleId, partition, i) + val out = ser.outputStream(new FastBufferedOutputStream(new FileOutputStream(file))) + val iter = buckets(i).entrySet().iterator() + while (iter.hasNext()) { + val entry = iter.next() + out.writeObject((entry.getKey, entry.getValue)) + } + // TODO: have some kind of EOF marker + out.close() + } + return SparkEnv.get.shuffleManager.getServerUri + } + + override def preferredLocations: Seq[String] = locs + + override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition) +} diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala index 5434197ecad3330fb000b6c5a3238453e16a3b19..5efc8cf50b8ef27154c59a2bf00bd7a3d2220114 100644 --- a/core/src/main/scala/spark/ShuffledRDD.scala +++ b/core/src/main/scala/spark/ShuffledRDD.scala @@ -8,7 +8,7 @@ class ShuffledRDDSplit(val idx: Int) extends Split { } class ShuffledRDD[K, V, C]( - @transient parent: RDD[(K, V)], + parent: RDD[(K, V)], aggregator: Aggregator[K, V, C], part : Partitioner) extends RDD[(K, C)](parent.context) { diff --git a/core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala b/core/src/main/scala/spark/SimpleJob.scala similarity index 50% rename from core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala rename to core/src/main/scala/spark/SimpleJob.scala index 535c17d9d4db78f29acca2b7e458159664a28391..01c7efff1e0af2bed9c6085b0958847968441c37 100644 --- a/core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala +++ b/core/src/main/scala/spark/SimpleJob.scala @@ -1,32 +1,28 @@ -package spark.scheduler.mesos +package spark -import java.util.Arrays import java.util.{HashMap => JHashMap} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.math.max -import scala.math.min import com.google.protobuf.ByteString import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _} - -import spark._ -import spark.scheduler._ +import org.apache.mesos.Protos._ /** - * Schedules the tasks within a single TaskSet in the MesosScheduler. + * A Job that runs a set of tasks with no interdependencies. */ -class TaskSetManager( +class SimpleJob( sched: MesosScheduler, - val taskSet: TaskSet) - extends Logging { + tasksSeq: Seq[Task[_]], + runId: Int, + jobId: Int) + extends Job(runId, jobId) + with Logging { // Maximum time to wait to run a task in a preferred location (in ms) - val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong + val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "5000").toLong // CPUs to request per task val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble @@ -34,20 +30,18 @@ class TaskSetManager( // Maximum times a task is allowed to fail before failing the job val MAX_TASK_FAILURES = 4 - // Quantile of tasks at which to start speculation - val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble - val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble - // Serializer for closures and tasks. val ser = SparkEnv.get.closureSerializer.newInstance() - val priority = taskSet.priority - val tasks = taskSet.tasks + val callingThread = Thread.currentThread + val tasks = tasksSeq.toArray val numTasks = tasks.length - val copiesRunning = new Array[Int](numTasks) + val launched = new Array[Boolean](numTasks) val finished = new Array[Boolean](numTasks) val numFailures = new Array[Int](numTasks) - val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) + val tidToIndex = HashMap[String, Int]() + + var tasksLaunched = 0 var tasksFinished = 0 // Last time when we launched a preferred task (for delay scheduling) @@ -68,13 +62,6 @@ class TaskSetManager( // List containing all pending tasks (also used as a stack, as above) val allPendingTasks = new ArrayBuffer[Int] - // Tasks that can be specualted. Since these will be a small fraction of total - // tasks, we'll just hold them in a HaskSet. - val speculatableTasks = new HashSet[Int] - - // Task index, start and finish time for each task attempt (indexed by task ID) - val taskInfos = new HashMap[String, TaskInfo] - // Did the job fail? var failed = false var causeOfFailure = "" @@ -89,12 +76,6 @@ class TaskSetManager( // exceptions automatically. val recentExceptions = HashMap[String, (Int, Long)]() - // Figure out the current map output tracker generation and set it on all tasks - val generation = sched.mapOutputTracker.getGeneration - for (t <- tasks) { - t.generation = generation - } - // Add all our tasks to the pending lists. We do this in reverse order // of task index so that tasks with low indices get launched first. for (i <- (0 until numTasks).reverse) { @@ -103,7 +84,7 @@ class TaskSetManager( // Add a task to all the pending-task lists that it should be on. def addPendingTask(index: Int) { - val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive + val locations = tasks(index).preferredLocations if (locations.size == 0) { pendingTasksWithNoPrefs += index } else { @@ -129,37 +110,13 @@ class TaskSetManager( while (!list.isEmpty) { val index = list.last list.trimEnd(1) - if (copiesRunning(index) == 0 && !finished(index)) { + if (!launched(index) && !finished(index)) { return Some(index) } } return None } - // Return a speculative task for a given host if any are available. The task should not have an - // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the - // task must have a preference for this host (or no preferred locations at all). - def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = { - speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set - val localTask = speculatableTasks.find { index => - val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive - val attemptLocs = taskAttempts(index).map(_.host) - (locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host) - } - if (localTask != None) { - speculatableTasks -= localTask.get - return localTask - } - if (!localOnly && speculatableTasks.size > 0) { - val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.host).contains(host)) - if (nonLocalTask != None) { - speculatableTasks -= nonLocalTask.get - return nonLocalTask - } - } - return None - } - // Dequeue a pending task for a given node and return its index. // If localOnly is set to false, allow non-local tasks as well. def findTask(host: String, localOnly: Boolean): Option[Int] = { @@ -172,13 +129,10 @@ class TaskSetManager( return noPrefTask } if (!localOnly) { - val nonLocalTask = findTaskFromList(allPendingTasks) - if (nonLocalTask != None) { - return nonLocalTask - } + return findTaskFromList(allPendingTasks) // Look for non-local task + } else { + return None } - // Finally, if all else has failed, find a speculative task - return findSpeculativeTask(host, localOnly) } // Does a host count as a preferred location for a task? This is true if @@ -190,11 +144,11 @@ class TaskSetManager( } // Respond to an offer of a single slave from the scheduler by finding a task - def slaveOffer(slaveId: String, host: String, availableCpus: Double): Option[MTaskInfo] = { - if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { + def slaveOffer(offer: Offer, availableCpus: Double): Option[TaskInfo] = { + if (tasksLaunched < numTasks && availableCpus >= CPUS_PER_TASK) { val time = System.currentTimeMillis - var localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT) - + val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT) + val host = offer.getHostname findTask(host, localOnly) match { case Some(index) => { // Found a task; do some bookkeeping and return a Mesos task for it @@ -202,17 +156,17 @@ class TaskSetManager( val taskId = sched.newTaskId() // Figure out whether this should count as a preferred launch val preferred = isPreferredLocation(task, host) - val prefStr = if (preferred) "preferred" else "non-preferred" - logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( - taskSet.id, index, taskId.getValue, slaveId, host, prefStr)) + val prefStr = if(preferred) "preferred" else "non-preferred" + val message = + "Starting task %d:%d as TID %s on slave %s: %s (%s)".format( + jobId, index, taskId.getValue, offer.getSlaveId.getValue, host, prefStr) + logInfo(message) // Do various bookkeeping - copiesRunning(index) += 1 - val info = new TaskInfo(taskId.getValue, index, time, host) - taskInfos(taskId.getValue) = info - taskAttempts(index) = info :: taskAttempts(index) - if (preferred) { + tidToIndex(taskId.getValue) = index + launched(index) = true + tasksLaunched += 1 + if (preferred) lastPreferredLaunchTime = time - } // Create and return the Mesos task object val cpuRes = Resource.newBuilder() .setName("cpus") @@ -224,13 +178,13 @@ class TaskSetManager( val serializedTask = ser.serialize(task) val timeTaken = System.currentTimeMillis - startTime - logInfo("Serialized task %s:%d as %d bytes in %d ms".format( - taskSet.id, index, serializedTask.limit, timeTaken)) + logInfo("Size of task %d:%d is %d bytes and took %d ms to serialize by %s" + .format(jobId, index, serializedTask.size, timeTaken, ser.getClass.getName)) - val taskName = "task %s:%d".format(taskSet.id, index) - return Some(MTaskInfo.newBuilder() + val taskName = "task %d:%d".format(jobId, index) + return Some(TaskInfo.newBuilder() .setTaskId(taskId) - .setSlaveId(SlaveID.newBuilder().setValue(slaveId)) + .setSlaveId(offer.getSlaveId) .setExecutor(sched.executorInfo) .setName(taskName) .addResources(cpuRes) @@ -259,21 +213,18 @@ class TaskSetManager( def taskFinished(status: TaskStatus) { val tid = status.getTaskId.getValue - val info = taskInfos(tid) - val index = info.index - info.markSuccessful() + val index = tidToIndex(tid) if (!finished(index)) { tasksFinished += 1 - logInfo("Finished TID %s in %d ms (progress: %d/%d)".format( - tid, info.duration, tasksFinished, numTasks)) - // Deserialize task result and pass it to the scheduler - val result = ser.deserialize[TaskResult[_]](status.getData.asReadOnlyByteBuffer) - sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates) + logInfo("Finished TID %s (progress: %d/%d)".format(tid, tasksFinished, numTasks)) + // Deserialize task result + val result = ser.deserialize[TaskResult[_]]( + status.getData.toByteArray, getClass.getClassLoader) + sched.taskEnded(tasks(index), Success, result.value, result.accumUpdates) // Mark finished and stop if we've finished all the tasks finished(index) = true - if (tasksFinished == numTasks) { - sched.taskSetFinished(this) - } + if (tasksFinished == numTasks) + sched.jobFinished(this) } else { logInfo("Ignoring task-finished event for TID " + tid + " because task " + index + " is already finished") @@ -282,29 +233,30 @@ class TaskSetManager( def taskLost(status: TaskStatus) { val tid = status.getTaskId.getValue - val info = taskInfos(tid) - val index = info.index - info.markFailed() + val index = tidToIndex(tid) if (!finished(index)) { - logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) - copiesRunning(index) -= 1 + logInfo("Lost TID %s (task %d:%d)".format(tid, jobId, index)) + launched(index) = false + tasksLaunched -= 1 // Check if the problem is a map output fetch failure. In that case, this // task will never succeed on any node, so tell the scheduler about it. if (status.getData != null && status.getData.size > 0) { - val reason = ser.deserialize[TaskEndReason](status.getData.asReadOnlyByteBuffer) + val reason = ser.deserialize[TaskEndReason]( + status.getData.toByteArray, getClass.getClassLoader) reason match { case fetchFailed: FetchFailed => - logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress) - sched.listener.taskEnded(tasks(index), fetchFailed, null, null) + logInfo("Loss was due to fetch failure from " + fetchFailed.serverUri) + sched.taskEnded(tasks(index), fetchFailed, null, null) finished(index) = true tasksFinished += 1 - sched.taskSetFinished(this) + if (tasksFinished == numTasks) { + sched.jobFinished(this) + } return - case ef: ExceptionFailure => val key = ef.exception.toString val now = System.currentTimeMillis - val (printFull, dupCount) = { + val (printFull, dupCount) = if (recentExceptions.contains(key)) { val (dupCount, printTime) = recentExceptions(key) if (now - printTime > EXCEPTION_PRINT_INTERVAL) { @@ -315,28 +267,32 @@ class TaskSetManager( (false, dupCount + 1) } } else { - recentExceptions(key) = (0, now) + recentExceptions += Tuple(key, (0, now)) (true, 0) } - } + if (printFull) { - val locs = ef.exception.getStackTrace.map(loc => "\tat %s".format(loc.toString)) - logInfo("Loss was due to %s\n%s".format(ef.exception.toString, locs.mkString("\n"))) + val stackTrace = + for (elem <- ef.exception.getStackTrace) + yield "\tat %s".format(elem.toString) + logInfo("Loss was due to %s\n%s".format( + ef.exception.toString, stackTrace.mkString("\n"))) } else { - logInfo("Loss was due to %s [duplicate %d]".format(ef.exception.toString, dupCount)) + logInfo("Loss was due to %s [duplicate %d]".format( + ef.exception.toString, dupCount)) } - case _ => {} } } - // On non-fetch failures, re-enqueue the task as pending for a max number of retries + // On other failures, re-enqueue the task as pending for a max number of retries addPendingTask(index) - // Count failed attempts only on FAILED and LOST state (not on KILLED) - if (status.getState == TaskState.TASK_FAILED || status.getState == TaskState.TASK_LOST) { + // Count attempts only on FAILED and LOST state (not on KILLED) + if (status.getState == TaskState.TASK_FAILED || + status.getState == TaskState.TASK_LOST) { numFailures(index) += 1 if (numFailures(index) > MAX_TASK_FAILURES) { - logError("Task %s:%d failed more than %d times; aborting job".format( - taskSet.id, index, MAX_TASK_FAILURES)) + logError("Task %d:%d failed more than %d times; aborting job".format( + jobId, index, MAX_TASK_FAILURES)) abort("Task %d failed more than %d times".format(index, MAX_TASK_FAILURES)) } } @@ -355,71 +311,6 @@ class TaskSetManager( failed = true causeOfFailure = message // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.taskSetFinished(this) - } - - def hostLost(hostname: String) { - logInfo("Re-queueing tasks for " + hostname) - // If some task has preferred locations only on hostname, put it in the no-prefs list - // to avoid the wait from delay scheduling - for (index <- getPendingTasksForHost(hostname)) { - val newLocs = tasks(index).preferredLocations.toSet & sched.hostsAlive - if (newLocs.isEmpty) { - pendingTasksWithNoPrefs += index - } - } - // Also re-enqueue any tasks that ran on the failed host if this is a shuffle map stage - if (tasks(0).isInstanceOf[ShuffleMapTask]) { - for ((tid, info) <- taskInfos if info.host == hostname) { - val index = taskInfos(tid).index - if (finished(index)) { - finished(index) = false - copiesRunning(index) -= 1 - tasksFinished -= 1 - addPendingTask(index) - // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our - // stage finishes when a total of tasks.size tasks finish. - sched.listener.taskEnded(tasks(index), Resubmitted, null, null) - } - } - } - } - - /** - * Check for tasks to be speculated and return true if there are any. This is called periodically - * by the MesosScheduler. - * - * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that - * we don't scan the whole task set. It might also help to make this sorted by launch time. - */ - def checkSpeculatableTasks(): Boolean = { - // Can't speculate if we only have one task, or if all tasks have finished. - if (numTasks == 1 || tasksFinished == numTasks) { - return false - } - var foundTasks = false - val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt - logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) - if (tasksFinished >= minFinishedForSpeculation) { - val time = System.currentTimeMillis() - val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray - Arrays.sort(durations) - val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1)) - val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) - // TODO: Threshold should also look at standard deviation of task durations and have a lower - // bound based on that. - logDebug("Task length threshold for speculation: " + threshold) - for ((tid, info) <- taskInfos) { - val index = info.index - if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && - !speculatableTasks.contains(index)) { - logInfo("Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( - taskSet.id, index, info.host, threshold)) - speculatableTasks += index - foundTasks = true - } - } - } - return foundTasks + sched.jobFinished(this) } } diff --git a/core/src/main/scala/spark/SimpleShuffleFetcher.scala b/core/src/main/scala/spark/SimpleShuffleFetcher.scala new file mode 100644 index 0000000000000000000000000000000000000000..196c64cf1fb76758c9d1251dc296ddcb58d863cd --- /dev/null +++ b/core/src/main/scala/spark/SimpleShuffleFetcher.scala @@ -0,0 +1,46 @@ +package spark + +import java.io.EOFException +import java.net.URL + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap + +import it.unimi.dsi.fastutil.io.FastBufferedInputStream + +class SimpleShuffleFetcher extends ShuffleFetcher with Logging { + def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) { + logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) + val ser = SparkEnv.get.serializer.newInstance() + val splitsByUri = new HashMap[String, ArrayBuffer[Int]] + val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId) + for ((serverUri, index) <- serverUris.zipWithIndex) { + splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index + } + for ((serverUri, inputIds) <- Utils.randomize(splitsByUri)) { + for (i <- inputIds) { + try { + val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId) + // TODO: multithreaded fetch + // TODO: would be nice to retry multiple times + val inputStream = ser.inputStream( + new FastBufferedInputStream(new URL(url).openStream())) + try { + while (true) { + val pair = inputStream.readObject().asInstanceOf[(K, V)] + func(pair._1, pair._2) + } + } finally { + inputStream.close() + } + } catch { + case e: EOFException => {} // We currently assume EOF means we read the whole thing + case other: Exception => { + logError("Fetch failed", other) + throw new FetchFailedException(serverUri, shuffleId, i, reduceId, other) + } + } + } + } + } +} diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index b43aca2b97facac3b68dd84355469027f3ed78dd..6e019d6e7f10c345bb79a7452124384e46a8c12b 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -3,9 +3,6 @@ package spark import java.io._ import java.util.concurrent.atomic.AtomicInteger -import akka.actor.Actor -import akka.actor.Actor._ - import scala.actors.remote.RemoteActor import scala.collection.mutable.ArrayBuffer @@ -35,17 +32,6 @@ import org.apache.mesos.MesosNativeLibrary import spark.broadcast._ -import spark.partial.ApproximateEvaluator -import spark.partial.PartialResult - -import spark.scheduler.ShuffleMapTask -import spark.scheduler.DAGScheduler -import spark.scheduler.TaskScheduler -import spark.scheduler.local.LocalScheduler -import spark.scheduler.mesos.MesosScheduler -import spark.scheduler.mesos.CoarseMesosScheduler -import spark.storage.BlockManagerMaster - class SparkContext( master: String, frameworkName: String, @@ -68,19 +54,14 @@ class SparkContext( if (RemoteActor.classLoader == null) { RemoteActor.classLoader = getClass.getClassLoader } - - remote.start(System.getProperty("spark.master.host"), - System.getProperty("spark.master.port").toInt) - private val isLocal = master.startsWith("local") // TODO: better check for local - // Create the Spark execution environment (cache, map output tracker, etc) - val env = SparkEnv.createFromSystemProperties(true, isLocal) + val env = SparkEnv.createFromSystemProperties(true) SparkEnv.set(env) Broadcast.initialize(true) // Create and start the scheduler - private var taskScheduler: TaskScheduler = { + private var scheduler: Scheduler = { // Regular expression used for local[N] master format val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r // Regular expression for local[N, maxRetries], used in tests with failing tasks @@ -93,17 +74,13 @@ class SparkContext( case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => new LocalScheduler(threads.toInt, maxFailures.toInt) case _ => - System.loadLibrary("mesos") - if (System.getProperty("spark.mesos.coarse", "false") == "true") { - new CoarseMesosScheduler(this, master, frameworkName) - } else { - new MesosScheduler(this, master, frameworkName) - } + MesosNativeLibrary.load() + new MesosScheduler(this, master, frameworkName) } } - taskScheduler.start() + scheduler.start() - private var dagScheduler = new DAGScheduler(taskScheduler) + private val isLocal = scheduler.isInstanceOf[LocalScheduler] // Methods for creating RDDs @@ -260,25 +237,19 @@ class SparkContext( // Stop the SparkContext def stop() { - remote.shutdownServerModule() - dagScheduler.stop() - dagScheduler = null - taskScheduler = null + scheduler.stop() + scheduler = null // TODO: Broadcast.stop(), Cache.stop()? env.mapOutputTracker.stop() env.cacheTracker.stop() env.shuffleFetcher.stop() env.shuffleManager.stop() - env.blockManager.stop() - BlockManagerMaster.stopBlockManagerMaster() - env.connectionManager.stop() SparkEnv.set(null) - ShuffleMapTask.clearCache() } - // Wait for the scheduler to be registered with the cluster manager + // Wait for the scheduler to be registered def waitForRegister() { - taskScheduler.waitForRegister() + scheduler.waitForRegister() } // Get Spark's home location from either a value set through the constructor, @@ -310,7 +281,7 @@ class SparkContext( ): Array[U] = { logInfo("Starting job...") val start = System.nanoTime - val result = dagScheduler.runJob(rdd, func, partitions, allowLocal) + val result = scheduler.runJob(rdd, func, partitions, allowLocal) logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s") result } @@ -335,22 +306,6 @@ class SparkContext( runJob(rdd, func, 0 until rdd.splits.size, false) } - /** - * Run a job that can return approximate results. - */ - def runApproximateJob[T, U, R]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - evaluator: ApproximateEvaluator[U, R], - timeout: Long - ): PartialResult[R] = { - logInfo("Starting job...") - val start = System.nanoTime - val result = dagScheduler.runApproximateJob(rdd, func, evaluator, timeout) - logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s") - result - } - // Clean a closure to make it ready to serialized and send to tasks // (removes unreferenced variables in $outer's, updates REPL variables) private[spark] def clean[F <: AnyRef](f: F): F = { @@ -359,7 +314,7 @@ class SparkContext( } // Default level of parallelism to use when not given by user (e.g. for reduce tasks) - def defaultParallelism: Int = taskScheduler.defaultParallelism + def defaultParallelism: Int = scheduler.defaultParallelism // Default min number of splits for Hadoop RDDs when not given by user def defaultMinSplits: Int = math.min(defaultParallelism, 2) @@ -394,23 +349,15 @@ object SparkContext { } // TODO: Add AccumulatorParams for other types, e.g. lists and strings - implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) = new PairRDDFunctions(rdd) - - implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest]( - rdd: RDD[(K, V)]) = + + implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](rdd: RDD[(K, V)]) = new SequenceFileRDDFunctions(rdd) - implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( - rdd: RDD[(K, V)]) = + implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) = new OrderedRDDFunctions(rdd) - implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd) - - implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = - new DoubleRDDFunctions(rdd.map(x => num.toDouble(x))) - // Implicit conversions to common Writable types, for saveAsSequenceFile implicit def intToIntWritable(i: Int) = new IntWritable(i) diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 897a5ef82d0913cf3d263d0d7db4e6986c4387d9..cd752f8b6597e6feb97a1d1e582070dae745f628 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -1,26 +1,14 @@ package spark -import akka.actor.Actor - -import spark.storage.BlockManager -import spark.storage.BlockManagerMaster -import spark.network.ConnectionManager - class SparkEnv ( - val cache: Cache, - val serializer: Serializer, - val closureSerializer: Serializer, - val cacheTracker: CacheTracker, - val mapOutputTracker: MapOutputTracker, - val shuffleFetcher: ShuffleFetcher, - val shuffleManager: ShuffleManager, - val blockManager: BlockManager, - val connectionManager: ConnectionManager - ) { - - /** No-parameter constructor for unit tests. */ - def this() = this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null) -} + val cache: Cache, + val serializer: Serializer, + val closureSerializer: Serializer, + val cacheTracker: CacheTracker, + val mapOutputTracker: MapOutputTracker, + val shuffleFetcher: ShuffleFetcher, + val shuffleManager: ShuffleManager +) object SparkEnv { private val env = new ThreadLocal[SparkEnv] @@ -33,55 +21,36 @@ object SparkEnv { env.get() } - def createFromSystemProperties(isMaster: Boolean, isLocal: Boolean): SparkEnv = { - val serializerClass = System.getProperty("spark.serializer", "spark.KryoSerializer") - val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer] - - BlockManagerMaster.startBlockManagerMaster(isMaster, isLocal) - - var blockManager = new BlockManager(serializer) - - val connectionManager = blockManager.connectionManager + def createFromSystemProperties(isMaster: Boolean): SparkEnv = { + val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache") + val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache] - val shuffleManager = new ShuffleManager() + val serializerClass = System.getProperty("spark.serializer", "spark.JavaSerializer") + val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer] val closureSerializerClass = System.getProperty("spark.closure.serializer", "spark.JavaSerializer") val closureSerializer = Class.forName(closureSerializerClass).newInstance().asInstanceOf[Serializer] - val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache") - val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache] - val cacheTracker = new CacheTracker(isMaster, blockManager) - blockManager.cacheTracker = cacheTracker + val cacheTracker = new CacheTracker(isMaster, cache) val mapOutputTracker = new MapOutputTracker(isMaster) val shuffleFetcherClass = - System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") + System.getProperty("spark.shuffle.fetcher", "spark.SimpleShuffleFetcher") val shuffleFetcher = Class.forName(shuffleFetcherClass).newInstance().asInstanceOf[ShuffleFetcher] - /* - if (System.getProperty("spark.stream.distributed", "false") == "true") { - val blockManagerClass = classOf[spark.storage.BlockManager].asInstanceOf[Class[_]] - if (isLocal || !isMaster) { - (new Thread() { - override def run() { - println("Wait started") - Thread.sleep(60000) - println("Wait ended") - val receiverClass = Class.forName("spark.stream.TestStreamReceiver4") - val constructor = receiverClass.getConstructor(blockManagerClass) - val receiver = constructor.newInstance(blockManager) - receiver.asInstanceOf[Thread].start() - } - }).start() - } - } - */ + val shuffleMgr = new ShuffleManager() - new SparkEnv(cache, serializer, closureSerializer, cacheTracker, mapOutputTracker, shuffleFetcher, - shuffleManager, blockManager, connectionManager) + new SparkEnv( + cache, + serializer, + closureSerializer, + cacheTracker, + mapOutputTracker, + shuffleFetcher, + shuffleMgr) } } diff --git a/core/src/main/scala/spark/Stage.scala b/core/src/main/scala/spark/Stage.scala new file mode 100644 index 0000000000000000000000000000000000000000..9452ea3a8e57db93c4cc31744a80bef8b3dfbd15 --- /dev/null +++ b/core/src/main/scala/spark/Stage.scala @@ -0,0 +1,41 @@ +package spark + +class Stage( + val id: Int, + val rdd: RDD[_], + val shuffleDep: Option[ShuffleDependency[_,_,_]], + val parents: List[Stage]) { + + val isShuffleMap = shuffleDep != None + val numPartitions = rdd.splits.size + val outputLocs = Array.fill[List[String]](numPartitions)(Nil) + var numAvailableOutputs = 0 + + def isAvailable: Boolean = { + if (parents.size == 0 && !isShuffleMap) { + true + } else { + numAvailableOutputs == numPartitions + } + } + + def addOutputLoc(partition: Int, host: String) { + val prevList = outputLocs(partition) + outputLocs(partition) = host :: prevList + if (prevList == Nil) + numAvailableOutputs += 1 + } + + def removeOutputLoc(partition: Int, host: String) { + val prevList = outputLocs(partition) + val newList = prevList.filterNot(_ == host) + outputLocs(partition) = newList + if (prevList != Nil && newList == Nil) { + numAvailableOutputs -= 1 + } + } + + override def toString = "Stage " + id + + override def hashCode(): Int = id +} diff --git a/core/src/main/scala/spark/Task.scala b/core/src/main/scala/spark/Task.scala new file mode 100644 index 0000000000000000000000000000000000000000..bc3b3743447bda9d887bbbe970beb2ef52dbf38e --- /dev/null +++ b/core/src/main/scala/spark/Task.scala @@ -0,0 +1,9 @@ +package spark + +class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Int) extends Serializable + +abstract class Task[T] extends Serializable { + def run(id: Int): T + def preferredLocations: Seq[String] = Nil + def generation: Option[Long] = None +} diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala deleted file mode 100644 index 7a6214aab6648f6e7f5670b9839f3582dbe628bb..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/TaskContext.scala +++ /dev/null @@ -1,3 +0,0 @@ -package spark - -class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Int) extends Serializable diff --git a/core/src/main/scala/spark/TaskEndReason.scala b/core/src/main/scala/spark/TaskEndReason.scala deleted file mode 100644 index 6e4eb25ed44ff07e94085ebaa0d01c736a2839ed..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/TaskEndReason.scala +++ /dev/null @@ -1,16 +0,0 @@ -package spark - -import spark.storage.BlockManagerId - -/** - * Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry - * tasks several times for "ephemeral" failures, and only report back failures that require some - * old stages to be resubmitted, such as shuffle map fetch failures. - */ -sealed trait TaskEndReason - -case object Success extends TaskEndReason -case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it -case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason -case class ExceptionFailure(exception: Throwable) extends TaskEndReason -case class OtherFailure(message: String) extends TaskEndReason diff --git a/core/src/main/scala/spark/TaskResult.scala b/core/src/main/scala/spark/TaskResult.scala new file mode 100644 index 0000000000000000000000000000000000000000..2b7fd1a4b225e74dae4da46ad14d8b2cba0a87e9 --- /dev/null +++ b/core/src/main/scala/spark/TaskResult.scala @@ -0,0 +1,8 @@ +package spark + +import scala.collection.mutable.Map + +// Task result. Also contains updates to accumulator variables. +// TODO: Use of distributed cache to return result is a hack to get around +// what seems to be a bug with messages over 60KB in libprocess; fix it +private class TaskResult[T](val value: T, val accumUpdates: Map[Long, Any]) extends Serializable diff --git a/core/src/main/scala/spark/UnionRDD.scala b/core/src/main/scala/spark/UnionRDD.scala index 17522e2bbb6d1077d4d8caefc778753229d820d2..4c0f255e6bb767e61ed3864f3e3600f237692247 100644 --- a/core/src/main/scala/spark/UnionRDD.scala +++ b/core/src/main/scala/spark/UnionRDD.scala @@ -33,8 +33,7 @@ class UnionRDD[T: ClassManifest]( override def splits = splits_ - @transient - override val dependencies = { + @transient override val dependencies = { val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for ((rdd, index) <- rdds.zipWithIndex) { diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index 89624eb37013a89223f56274062fd184b2a0be9b..68ccab24db3867af2747529ae0264a37e6905bc3 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -124,23 +124,6 @@ object Utils { * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4). */ def localIpAddress(): String = InetAddress.getLocalHost.getHostAddress - - private var customHostname: Option[String] = None - - /** - * Allow setting a custom host name because when we run on Mesos we need to use the same - * hostname it reports to the master. - */ - def setCustomHostname(hostname: String) { - customHostname = Some(hostname) - } - - /** - * Get the local machine's hostname - */ - def localHostName(): String = { - customHostname.getOrElse(InetAddress.getLocalHost.getHostName) - } /** * Returns a standard ThreadFactory except all threads are daemons. @@ -165,14 +148,6 @@ object Utils { return threadPool } - - /** - * Return the string to tell how long has passed in seconds. The passing parameter should be in - * millisecond. - */ - def getUsedTimeMs(startTimeMs: Long): String = { - return " " + (System.currentTimeMillis - startTimeMs) + " ms " - } /** * Wrapper over newFixedThreadPool. @@ -185,6 +160,16 @@ object Utils { return threadPool } + /** + * Get the local machine's hostname. + */ + def localHostName(): String = InetAddress.getLocalHost.getHostName + + /** + * Get current host + */ + def getHost = System.getProperty("spark.hostname", localHostName()) + /** * Delete a file or directory and its contents recursively. */ diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala deleted file mode 100644 index 4546dfa0fac1b6c7f07d708a42abac2f4cedbdaa..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/network/Connection.scala +++ /dev/null @@ -1,364 +0,0 @@ -package spark.network - -import spark._ - -import scala.collection.mutable.{HashMap, Queue, ArrayBuffer} - -import java.io._ -import java.nio._ -import java.nio.channels._ -import java.nio.channels.spi._ -import java.net._ - - -abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging { - - channel.configureBlocking(false) - channel.socket.setTcpNoDelay(true) - channel.socket.setReuseAddress(true) - channel.socket.setKeepAlive(true) - /*channel.socket.setReceiveBufferSize(32768) */ - - var onCloseCallback: Connection => Unit = null - var onExceptionCallback: (Connection, Exception) => Unit = null - var onKeyInterestChangeCallback: (Connection, Int) => Unit = null - - lazy val remoteAddress = getRemoteAddress() - lazy val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress) - - def key() = channel.keyFor(selector) - - def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] - - def read() { - throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString) - } - - def write() { - throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString) - } - - def close() { - key.cancel() - channel.close() - callOnCloseCallback() - } - - def onClose(callback: Connection => Unit) {onCloseCallback = callback} - - def onException(callback: (Connection, Exception) => Unit) {onExceptionCallback = callback} - - def onKeyInterestChange(callback: (Connection, Int) => Unit) {onKeyInterestChangeCallback = callback} - - def callOnExceptionCallback(e: Exception) { - if (onExceptionCallback != null) { - onExceptionCallback(this, e) - } else { - logError("Error in connection to " + remoteConnectionManagerId + - " and OnExceptionCallback not registered", e) - } - } - - def callOnCloseCallback() { - if (onCloseCallback != null) { - onCloseCallback(this) - } else { - logWarning("Connection to " + remoteConnectionManagerId + - " closed and OnExceptionCallback not registered") - } - - } - - def changeConnectionKeyInterest(ops: Int) { - if (onKeyInterestChangeCallback != null) { - onKeyInterestChangeCallback(this, ops) - } else { - throw new Exception("OnKeyInterestChangeCallback not registered") - } - } - - def printRemainingBuffer(buffer: ByteBuffer) { - val bytes = new Array[Byte](buffer.remaining) - val curPosition = buffer.position - buffer.get(bytes) - bytes.foreach(x => print(x + " ")) - buffer.position(curPosition) - print(" (" + bytes.size + ")") - } - - def printBuffer(buffer: ByteBuffer, position: Int, length: Int) { - val bytes = new Array[Byte](length) - val curPosition = buffer.position - buffer.position(position) - buffer.get(bytes) - bytes.foreach(x => print(x + " ")) - print(" (" + position + ", " + length + ")") - buffer.position(curPosition) - } - -} - - -class SendingConnection(val address: InetSocketAddress, selector_ : Selector) -extends Connection(SocketChannel.open, selector_) { - - class Outbox(fair: Int = 0) { - val messages = new Queue[Message]() - val defaultChunkSize = 65536 //32768 //16384 - var nextMessageToBeUsed = 0 - - def addMessage(message: Message): Unit = { - messages.synchronized{ - /*messages += message*/ - messages.enqueue(message) - logInfo("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]") - } - } - - def getChunk(): Option[MessageChunk] = { - fair match { - case 0 => getChunkFIFO() - case 1 => getChunkRR() - case _ => throw new Exception("Unexpected fairness policy in outbox") - } - } - - private def getChunkFIFO(): Option[MessageChunk] = { - /*logInfo("Using FIFO")*/ - messages.synchronized { - while (!messages.isEmpty) { - val message = messages(0) - val chunk = message.getChunkForSending(defaultChunkSize) - if (chunk.isDefined) { - messages += message // this is probably incorrect, it wont work as fifo - if (!message.started) logDebug("Starting to send [" + message + "]") - message.started = true - return chunk - } - /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ - logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken ) - } - } - None - } - - private def getChunkRR(): Option[MessageChunk] = { - messages.synchronized { - while (!messages.isEmpty) { - /*nextMessageToBeUsed = nextMessageToBeUsed % messages.size */ - /*val message = messages(nextMessageToBeUsed)*/ - val message = messages.dequeue - val chunk = message.getChunkForSending(defaultChunkSize) - if (chunk.isDefined) { - messages.enqueue(message) - nextMessageToBeUsed = nextMessageToBeUsed + 1 - if (!message.started) { - logDebug("Starting to send [" + message + "] to [" + remoteConnectionManagerId + "]") - message.started = true - message.startTime = System.currentTimeMillis - } - logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]") - return chunk - } - /*messages -= message*/ - message.finishTime = System.currentTimeMillis - logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken ) - } - } - None - } - } - - val outbox = new Outbox(1) - val currentBuffers = new ArrayBuffer[ByteBuffer]() - - /*channel.socket.setSendBufferSize(256 * 1024)*/ - - override def getRemoteAddress() = address - - def send(message: Message) { - outbox.synchronized { - outbox.addMessage(message) - if (channel.isConnected) { - changeConnectionKeyInterest(SelectionKey.OP_WRITE) - } - } - } - - def connect() { - try{ - channel.connect(address) - channel.register(selector, SelectionKey.OP_CONNECT) - logInfo("Initiating connection to [" + address + "]") - } catch { - case e: Exception => { - logError("Error connecting to " + address, e) - callOnExceptionCallback(e) - } - } - } - - def finishConnect() { - try { - channel.finishConnect - changeConnectionKeyInterest(SelectionKey.OP_WRITE) - logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") - } catch { - case e: Exception => { - logWarning("Error finishing connection to " + address, e) - callOnExceptionCallback(e) - } - } - } - - override def write() { - try{ - while(true) { - if (currentBuffers.size == 0) { - outbox.synchronized { - outbox.getChunk match { - case Some(chunk) => { - currentBuffers ++= chunk.buffers - } - case None => { - changeConnectionKeyInterest(0) - /*key.interestOps(0)*/ - return - } - } - } - } - - if (currentBuffers.size > 0) { - val buffer = currentBuffers(0) - val remainingBytes = buffer.remaining - val writtenBytes = channel.write(buffer) - if (buffer.remaining == 0) { - currentBuffers -= buffer - } - if (writtenBytes < remainingBytes) { - return - } - } - } - } catch { - case e: Exception => { - logWarning("Error writing in connection to " + remoteConnectionManagerId, e) - callOnExceptionCallback(e) - close() - } - } - } -} - - -class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) -extends Connection(channel_, selector_) { - - class Inbox() { - val messages = new HashMap[Int, BufferMessage]() - - def getChunk(header: MessageChunkHeader): Option[MessageChunk] = { - - def createNewMessage: BufferMessage = { - val newMessage = Message.create(header).asInstanceOf[BufferMessage] - newMessage.started = true - newMessage.startTime = System.currentTimeMillis - logDebug("Starting to receive [" + newMessage + "] from [" + remoteConnectionManagerId + "]") - messages += ((newMessage.id, newMessage)) - newMessage - } - - val message = messages.getOrElseUpdate(header.id, createNewMessage) - logTrace("Receiving chunk of [" + message + "] from [" + remoteConnectionManagerId + "]") - message.getChunkForReceiving(header.chunkSize) - } - - def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = { - messages.get(chunk.header.id) - } - - def removeMessage(message: Message) { - messages -= message.id - } - } - - val inbox = new Inbox() - val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE) - var onReceiveCallback: (Connection , Message) => Unit = null - var currentChunk: MessageChunk = null - - channel.register(selector, SelectionKey.OP_READ) - - override def read() { - try { - while (true) { - if (currentChunk == null) { - val headerBytesRead = channel.read(headerBuffer) - if (headerBytesRead == -1) { - close() - return - } - if (headerBuffer.remaining > 0) { - return - } - headerBuffer.flip - if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) { - throw new Exception("Unexpected number of bytes (" + headerBuffer.remaining + ") in the header") - } - val header = MessageChunkHeader.create(headerBuffer) - headerBuffer.clear() - header.typ match { - case Message.BUFFER_MESSAGE => { - if (header.totalSize == 0) { - if (onReceiveCallback != null) { - onReceiveCallback(this, Message.create(header)) - } - currentChunk = null - return - } else { - currentChunk = inbox.getChunk(header).orNull - } - } - case _ => throw new Exception("Message of unknown type received") - } - } - - if (currentChunk == null) throw new Exception("No message chunk to receive data") - - val bytesRead = channel.read(currentChunk.buffer) - if (bytesRead == 0) { - return - } else if (bytesRead == -1) { - close() - return - } - - /*logDebug("Read " + bytesRead + " bytes for the buffer")*/ - - if (currentChunk.buffer.remaining == 0) { - /*println("Filled buffer at " + System.currentTimeMillis)*/ - val bufferMessage = inbox.getMessageForChunk(currentChunk).get - if (bufferMessage.isCompletelyReceived) { - bufferMessage.flip - bufferMessage.finishTime = System.currentTimeMillis - logDebug("Finished receiving [" + bufferMessage + "] from [" + remoteConnectionManagerId + "] in " + bufferMessage.timeTaken) - if (onReceiveCallback != null) { - onReceiveCallback(this, bufferMessage) - } - inbox.removeMessage(bufferMessage) - } - currentChunk = null - } - } - } catch { - case e: Exception => { - logWarning("Error reading from connection to " + remoteConnectionManagerId, e) - callOnExceptionCallback(e) - close() - } - } - } - - def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} -} diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala deleted file mode 100644 index 3222187990eaa630997d012a0babbcd1a61cbb20..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ /dev/null @@ -1,468 +0,0 @@ -package spark.network - -import spark._ - -import scala.actors.Future -import scala.actors.Futures.future -import scala.collection.mutable.HashMap -import scala.collection.mutable.SynchronizedMap -import scala.collection.mutable.SynchronizedQueue -import scala.collection.mutable.Queue -import scala.collection.mutable.ArrayBuffer - -import java.io._ -import java.nio._ -import java.nio.channels._ -import java.nio.channels.spi._ -import java.net._ -import java.util.concurrent.Executors - -case class ConnectionManagerId(val host: String, val port: Int) { - def toSocketAddress() = new InetSocketAddress(host, port) -} - -object ConnectionManagerId { - def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { - new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort()) - } -} - -class ConnectionManager(port: Int) extends Logging { - - case class MessageStatus(message: Message, connectionManagerId: ConnectionManagerId) { - var ackMessage: Option[Message] = None - var attempted = false - var acked = false - } - - val selector = SelectorProvider.provider.openSelector() - val handleMessageExecutor = Executors.newFixedThreadPool(4) - val serverChannel = ServerSocketChannel.open() - val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] - val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] - val messageStatuses = new HashMap[Int, MessageStatus] - val connectionRequests = new SynchronizedQueue[SendingConnection] - val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] - val sendMessageRequests = new Queue[(Message, SendingConnection)] - - var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null - - serverChannel.configureBlocking(false) - serverChannel.socket.setReuseAddress(true) - serverChannel.socket.setReceiveBufferSize(256 * 1024) - - serverChannel.socket.bind(new InetSocketAddress(port)) - serverChannel.register(selector, SelectionKey.OP_ACCEPT) - - val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) - logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) - - val thisInstance = this - val selectorThread = new Thread("connection-manager-thread") { - override def run() { - thisInstance.run() - } - } - selectorThread.setDaemon(true) - selectorThread.start() - - def run() { - try { - var interrupted = false - while(!interrupted) { - while(!connectionRequests.isEmpty) { - val sendingConnection = connectionRequests.dequeue - sendingConnection.connect() - addConnection(sendingConnection) - } - sendMessageRequests.synchronized { - while(!sendMessageRequests.isEmpty) { - val (message, connection) = sendMessageRequests.dequeue - connection.send(message) - } - } - - while(!keyInterestChangeRequests.isEmpty) { - val (key, ops) = keyInterestChangeRequests.dequeue - val connection = connectionsByKey(key) - val lastOps = key.interestOps() - key.interestOps(ops) - - def intToOpStr(op: Int): String = { - val opStrs = ArrayBuffer[String]() - if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" - if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" - if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" - if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" - if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " - } - - logTrace("Changed key for connection to [" + connection.remoteConnectionManagerId + - "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") - - } - - val selectedKeysCount = selector.select() - if (selectedKeysCount == 0) logInfo("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys") - - interrupted = selectorThread.isInterrupted - - val selectedKeys = selector.selectedKeys().iterator() - while (selectedKeys.hasNext()) { - val key = selectedKeys.next.asInstanceOf[SelectionKey] - selectedKeys.remove() - if (key.isValid) { - if (key.isAcceptable) { - acceptConnection(key) - } else - if (key.isConnectable) { - connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect() - } else - if (key.isReadable) { - connectionsByKey(key).read() - } else - if (key.isWritable) { - connectionsByKey(key).write() - } - } - } - } - } catch { - case e: Exception => logError("Error in select loop", e) - } - } - - def acceptConnection(key: SelectionKey) { - val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] - val newChannel = serverChannel.accept() - val newConnection = new ReceivingConnection(newChannel, selector) - newConnection.onReceive(receiveMessage) - newConnection.onClose(removeConnection) - addConnection(newConnection) - logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]") - } - - def addConnection(connection: Connection) { - connectionsByKey += ((connection.key, connection)) - if (connection.isInstanceOf[SendingConnection]) { - val sendingConnection = connection.asInstanceOf[SendingConnection] - connectionsById += ((sendingConnection.remoteConnectionManagerId, sendingConnection)) - } - connection.onKeyInterestChange(changeConnectionKeyInterest) - connection.onException(handleConnectionError) - connection.onClose(removeConnection) - } - - def removeConnection(connection: Connection) { - /*logInfo("Removing connection")*/ - connectionsByKey -= connection.key - if (connection.isInstanceOf[SendingConnection]) { - val sendingConnection = connection.asInstanceOf[SendingConnection] - val sendingConnectionManagerId = sendingConnection.remoteConnectionManagerId - logInfo("Removing SendingConnection to " + sendingConnectionManagerId) - - connectionsById -= sendingConnectionManagerId - - messageStatuses.synchronized { - messageStatuses - .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => { - logInfo("Notifying " + status) - status.synchronized { - status.attempted = true - status.acked = false - status.notifyAll - } - }) - - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - } else if (connection.isInstanceOf[ReceivingConnection]) { - val receivingConnection = connection.asInstanceOf[ReceivingConnection] - val remoteConnectionManagerId = receivingConnection.remoteConnectionManagerId - logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) - - val sendingConnectionManagerId = connectionsById.keys.find(_.host == remoteConnectionManagerId.host).orNull - if (sendingConnectionManagerId == null) { - logError("Corresponding SendingConnectionManagerId not found") - return - } - logInfo("Corresponding SendingConnectionManagerId is " + sendingConnectionManagerId) - - val sendingConnection = connectionsById(sendingConnectionManagerId) - sendingConnection.close() - connectionsById -= sendingConnectionManagerId - - messageStatuses.synchronized { - messageStatuses - .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => { - logInfo("Notifying " + status) - status.synchronized { - status.attempted = true - status.acked = false - status.notifyAll - } - }) - - messageStatuses.retain((i, status) => { - status.connectionManagerId != sendingConnectionManagerId - }) - } - } - } - - def handleConnectionError(connection: Connection, e: Exception) { - logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId) - removeConnection(connection) - } - - def changeConnectionKeyInterest(connection: Connection, ops: Int) { - keyInterestChangeRequests += ((connection.key, ops)) - } - - def receiveMessage(connection: Connection, message: Message) { - val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) - logInfo("Received [" + message + "] from [" + connectionManagerId + "]") - val runnable = new Runnable() { - val creationTime = System.currentTimeMillis - def run() { - logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") - handleMessage(connectionManagerId, message) - logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") - } - } - handleMessageExecutor.execute(runnable) - /*handleMessage(connection, message)*/ - } - - private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) { - logInfo("Handling [" + message + "] from [" + connectionManagerId + "]") - message match { - case bufferMessage: BufferMessage => { - if (bufferMessage.hasAckId) { - val sentMessageStatus = messageStatuses.synchronized { - messageStatuses.get(bufferMessage.ackId) match { - case Some(status) => { - messageStatuses -= bufferMessage.ackId - status - } - case None => { - throw new Exception("Could not find reference for received ack message " + message.id) - null - } - } - } - sentMessageStatus.synchronized { - sentMessageStatus.ackMessage = Some(message) - sentMessageStatus.attempted = true - sentMessageStatus.acked = true - sentMessageStatus.notifyAll - } - } else { - val ackMessage = if (onReceiveCallback != null) { - logDebug("Calling back") - onReceiveCallback(bufferMessage, connectionManagerId) - } else { - logWarning("Not calling back as callback is null") - None - } - - if (ackMessage.isDefined) { - if (!ackMessage.get.isInstanceOf[BufferMessage]) { - logWarning("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass()) - } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { - logWarning("Response to " + bufferMessage + " does not have ack id set") - ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id - } - } - - sendMessage(connectionManagerId, ackMessage.getOrElse { - Message.createBufferMessage(bufferMessage.id) - }) - } - } - case _ => throw new Exception("Unknown type message received") - } - } - - private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { - def startNewConnection(): SendingConnection = { - val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) - val newConnection = new SendingConnection(inetSocketAddress, selector) - connectionRequests += newConnection - newConnection - } - val connection = connectionsById.getOrElse(connectionManagerId, startNewConnection) - message.senderAddress = id.toSocketAddress() - logInfo("Sending [" + message + "] to [" + connectionManagerId + "]") - /*connection.send(message)*/ - sendMessageRequests.synchronized { - sendMessageRequests += ((message, connection)) - } - selector.wakeup() - } - - def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message): Future[Option[Message]] = { - val messageStatus = new MessageStatus(message, connectionManagerId) - messageStatuses.synchronized { - messageStatuses += ((message.id, messageStatus)) - } - sendMessage(connectionManagerId, message) - future { - messageStatus.synchronized { - if (!messageStatus.attempted) { - logTrace("Waiting, " + messageStatuses.size + " statuses" ) - messageStatus.wait() - logTrace("Done waiting") - } - } - messageStatus.ackMessage - } - } - - def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, message: Message): Option[Message] = { - sendMessageReliably(connectionManagerId, message)() - } - - def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { - onReceiveCallback = callback - } - - def stop() { - if (!selectorThread.isAlive) { - selectorThread.interrupt() - selectorThread.join() - selector.close() - val connections = connectionsByKey.values - connections.foreach(_.close()) - if (connectionsByKey.size != 0) { - logWarning("All connections not cleaned up") - } - handleMessageExecutor.shutdown() - logInfo("ConnectionManager stopped") - } - } -} - - -object ConnectionManager { - - def main(args: Array[String]) { - - val manager = new ConnectionManager(9999) - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - println("Received [" + msg + "] from [" + id + "]") - None - }) - - /*testSequentialSending(manager)*/ - /*System.gc()*/ - - /*testParallelSending(manager)*/ - /*System.gc()*/ - - /*testParallelDecreasingSending(manager)*/ - /*System.gc()*/ - - testContinuousSending(manager) - System.gc() - } - - def testSequentialSending(manager: ConnectionManager) { - println("--------------------------") - println("Sequential Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliablySync(manager.id, bufferMessage) - }) - println("--------------------------") - println() - } - - def testParallelSending(manager: ConnectionManager) { - println("--------------------------") - println("Parallel Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => {if (!f().isDefined) println("Failed")}) - val finishTime = System.currentTimeMillis - - val mb = size * count / 1024.0 / 1024.0 - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("--------------------------") - println("Started at " + startTime + ", finished at " + finishTime) - println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)") - println("--------------------------") - println() - } - - def testParallelDecreasingSending(manager: ConnectionManager) { - println("--------------------------") - println("Parallel Decreasing Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte))) - buffers.foreach(_.flip) - val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0 - - val startTime = System.currentTimeMillis - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => {if (!f().isDefined) println("Failed")}) - val finishTime = System.currentTimeMillis - - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("--------------------------") - /*println("Started at " + startTime + ", finished at " + finishTime) */ - println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") - println("--------------------------") - println() - } - - def testContinuousSending(manager: ConnectionManager) { - println("--------------------------") - println("Continuous Sending") - println("--------------------------") - val size = 10 * 1024 * 1024 - val count = 10 - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - while(true) { - (0 until count).map(i => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - manager.sendMessageReliably(manager.id, bufferMessage) - }).foreach(f => {if (!f().isDefined) println("Failed")}) - val finishTime = System.currentTimeMillis - Thread.sleep(1000) - val mb = size * count / 1024.0 / 1024.0 - val ms = finishTime - startTime - val tput = mb * 1000.0 / ms - println("--------------------------") - println() - } - } -} diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala deleted file mode 100644 index 5d21bb793f3dcefce2af736edeb602c47ff0c56f..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/network/ConnectionManagerTest.scala +++ /dev/null @@ -1,74 +0,0 @@ -package spark.network - -import spark._ -import spark.SparkContext._ - -import scala.io.Source - -import java.nio.ByteBuffer -import java.net.InetAddress - -object ConnectionManagerTest extends Logging{ - def main(args: Array[String]) { - if (args.length < 2) { - println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>") - System.exit(1) - } - - if (args(0).startsWith("local")) { - println("This runs only on a mesos cluster") - } - - val sc = new SparkContext(args(0), "ConnectionManagerTest") - val slavesFile = Source.fromFile(args(1)) - val slaves = slavesFile.mkString.split("\n") - slavesFile.close() - - /*println("Slaves")*/ - /*slaves.foreach(println)*/ - - val slaveConnManagerIds = sc.parallelize(0 until slaves.length, slaves.length).map( - i => SparkEnv.get.connectionManager.id).collect() - println("\nSlave ConnectionManagerIds") - slaveConnManagerIds.foreach(println) - println - - val count = 10 - (0 until count).foreach(i => { - val resultStrs = sc.parallelize(0 until slaves.length, slaves.length).map(i => { - val connManager = SparkEnv.get.connectionManager - val thisConnManagerId = connManager.id - connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - logInfo("Received [" + msg + "] from [" + id + "]") - None - }) - - val size = 100 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") - connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) - }) - val results = futures.map(f => f()) - val finishTime = System.currentTimeMillis - Thread.sleep(5000) - - val mb = size * results.size / 1024.0 / 1024.0 - val ms = finishTime - startTime - val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" - logInfo(resultStr) - resultStr - }).collect() - - println("---------------------") - println("Run " + i) - resultStrs.foreach(println) - println("---------------------") - }) - } -} - diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala deleted file mode 100644 index 2e858036791d2e5e80020c7527d6cfecf6bd9f07..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/network/Message.scala +++ /dev/null @@ -1,219 +0,0 @@ -package spark.network - -import spark._ - -import scala.collection.mutable.ArrayBuffer - -import java.nio.ByteBuffer -import java.net.InetAddress -import java.net.InetSocketAddress - -class MessageChunkHeader( - val typ: Long, - val id: Int, - val totalSize: Int, - val chunkSize: Int, - val other: Int, - val address: InetSocketAddress) { - lazy val buffer = { - val ip = address.getAddress.getAddress() - val port = address.getPort() - ByteBuffer. - allocate(MessageChunkHeader.HEADER_SIZE). - putLong(typ). - putInt(id). - putInt(totalSize). - putInt(chunkSize). - putInt(other). - putInt(ip.size). - put(ip). - putInt(port). - position(MessageChunkHeader.HEADER_SIZE). - flip.asInstanceOf[ByteBuffer] - } - - override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + - " and sizes " + totalSize + " / " + chunkSize + " bytes" -} - -class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { - val size = if (buffer == null) 0 else buffer.remaining - lazy val buffers = { - val ab = new ArrayBuffer[ByteBuffer]() - ab += header.buffer - if (buffer != null) { - ab += buffer - } - ab - } - - override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" -} - -abstract class Message(val typ: Long, val id: Int) { - var senderAddress: InetSocketAddress = null - var started = false - var startTime = -1L - var finishTime = -1L - - def size: Int - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] - - def timeTaken(): String = (finishTime - startTime).toString + " ms" - - override def toString = "" + this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" -} - -class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) -extends Message(Message.BUFFER_MESSAGE, id_) { - - val initialSize = currentSize() - var gotChunkForSendingOnce = false - - def size = initialSize - - def currentSize() = { - if (buffers == null || buffers.isEmpty) { - 0 - } else { - buffers.map(_.remaining).reduceLeft(_ + _) - } - } - - def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = { - if (maxChunkSize <= 0) { - throw new Exception("Max chunk size is " + maxChunkSize) - } - - if (size == 0 && gotChunkForSendingOnce == false) { - val newChunk = new MessageChunk(new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null) - gotChunkForSendingOnce = true - return Some(newChunk) - } - - while(!buffers.isEmpty) { - val buffer = buffers(0) - if (buffer.remaining == 0) { - buffers -= buffer - } else { - val newBuffer = if (buffer.remaining <= maxChunkSize) { - buffer.duplicate - } else { - buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] - } - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) - gotChunkForSendingOnce = true - return Some(newChunk) - } - } - None - } - - def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = { - // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer - if (buffers.size > 1) { - throw new Exception("Attempting to get chunk from message with multiple data buffers") - } - val buffer = buffers(0) - if (buffer.remaining > 0) { - if (buffer.remaining < chunkSize) { - throw new Exception("Not enough space in data buffer for receiving chunk") - } - val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] - buffer.position(buffer.position + newBuffer.remaining) - val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) - return Some(newChunk) - } - None - } - - def flip() { - buffers.foreach(_.flip) - } - - def hasAckId() = (ackId != 0) - - def isCompletelyReceived() = !buffers(0).hasRemaining - - override def toString = { - if (hasAckId) { - "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" - } else { - "BufferMessage(id = " + id + ", size = " + size + ")" - } - - } -} - -object MessageChunkHeader { - val HEADER_SIZE = 40 - - def create(buffer: ByteBuffer): MessageChunkHeader = { - if (buffer.remaining != HEADER_SIZE) { - throw new IllegalArgumentException("Cannot convert buffer data to Message") - } - val typ = buffer.getLong() - val id = buffer.getInt() - val totalSize = buffer.getInt() - val chunkSize = buffer.getInt() - val other = buffer.getInt() - val ipSize = buffer.getInt() - val ipBytes = new Array[Byte](ipSize) - buffer.get(ipBytes) - val ip = InetAddress.getByAddress(ipBytes) - val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port)) - } -} - -object Message { - val BUFFER_MESSAGE = 1111111111L - - var lastId = 1 - - def getNewId() = synchronized { - lastId += 1 - if (lastId == 0) lastId += 1 - lastId - } - - def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = { - if (dataBuffers == null) { - return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId) - } - if (dataBuffers.exists(_ == null)) { - throw new Exception("Attempting to create buffer message with null buffer") - } - return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId) - } - - def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage = - createBufferMessage(dataBuffers, 0) - - def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = { - if (dataBuffer == null) { - return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) - } else { - return createBufferMessage(Array(dataBuffer), ackId) - } - } - - def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = - createBufferMessage(dataBuffer, 0) - - def createBufferMessage(ackId: Int): BufferMessage = createBufferMessage(new Array[ByteBuffer](0), ackId) - - def create(header: MessageChunkHeader): Message = { - val newMessage: Message = header.typ match { - case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) - } - newMessage.senderAddress = header.address - newMessage - } -} diff --git a/core/src/main/scala/spark/network/ReceiverTest.scala b/core/src/main/scala/spark/network/ReceiverTest.scala deleted file mode 100644 index e1ba7c06c04dfd615ef5f23ae710fc73faaf6e11..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/network/ReceiverTest.scala +++ /dev/null @@ -1,20 +0,0 @@ -package spark.network - -import java.nio.ByteBuffer -import java.net.InetAddress - -object ReceiverTest { - - def main(args: Array[String]) { - val manager = new ConnectionManager(9999) - println("Started connection manager with id = " + manager.id) - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - /*println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis)*/ - val buffer = ByteBuffer.wrap("response".getBytes()) - Some(Message.createBufferMessage(buffer, msg.id)) - }) - Thread.currentThread.join() - } -} - diff --git a/core/src/main/scala/spark/network/SenderTest.scala b/core/src/main/scala/spark/network/SenderTest.scala deleted file mode 100644 index 4ab6dd34140992fdc9d6b1642b5b4d6ae1e69e2c..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/network/SenderTest.scala +++ /dev/null @@ -1,53 +0,0 @@ -package spark.network - -import java.nio.ByteBuffer -import java.net.InetAddress - -object SenderTest { - - def main(args: Array[String]) { - - if (args.length < 2) { - println("Usage: SenderTest <target host> <target port>") - System.exit(1) - } - - val targetHost = args(0) - val targetPort = args(1).toInt - val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort) - - val manager = new ConnectionManager(0) - println("Started connection manager with id = " + manager.id) - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - println("Received [" + msg + "] from [" + id + "]") - None - }) - - val size = 100 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val targetServer = args(0) - - val count = 100 - (0 until count).foreach(i => { - val dataMessage = Message.createBufferMessage(buffer.duplicate) - val startTime = System.currentTimeMillis - /*println("Started timer at " + startTime)*/ - val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) match { - case Some(response) => - val buffer = response.asInstanceOf[BufferMessage].buffers(0) - new String(buffer.array) - case None => "none" - } - val finishTime = System.currentTimeMillis - val mb = size / 1024.0 / 1024.0 - val ms = finishTime - startTime - /*val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"*/ - val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms (" + (mb / ms * 1000.0).toInt + "MB/s) | Response = " + responseStr - println(resultStr) - }) - } -} - diff --git a/core/src/main/scala/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/spark/partial/ApproximateActionListener.scala deleted file mode 100644 index 260547902bb4a743e7a48ec1fb2d5a8b3b56da9c..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/partial/ApproximateActionListener.scala +++ /dev/null @@ -1,66 +0,0 @@ -package spark.partial - -import spark._ -import spark.scheduler.JobListener - -/** - * A JobListener for an approximate single-result action, such as count() or non-parallel reduce(). - * This listener waits up to timeout milliseconds and will return a partial answer even if the - * complete answer is not available by then. - * - * This class assumes that the action is performed on an entire RDD[T] via a function that computes - * a result of type U for each partition, and that the action returns a partial or complete result - * of type R. Note that the type R must *include* any error bars on it (e.g. see BoundedInt). - */ -class ApproximateActionListener[T, U, R]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - evaluator: ApproximateEvaluator[U, R], - timeout: Long) - extends JobListener { - - val startTime = System.currentTimeMillis() - val totalTasks = rdd.splits.size - var finishedTasks = 0 - var failure: Option[Exception] = None // Set if the job has failed (permanently) - var resultObject: Option[PartialResult[R]] = None // Set if we've already returned a PartialResult - - override def taskSucceeded(index: Int, result: Any): Unit = synchronized { - evaluator.merge(index, result.asInstanceOf[U]) - finishedTasks += 1 - if (finishedTasks == totalTasks) { - // If we had already returned a PartialResult, set its final value - resultObject.foreach(r => r.setFinalValue(evaluator.currentResult())) - // Notify any waiting thread that may have called getResult - this.notifyAll() - } - } - - override def jobFailed(exception: Exception): Unit = synchronized { - failure = Some(exception) - this.notifyAll() - } - - /** - * Waits for up to timeout milliseconds since the listener was created and then returns a - * PartialResult with the result so far. This may be complete if the whole job is done. - */ - def getResult(): PartialResult[R] = synchronized { - val finishTime = startTime + timeout - while (true) { - val time = System.currentTimeMillis() - if (failure != None) { - throw failure.get - } else if (finishedTasks == totalTasks) { - return new PartialResult(evaluator.currentResult(), true) - } else if (time >= finishTime) { - resultObject = Some(new PartialResult(evaluator.currentResult(), false)) - return resultObject.get - } else { - this.wait(finishTime - time) - } - } - // Should never be reached, but required to keep the compiler happy - return null - } -} diff --git a/core/src/main/scala/spark/partial/ApproximateEvaluator.scala b/core/src/main/scala/spark/partial/ApproximateEvaluator.scala deleted file mode 100644 index 4772e43ef04118cc25a2555ca3c250268496264f..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/partial/ApproximateEvaluator.scala +++ /dev/null @@ -1,10 +0,0 @@ -package spark.partial - -/** - * An object that computes a function incrementally by merging in results of type U from multiple - * tasks. Allows partial evaluation at any point by calling currentResult(). - */ -trait ApproximateEvaluator[U, R] { - def merge(outputId: Int, taskResult: U): Unit - def currentResult(): R -} diff --git a/core/src/main/scala/spark/partial/BoundedDouble.scala b/core/src/main/scala/spark/partial/BoundedDouble.scala deleted file mode 100644 index 463c33d6e238ebc688390accd0b66e4b4ef10cf5..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/partial/BoundedDouble.scala +++ /dev/null @@ -1,8 +0,0 @@ -package spark.partial - -/** - * A Double with error bars on it. - */ -class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) { - override def toString(): String = "[%.3f, %.3f]".format(low, high) -} diff --git a/core/src/main/scala/spark/partial/CountEvaluator.scala b/core/src/main/scala/spark/partial/CountEvaluator.scala deleted file mode 100644 index 1bc90d6b3930aab7b870cbca4a2b0731723be1e8..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/partial/CountEvaluator.scala +++ /dev/null @@ -1,38 +0,0 @@ -package spark.partial - -import cern.jet.stat.Probability - -/** - * An ApproximateEvaluator for counts. - * - * TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might - * be best to make this a special case of GroupedCountEvaluator with one group. - */ -class CountEvaluator(totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[Long, BoundedDouble] { - - var outputsMerged = 0 - var sum: Long = 0 - - override def merge(outputId: Int, taskResult: Long) { - outputsMerged += 1 - sum += taskResult - } - - override def currentResult(): BoundedDouble = { - if (outputsMerged == totalOutputs) { - new BoundedDouble(sum, 1.0, sum, sum) - } else if (outputsMerged == 0) { - new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) - } else { - val p = outputsMerged.toDouble / totalOutputs - val mean = (sum + 1 - p) / p - val variance = (sum + 1) * (1 - p) / (p * p) - val stdev = math.sqrt(variance) - val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2) - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - new BoundedDouble(mean, confidence, low, high) - } - } -} diff --git a/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala deleted file mode 100644 index 3e631c0efc5517c184126ff4602988d1e79297e6..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala +++ /dev/null @@ -1,62 +0,0 @@ -package spark.partial - -import java.util.{HashMap => JHashMap} -import java.util.{Map => JMap} - -import scala.collection.Map -import scala.collection.mutable.HashMap -import scala.collection.JavaConversions.mapAsScalaMap - -import cern.jet.stat.Probability - -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} - -/** - * An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval. - */ -class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[OLMap[T], Map[T, BoundedDouble]] { - - var outputsMerged = 0 - var sums = new OLMap[T] // Sum of counts for each key - - override def merge(outputId: Int, taskResult: OLMap[T]) { - outputsMerged += 1 - val iter = taskResult.object2LongEntrySet.fastIterator() - while (iter.hasNext) { - val entry = iter.next() - sums.put(entry.getKey, sums.getLong(entry.getKey) + entry.getLongValue) - } - } - - override def currentResult(): Map[T, BoundedDouble] = { - if (outputsMerged == totalOutputs) { - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.object2LongEntrySet.fastIterator() - while (iter.hasNext) { - val entry = iter.next() - val sum = entry.getLongValue() - result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum) - } - result - } else if (outputsMerged == 0) { - new HashMap[T, BoundedDouble] - } else { - val p = outputsMerged.toDouble / totalOutputs - val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2) - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.object2LongEntrySet.fastIterator() - while (iter.hasNext) { - val entry = iter.next() - val sum = entry.getLongValue - val mean = (sum + 1 - p) / p - val variance = (sum + 1) * (1 - p) / (p * p) - val stdev = math.sqrt(variance) - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - result(entry.getKey) = new BoundedDouble(mean, confidence, low, high) - } - result - } - } -} diff --git a/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala deleted file mode 100644 index 2a9ccba2055efc5121de8789b225a9808bb475b9..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala +++ /dev/null @@ -1,65 +0,0 @@ -package spark.partial - -import java.util.{HashMap => JHashMap} -import java.util.{Map => JMap} - -import scala.collection.mutable.HashMap -import scala.collection.Map -import scala.collection.JavaConversions.mapAsScalaMap - -import spark.util.StatCounter - -/** - * An ApproximateEvaluator for means by key. Returns a map of key to confidence interval. - */ -class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { - - var outputsMerged = 0 - var sums = new JHashMap[T, StatCounter] // Sum of counts for each key - - override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) { - outputsMerged += 1 - val iter = taskResult.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val old = sums.get(entry.getKey) - if (old != null) { - old.merge(entry.getValue) - } else { - sums.put(entry.getKey, entry.getValue) - } - } - } - - override def currentResult(): Map[T, BoundedDouble] = { - if (outputsMerged == totalOutputs) { - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val mean = entry.getValue.mean - result(entry.getKey) = new BoundedDouble(mean, 1.0, mean, mean) - } - result - } else if (outputsMerged == 0) { - new HashMap[T, BoundedDouble] - } else { - val p = outputsMerged.toDouble / totalOutputs - val studentTCacher = new StudentTCacher(confidence) - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val counter = entry.getValue - val mean = counter.mean - val stdev = math.sqrt(counter.sampleVariance / counter.count) - val confFactor = studentTCacher.get(counter.count) - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - result(entry.getKey) = new BoundedDouble(mean, confidence, low, high) - } - result - } - } -} diff --git a/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala deleted file mode 100644 index 6a2ec7a7bd30e53bf4844ff1f4382f3118bbc635..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala +++ /dev/null @@ -1,72 +0,0 @@ -package spark.partial - -import java.util.{HashMap => JHashMap} -import java.util.{Map => JMap} - -import scala.collection.mutable.HashMap -import scala.collection.Map -import scala.collection.JavaConversions.mapAsScalaMap - -import spark.util.StatCounter - -/** - * An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval. - */ -class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { - - var outputsMerged = 0 - var sums = new JHashMap[T, StatCounter] // Sum of counts for each key - - override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) { - outputsMerged += 1 - val iter = taskResult.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val old = sums.get(entry.getKey) - if (old != null) { - old.merge(entry.getValue) - } else { - sums.put(entry.getKey, entry.getValue) - } - } - } - - override def currentResult(): Map[T, BoundedDouble] = { - if (outputsMerged == totalOutputs) { - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val sum = entry.getValue.sum - result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum) - } - result - } else if (outputsMerged == 0) { - new HashMap[T, BoundedDouble] - } else { - val p = outputsMerged.toDouble / totalOutputs - val studentTCacher = new StudentTCacher(confidence) - val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.entrySet.iterator() - while (iter.hasNext) { - val entry = iter.next() - val counter = entry.getValue - val meanEstimate = counter.mean - val meanVar = counter.sampleVariance / counter.count - val countEstimate = (counter.count + 1 - p) / p - val countVar = (counter.count + 1) * (1 - p) / (p * p) - val sumEstimate = meanEstimate * countEstimate - val sumVar = (meanEstimate * meanEstimate * countVar) + - (countEstimate * countEstimate * meanVar) + - (meanVar * countVar) - val sumStdev = math.sqrt(sumVar) - val confFactor = studentTCacher.get(counter.count) - val low = sumEstimate - confFactor * sumStdev - val high = sumEstimate + confFactor * sumStdev - result(entry.getKey) = new BoundedDouble(sumEstimate, confidence, low, high) - } - result - } - } -} diff --git a/core/src/main/scala/spark/partial/MeanEvaluator.scala b/core/src/main/scala/spark/partial/MeanEvaluator.scala deleted file mode 100644 index b8c7cb8863539096ec9577e1c43ec1831c545423..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/partial/MeanEvaluator.scala +++ /dev/null @@ -1,41 +0,0 @@ -package spark.partial - -import cern.jet.stat.Probability - -import spark.util.StatCounter - -/** - * An ApproximateEvaluator for means. - */ -class MeanEvaluator(totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[StatCounter, BoundedDouble] { - - var outputsMerged = 0 - var counter = new StatCounter - - override def merge(outputId: Int, taskResult: StatCounter) { - outputsMerged += 1 - counter.merge(taskResult) - } - - override def currentResult(): BoundedDouble = { - if (outputsMerged == totalOutputs) { - new BoundedDouble(counter.mean, 1.0, counter.mean, counter.mean) - } else if (outputsMerged == 0) { - new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) - } else { - val mean = counter.mean - val stdev = math.sqrt(counter.sampleVariance / counter.count) - val confFactor = { - if (counter.count > 100) { - Probability.normalInverse(1 - (1 - confidence) / 2) - } else { - Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt) - } - } - val low = mean - confFactor * stdev - val high = mean + confFactor * stdev - new BoundedDouble(mean, confidence, low, high) - } - } -} diff --git a/core/src/main/scala/spark/partial/PartialResult.scala b/core/src/main/scala/spark/partial/PartialResult.scala deleted file mode 100644 index 7095bc8ca1bbf4d134a3ce01b3cd1826e3a93722..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/partial/PartialResult.scala +++ /dev/null @@ -1,86 +0,0 @@ -package spark.partial - -class PartialResult[R](initialVal: R, isFinal: Boolean) { - private var finalValue: Option[R] = if (isFinal) Some(initialVal) else None - private var failure: Option[Exception] = None - private var completionHandler: Option[R => Unit] = None - private var failureHandler: Option[Exception => Unit] = None - - def initialValue: R = initialVal - - def isInitialValueFinal: Boolean = isFinal - - /** - * Blocking method to wait for and return the final value. - */ - def getFinalValue(): R = synchronized { - while (finalValue == None && failure == None) { - this.wait() - } - if (finalValue != None) { - return finalValue.get - } else { - throw failure.get - } - } - - /** - * Set a handler to be called when this PartialResult completes. Only one completion handler - * is supported per PartialResult. - */ - def onComplete(handler: R => Unit): PartialResult[R] = synchronized { - if (completionHandler != None) { - throw new UnsupportedOperationException("onComplete cannot be called twice") - } - completionHandler = Some(handler) - if (finalValue != None) { - // We already have a final value, so let's call the handler - handler(finalValue.get) - } - return this - } - - /** - * Set a handler to be called if this PartialResult's job fails. Only one failure handler - * is supported per PartialResult. - */ - def onFail(handler: Exception => Unit): Unit = synchronized { - if (failureHandler != None) { - throw new UnsupportedOperationException("onFail cannot be called twice") - } - failureHandler = Some(handler) - if (failure != None) { - // We already have a failure, so let's call the handler - handler(failure.get) - } - } - - private[spark] def setFinalValue(value: R): Unit = synchronized { - if (finalValue != None) { - throw new UnsupportedOperationException("setFinalValue called twice on a PartialResult") - } - finalValue = Some(value) - // Call the completion handler if it was set - completionHandler.foreach(h => h(value)) - // Notify any threads that may be calling getFinalValue() - this.notifyAll() - } - - private[spark] def setFailure(exception: Exception): Unit = synchronized { - if (failure != None) { - throw new UnsupportedOperationException("setFailure called twice on a PartialResult") - } - failure = Some(exception) - // Call the failure handler if it was set - failureHandler.foreach(h => h(exception)) - // Notify any threads that may be calling getFinalValue() - this.notifyAll() - } - - override def toString: String = synchronized { - finalValue match { - case Some(value) => "(final: " + value + ")" - case None => "(partial: " + initialValue + ")" - } - } -} diff --git a/core/src/main/scala/spark/partial/StudentTCacher.scala b/core/src/main/scala/spark/partial/StudentTCacher.scala deleted file mode 100644 index 6263ee3518d8c21beb081d4c26dd0aa837f683d5..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/partial/StudentTCacher.scala +++ /dev/null @@ -1,26 +0,0 @@ -package spark.partial - -import cern.jet.stat.Probability - -/** - * A utility class for caching Student's T distribution values for a given confidence level - * and various sample sizes. This is used by the MeanEvaluator to efficiently calculate - * confidence intervals for many keys. - */ -class StudentTCacher(confidence: Double) { - val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation - val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2) - val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0) - - def get(sampleSize: Long): Double = { - if (sampleSize >= NORMAL_APPROX_SAMPLE_SIZE) { - normalApprox - } else { - val size = sampleSize.toInt - if (cache(size) < 0) { - cache(size) = Probability.studentTInverse(1 - confidence, size - 1) - } - cache(size) - } - } -} diff --git a/core/src/main/scala/spark/partial/SumEvaluator.scala b/core/src/main/scala/spark/partial/SumEvaluator.scala deleted file mode 100644 index 0357a6bff860a78729f759d44ff63feae76236fa..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/partial/SumEvaluator.scala +++ /dev/null @@ -1,51 +0,0 @@ -package spark.partial - -import cern.jet.stat.Probability - -import spark.util.StatCounter - -/** - * An ApproximateEvaluator for sums. It estimates the mean and the cont and multiplies them - * together, then uses the formula for the variance of two independent random variables to get - * a variance for the result and compute a confidence interval. - */ -class SumEvaluator(totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[StatCounter, BoundedDouble] { - - var outputsMerged = 0 - var counter = new StatCounter - - override def merge(outputId: Int, taskResult: StatCounter) { - outputsMerged += 1 - counter.merge(taskResult) - } - - override def currentResult(): BoundedDouble = { - if (outputsMerged == totalOutputs) { - new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum) - } else if (outputsMerged == 0) { - new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) - } else { - val p = outputsMerged.toDouble / totalOutputs - val meanEstimate = counter.mean - val meanVar = counter.sampleVariance / counter.count - val countEstimate = (counter.count + 1 - p) / p - val countVar = (counter.count + 1) * (1 - p) / (p * p) - val sumEstimate = meanEstimate * countEstimate - val sumVar = (meanEstimate * meanEstimate * countVar) + - (countEstimate * countEstimate * meanVar) + - (meanVar * countVar) - val sumStdev = math.sqrt(sumVar) - val confFactor = { - if (counter.count > 100) { - Probability.normalInverse(1 - (1 - confidence) / 2) - } else { - Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt) - } - } - val low = sumEstimate - confFactor * sumStdev - val high = sumEstimate + confFactor * sumStdev - new BoundedDouble(sumEstimate, confidence, low, high) - } - } -} diff --git a/core/src/main/scala/spark/scheduler/ActiveJob.scala b/core/src/main/scala/spark/scheduler/ActiveJob.scala deleted file mode 100644 index 0ecff9ce77ea773c30d9947a342327d2bf88fa29..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/scheduler/ActiveJob.scala +++ /dev/null @@ -1,18 +0,0 @@ -package spark.scheduler - -import spark.TaskContext - -/** - * Tracks information about an active job in the DAGScheduler. - */ -class ActiveJob( - val runId: Int, - val finalStage: Stage, - val func: (TaskContext, Iterator[_]) => _, - val partitions: Array[Int], - val listener: JobListener) { - - val numPartitions = partitions.length - val finished = Array.fill[Boolean](numPartitions)(false) - var numFinished = 0 -} diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala deleted file mode 100644 index f9d53d3b5d4457a975b696552af87e3ded3f7bc5..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ /dev/null @@ -1,535 +0,0 @@ -package spark.scheduler - -import java.net.URI -import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.Future -import java.util.concurrent.LinkedBlockingQueue -import java.util.concurrent.TimeUnit - -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map} - -import spark._ -import spark.partial.ApproximateActionListener -import spark.partial.ApproximateEvaluator -import spark.partial.PartialResult -import spark.storage.BlockManagerMaster -import spark.storage.BlockManagerId - -/** - * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for - * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal - * schedule to run the job. Subclasses only need to implement the code to send a task to the cluster - * and to report fetch failures (the submitTasks method, and code to add CompletionEvents). - */ -class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging { - taskSched.setListener(this) - - // Called by TaskScheduler to report task completions or failures. - override def taskEnded( - task: Task[_], - reason: TaskEndReason, - result: Any, - accumUpdates: Map[Long, Any]) { - eventQueue.put(CompletionEvent(task, reason, result, accumUpdates)) - } - - // Called by TaskScheduler when a host fails. - override def hostLost(host: String) { - eventQueue.put(HostLost(host)) - } - - // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; - // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one - // as more failure events come in - val RESUBMIT_TIMEOUT = 50L - - // The time, in millis, to wake up between polls of the completion queue in order to potentially - // resubmit failed stages - val POLL_TIMEOUT = 10L - - private val lock = new Object // Used for access to the entire DAGScheduler - - private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent] - - val nextRunId = new AtomicInteger(0) - - val nextStageId = new AtomicInteger(0) - - val idToStage = new HashMap[Int, Stage] - - val shuffleToMapStage = new HashMap[Int, Stage] - - var cacheLocs = new HashMap[Int, Array[List[String]]] - - val env = SparkEnv.get - val cacheTracker = env.cacheTracker - val mapOutputTracker = env.mapOutputTracker - - val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back; - // that's not going to be a realistic assumption in general - - val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done - val running = new HashSet[Stage] // Stages we are running right now - val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures - val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage - var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits - - val activeJobs = new HashSet[ActiveJob] - val resultStageToJob = new HashMap[Stage, ActiveJob] - - // Start a thread to run the DAGScheduler event loop - new Thread("DAGScheduler") { - setDaemon(true) - override def run() { - DAGScheduler.this.run() - } - }.start() - - def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { - cacheLocs(rdd.id) - } - - def updateCacheLocs() { - cacheLocs = cacheTracker.getLocationsSnapshot() - } - - /** - * Get or create a shuffle map stage for the given shuffle dependency's map side. - * The priority value passed in will be used if the stage doesn't already exist with - * a lower priority (we assume that priorities always increase across jobs for now). - */ - def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_,_], priority: Int): Stage = { - shuffleToMapStage.get(shuffleDep.shuffleId) match { - case Some(stage) => stage - case None => - val stage = newStage(shuffleDep.rdd, Some(shuffleDep), priority) - shuffleToMapStage(shuffleDep.shuffleId) = stage - stage - } - } - - /** - * Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or - * as a result stage for the final RDD used directly in an action. The stage will also be given - * the provided priority. - */ - def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]], priority: Int): Stage = { - // Kind of ugly: need to register RDDs with the cache and map output tracker here - // since we can't do it in the RDD constructor because # of splits is unknown - logInfo("Registering RDD " + rdd.id + ": " + rdd) - cacheTracker.registerRDD(rdd.id, rdd.splits.size) - if (shuffleDep != None) { - mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size) - } - val id = nextStageId.getAndIncrement() - val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority) - idToStage(id) = stage - stage - } - - /** - * Get or create the list of parent stages for a given RDD. The stages will be assigned the - * provided priority if they haven't already been created with a lower priority. - */ - def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = { - val parents = new HashSet[Stage] - val visited = new HashSet[RDD[_]] - def visit(r: RDD[_]) { - if (!visited(r)) { - visited += r - // Kind of ugly: need to register RDDs with the cache here since - // we can't do it in its constructor because # of splits is unknown - logInfo("Registering parent RDD " + r.id + ": " + r) - cacheTracker.registerRDD(r.id, r.splits.size) - for (dep <- r.dependencies) { - dep match { - case shufDep: ShuffleDependency[_,_,_] => - parents += getShuffleMapStage(shufDep, priority) - case _ => - visit(dep.rdd) - } - } - } - } - visit(rdd) - parents.toList - } - - def getMissingParentStages(stage: Stage): List[Stage] = { - val missing = new HashSet[Stage] - val visited = new HashSet[RDD[_]] - def visit(rdd: RDD[_]) { - if (!visited(rdd)) { - visited += rdd - val locs = getCacheLocs(rdd) - for (p <- 0 until rdd.splits.size) { - if (locs(p) == Nil) { - for (dep <- rdd.dependencies) { - dep match { - case shufDep: ShuffleDependency[_,_,_] => - val mapStage = getShuffleMapStage(shufDep, stage.priority) - if (!mapStage.isAvailable) { - missing += mapStage - } - case narrowDep: NarrowDependency[_] => - visit(narrowDep.rdd) - } - } - } - } - } - } - visit(stage.rdd) - missing.toList - } - - def runJob[T, U]( - finalRdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - partitions: Seq[Int], - allowLocal: Boolean) - (implicit m: ClassManifest[U]): Array[U] = - { - if (partitions.size == 0) { - return new Array[U](0) - } - val waiter = new JobWaiter(partitions.size) - val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, waiter)) - waiter.getResult() match { - case JobSucceeded(results: Seq[_]) => - return results.asInstanceOf[Seq[U]].toArray - case JobFailed(exception: Exception) => - throw exception - } - } - - def runApproximateJob[T, U, R]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - evaluator: ApproximateEvaluator[U, R], - timeout: Long - ): PartialResult[R] = - { - val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) - val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val partitions = (0 until rdd.splits.size).toArray - eventQueue.put(JobSubmitted(rdd, func2, partitions, false, listener)) - return listener.getResult() // Will throw an exception if the job fails - } - - /** - * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure - * events and responds by launching tasks. This runs in a dedicated thread and receives events - * via the eventQueue. - */ - def run() = { - SparkEnv.set(env) - - while (true) { - val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS) - val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability - if (event != null) { - logDebug("Got event of type " + event.getClass.getName) - } - - event match { - case JobSubmitted(finalRDD, func, partitions, allowLocal, listener) => - val runId = nextRunId.getAndIncrement() - val finalStage = newStage(finalRDD, None, runId) - val job = new ActiveJob(runId, finalStage, func, partitions, listener) - updateCacheLocs() - logInfo("Got job " + job.runId + " with " + partitions.length + " output partitions") - logInfo("Final stage: " + finalStage) - logInfo("Parents of final stage: " + finalStage.parents) - logInfo("Missing parents: " + getMissingParentStages(finalStage)) - if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) { - // Compute very short actions like first() or take() with no parent stages locally. - runLocally(job) - } else { - activeJobs += job - resultStageToJob(finalStage) = job - submitStage(finalStage) - } - - case HostLost(host) => - handleHostLost(host) - - case completion: CompletionEvent => - handleTaskCompletion(completion) - - case null => - // queue.poll() timed out, ignore it - } - - // Periodically resubmit failed stages if some map output fetches have failed and we have - // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails, - // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at - // the same time, so we want to make sure we've identified all the reduce tasks that depend - // on the failed node. - if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) { - logInfo("Resubmitting failed stages") - updateCacheLocs() - val failed2 = failed.toArray - failed.clear() - for (stage <- failed2.sortBy(_.priority)) { - submitStage(stage) - } - } else { - // TODO: We might want to run this less often, when we are sure that something has become - // runnable that wasn't before. - logDebug("Checking for newly runnable parent stages") - logDebug("running: " + running) - logDebug("waiting: " + waiting) - logDebug("failed: " + failed) - val waiting2 = waiting.toArray - waiting.clear() - for (stage <- waiting2.sortBy(_.priority)) { - submitStage(stage) - } - } - } - } - - /** - * Run a job on an RDD locally, assuming it has only a single partition and no dependencies. - * We run the operation in a separate thread just in case it takes a bunch of time, so that we - * don't block the DAGScheduler event loop or other concurrent jobs. - */ - def runLocally(job: ActiveJob) { - logInfo("Computing the requested partition locally") - new Thread("Local computation of job " + job.runId) { - override def run() { - try { - SparkEnv.set(env) - val rdd = job.finalStage.rdd - val split = rdd.splits(job.partitions(0)) - val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0) - val result = job.func(taskContext, rdd.iterator(split)) - job.listener.taskSucceeded(0, result) - } catch { - case e: Exception => - job.listener.jobFailed(e) - } - } - }.start() - } - - def submitStage(stage: Stage) { - logDebug("submitStage(" + stage + ")") - if (!waiting(stage) && !running(stage) && !failed(stage)) { - val missing = getMissingParentStages(stage).sortBy(_.id) - logDebug("missing: " + missing) - if (missing == Nil) { - logInfo("Submitting " + stage + ", which has no missing parents") - submitMissingTasks(stage) - running += stage - } else { - for (parent <- missing) { - submitStage(parent) - } - waiting += stage - } - } - } - - def submitMissingTasks(stage: Stage) { - logDebug("submitMissingTasks(" + stage + ")") - // Get our pending tasks and remember them in our pendingTasks entry - val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet) - myPending.clear() - var tasks = ArrayBuffer[Task[_]]() - if (stage.isShuffleMap) { - for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { - val locs = getPreferredLocs(stage.rdd, p) - tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs) - } - } else { - // This is a final stage; figure out its job's missing partitions - val job = resultStageToJob(stage) - for (id <- 0 until job.numPartitions if (!job.finished(id))) { - val partition = job.partitions(id) - val locs = getPreferredLocs(stage.rdd, partition) - tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id) - } - } - if (tasks.size > 0) { - logInfo("Submitting " + tasks.size + " missing tasks from " + stage) - myPending ++= tasks - logDebug("New pending tasks: " + myPending) - taskSched.submitTasks( - new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority)) - } else { - logDebug("Stage " + stage + " is actually done; %b %d %d".format( - stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) - running -= stage - } - } - - /** - * Responds to a task finishing. This is called inside the event loop so it assumes that it can - * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. - */ - def handleTaskCompletion(event: CompletionEvent) { - val task = event.task - val stage = idToStage(task.stageId) - event.reason match { - case Success => - logInfo("Completed " + task) - if (event.accumUpdates != null) { - Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted - } - pendingTasks(stage) -= task - task match { - case rt: ResultTask[_, _] => - resultStageToJob.get(stage) match { - case Some(job) => - if (!job.finished(rt.outputId)) { - job.finished(rt.outputId) = true - job.numFinished += 1 - job.listener.taskSucceeded(rt.outputId, event.result) - // If the whole job has finished, remove it - if (job.numFinished == job.numPartitions) { - activeJobs -= job - resultStageToJob -= stage - running -= stage - } - } - case None => - logInfo("Ignoring result from " + rt + " because its job has finished") - } - - case smt: ShuffleMapTask => - val stage = idToStage(smt.stageId) - val bmAddress = event.result.asInstanceOf[BlockManagerId] - val host = bmAddress.ip - logInfo("ShuffleMapTask finished with host " + host) - if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos - stage.addOutputLoc(smt.partition, bmAddress) - } - if (running.contains(stage) && pendingTasks(stage).isEmpty) { - logInfo(stage + " finished; looking for newly runnable stages") - running -= stage - logInfo("running: " + running) - logInfo("waiting: " + waiting) - logInfo("failed: " + failed) - if (stage.shuffleDep != None) { - mapOutputTracker.registerMapOutputs( - stage.shuffleDep.get.shuffleId, - stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray) - } - updateCacheLocs() - if (stage.outputLocs.count(_ == Nil) != 0) { - // Some tasks had failed; let's resubmit this stage - // TODO: Lower-level scheduler should also deal with this - logInfo("Resubmitting " + stage + " because some of its tasks had failed: " + - stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", ")) - submitStage(stage) - } else { - val newlyRunnable = new ArrayBuffer[Stage] - for (stage <- waiting) { - logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage)) - } - for (stage <- waiting if getMissingParentStages(stage) == Nil) { - newlyRunnable += stage - } - waiting --= newlyRunnable - running ++= newlyRunnable - for (stage <- newlyRunnable.sortBy(_.id)) { - submitMissingTasks(stage) - } - } - } - } - - case Resubmitted => - logInfo("Resubmitted " + task + ", so marking it as still running") - pendingTasks(stage) += task - - case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => - // Mark the stage that the reducer was in as unrunnable - val failedStage = idToStage(task.stageId) - running -= failedStage - failed += failedStage - // TODO: Cancel running tasks in the stage - logInfo("Marking " + failedStage + " for resubmision due to a fetch failure") - // Mark the map whose fetch failed as broken in the map stage - val mapStage = shuffleToMapStage(shuffleId) - mapStage.removeOutputLoc(mapId, bmAddress) - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) - logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission") - failed += mapStage - // Remember that a fetch failed now; this is used to resubmit the broken - // stages later, after a small wait (to give other tasks the chance to fail) - lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock - // TODO: mark the host as failed only if there were lots of fetch failures on it - if (bmAddress != null) { - handleHostLost(bmAddress.ip) - } - - case _ => - // Non-fetch failure -- probably a bug in the job, so bail out - // TODO: Cancel all tasks that are still running - resultStageToJob.get(stage) match { - case Some(job) => - val error = new SparkException("Task failed: " + task + ", reason: " + event.reason) - job.listener.jobFailed(error) - activeJobs -= job - resultStageToJob -= stage - case None => - logInfo("Ignoring result from " + task + " because its job has finished") - } - } - } - - /** - * Responds to a host being lost. This is called inside the event loop so it assumes that it can - * modify the scheduler's internal state. Use hostLost() to post a host lost event from outside. - */ - def handleHostLost(host: String) { - if (!deadHosts.contains(host)) { - logInfo("Host lost: " + host) - deadHosts += host - BlockManagerMaster.notifyADeadHost(host) - // TODO: This will be really slow if we keep accumulating shuffle map stages - for ((shuffleId, stage) <- shuffleToMapStage) { - stage.removeOutputsOnHost(host) - val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray - mapOutputTracker.registerMapOutputs(shuffleId, locs, true) - } - cacheTracker.cacheLost(host) - updateCacheLocs() - } - } - - def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = { - // If the partition is cached, return the cache locations - val cached = getCacheLocs(rdd)(partition) - if (cached != Nil) { - return cached - } - // If the RDD has some placement preferences (as is the case for input RDDs), get those - val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList - if (rddPrefs != Nil) { - return rddPrefs - } - // If the RDD has narrow dependencies, pick the first partition of the first narrow dep - // that has any placement preferences. Ideally we would choose based on transfer sizes, - // but this will do for now. - rdd.dependencies.foreach(_ match { - case n: NarrowDependency[_] => - for (inPart <- n.getParents(partition)) { - val locs = getPreferredLocs(n.rdd, inPart) - if (locs != Nil) - return locs; - } - case _ => - }) - return Nil - } - - def stop() { - // TODO: Put a stop event on our queue and break the event loop - taskSched.stop() - } -} diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala deleted file mode 100644 index c10abc92028993d9200676d60139493ee5df5f62..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala +++ /dev/null @@ -1,30 +0,0 @@ -package spark.scheduler - -import scala.collection.mutable.Map - -import spark._ - -/** - * Types of events that can be handled by the DAGScheduler. The DAGScheduler uses an event queue - * architecture where any thread can post an event (e.g. a task finishing or a new job being - * submitted) but there is a single "logic" thread that reads these events and takes decisions. - * This greatly simplifies synchronization. - */ -sealed trait DAGSchedulerEvent - -case class JobSubmitted( - finalRDD: RDD[_], - func: (TaskContext, Iterator[_]) => _, - partitions: Array[Int], - allowLocal: Boolean, - listener: JobListener) - extends DAGSchedulerEvent - -case class CompletionEvent( - task: Task[_], - reason: TaskEndReason, - result: Any, - accumUpdates: Map[Long, Any]) - extends DAGSchedulerEvent - -case class HostLost(host: String) extends DAGSchedulerEvent diff --git a/core/src/main/scala/spark/scheduler/JobListener.scala b/core/src/main/scala/spark/scheduler/JobListener.scala deleted file mode 100644 index d4dd536a7de553f92d3c8a506df39805bb89d77f..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/scheduler/JobListener.scala +++ /dev/null @@ -1,11 +0,0 @@ -package spark.scheduler - -/** - * Interface used to listen for job completion or failure events after submitting a job to the - * DAGScheduler. The listener is notified each time a task succeeds, as well as if the whole - * job fails (and no further taskSucceeded events will happen). - */ -trait JobListener { - def taskSucceeded(index: Int, result: Any) - def jobFailed(exception: Exception) -} diff --git a/core/src/main/scala/spark/scheduler/JobResult.scala b/core/src/main/scala/spark/scheduler/JobResult.scala deleted file mode 100644 index 62b458eccbd22822592b236ba2c67ad15c4a2b4b..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/scheduler/JobResult.scala +++ /dev/null @@ -1,9 +0,0 @@ -package spark.scheduler - -/** - * A result of a job in the DAGScheduler. - */ -sealed trait JobResult - -case class JobSucceeded(results: Seq[_]) extends JobResult -case class JobFailed(exception: Exception) extends JobResult diff --git a/core/src/main/scala/spark/scheduler/JobWaiter.scala b/core/src/main/scala/spark/scheduler/JobWaiter.scala deleted file mode 100644 index be8ec9bd7b07e9d8ac8e986ae9a20b575b9bbd0c..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/scheduler/JobWaiter.scala +++ /dev/null @@ -1,43 +0,0 @@ -package spark.scheduler - -import scala.collection.mutable.ArrayBuffer - -/** - * An object that waits for a DAGScheduler job to complete. - */ -class JobWaiter(totalTasks: Int) extends JobListener { - private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null) - private var finishedTasks = 0 - - private var jobFinished = false // Is the job as a whole finished (succeeded or failed)? - private var jobResult: JobResult = null // If the job is finished, this will be its result - - override def taskSucceeded(index: Int, result: Any) = synchronized { - if (jobFinished) { - throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") - } - taskResults(index) = result - finishedTasks += 1 - if (finishedTasks == totalTasks) { - jobFinished = true - jobResult = JobSucceeded(taskResults) - this.notifyAll() - } - } - - override def jobFailed(exception: Exception) = synchronized { - if (jobFinished) { - throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter") - } - jobFinished = true - jobResult = JobFailed(exception) - this.notifyAll() - } - - def getResult(): JobResult = synchronized { - while (!jobFinished) { - this.wait() - } - return jobResult - } -} diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala deleted file mode 100644 index 79cca0f294593154d1d667debb261db4ad836974..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ /dev/null @@ -1,142 +0,0 @@ -package spark.scheduler - -import java.io._ -import java.util.HashMap -import java.util.zip.{GZIPInputStream, GZIPOutputStream} - -import scala.collection.mutable.ArrayBuffer - -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream - -import com.ning.compress.lzf.LZFInputStream -import com.ning.compress.lzf.LZFOutputStream - -import spark._ -import spark.storage._ - -object ShuffleMapTask { - val serializedInfoCache = new HashMap[Int, Array[Byte]] - val deserializedInfoCache = new HashMap[Int, (RDD[_], ShuffleDependency[_,_,_])] - - def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_]): Array[Byte] = { - synchronized { - val old = serializedInfoCache.get(stageId) - if (old != null) { - return old - } else { - val out = new ByteArrayOutputStream - val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) - objOut.writeObject(rdd) - objOut.writeObject(dep) - objOut.close() - val bytes = out.toByteArray - serializedInfoCache.put(stageId, bytes) - return bytes - } - } - } - - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = { - synchronized { - val old = deserializedInfoCache.get(stageId) - if (old != null) { - return old - } else { - val loader = currentThread.getContextClassLoader - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val objIn = new ObjectInputStream(in) { - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, loader) - } - val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_,_]] - val tuple = (rdd, dep) - deserializedInfoCache.put(stageId, tuple) - return tuple - } - } - } - - def clearCache() { - synchronized { - serializedInfoCache.clear() - deserializedInfoCache.clear() - } - } -} - -class ShuffleMapTask( - stageId: Int, - var rdd: RDD[_], - var dep: ShuffleDependency[_,_,_], - var partition: Int, - @transient var locs: Seq[String]) - extends Task[BlockManagerId](stageId) - with Externalizable - with Logging { - - def this() = this(0, null, null, 0, null) - - var split = if (rdd == null) { - null - } else { - rdd.splits(partition) - } - - override def writeExternal(out: ObjectOutput) { - out.writeInt(stageId) - val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) - out.writeInt(bytes.length) - out.write(bytes) - out.writeInt(partition) - out.writeObject(split) - } - - override def readExternal(in: ObjectInput) { - val stageId = in.readInt() - val numBytes = in.readInt() - val bytes = new Array[Byte](numBytes) - in.readFully(bytes) - val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes) - rdd = rdd_ - dep = dep_ - partition = in.readInt() - split = in.readObject().asInstanceOf[Split] - } - - override def run(attemptId: Int): BlockManagerId = { - val numOutputSplits = dep.partitioner.numPartitions - val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]] - val partitioner = dep.partitioner.asInstanceOf[Partitioner] - val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any]) - for (elem <- rdd.iterator(split)) { - val (k, v) = elem.asInstanceOf[(Any, Any)] - var bucketId = partitioner.getPartition(k) - val bucket = buckets(bucketId) - var existing = bucket.get(k) - if (existing == null) { - bucket.put(k, aggregator.createCombiner(v)) - } else { - bucket.put(k, aggregator.mergeValue(existing, v)) - } - } - val ser = SparkEnv.get.serializer.newInstance() - val blockManager = SparkEnv.get.blockManager - for (i <- 0 until numOutputSplits) { - val blockId = "shuffleid_" + dep.shuffleId + "_" + partition + "_" + i - val arr = new ArrayBuffer[Any] - val iter = buckets(i).entrySet().iterator() - while (iter.hasNext()) { - val entry = iter.next() - arr += ((entry.getKey(), entry.getValue())) - } - // TODO: This should probably be DISK_ONLY - blockManager.put(blockId, arr.iterator, StorageLevel.MEMORY_ONLY, false) - } - return SparkEnv.get.blockManager.blockManagerId - } - - override def preferredLocations: Seq[String] = locs - - override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition) -} diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala deleted file mode 100644 index cd660c9085a751193bcc99cc93c3499276b7b72a..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/scheduler/Stage.scala +++ /dev/null @@ -1,86 +0,0 @@ -package spark.scheduler - -import java.net.URI - -import spark._ -import spark.storage.BlockManagerId - -/** - * A stage is a set of independent tasks all computing the same function that need to run as part - * of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run - * by the scheduler is split up into stages at the boundaries where shuffle occurs, and then the - * DAGScheduler runs these stages in topological order. - * - * Each Stage can either be a shuffle map stage, in which case its tasks' results are input for - * another stage, or a result stage, in which case its tasks directly compute the action that - * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes - * that each output partition is on. - * - * Each Stage also has a priority, which is (by default) based on the job it was submitted in. - * This allows Stages from earlier jobs to be computed first or recovered faster on failure. - */ -class Stage( - val id: Int, - val rdd: RDD[_], - val shuffleDep: Option[ShuffleDependency[_,_,_]], // Output shuffle if stage is a map stage - val parents: List[Stage], - val priority: Int) - extends Logging { - - val isShuffleMap = shuffleDep != None - val numPartitions = rdd.splits.size - val outputLocs = Array.fill[List[BlockManagerId]](numPartitions)(Nil) - var numAvailableOutputs = 0 - - private var nextAttemptId = 0 - - def isAvailable: Boolean = { - if (/*parents.size == 0 &&*/ !isShuffleMap) { - true - } else { - numAvailableOutputs == numPartitions - } - } - - def addOutputLoc(partition: Int, bmAddress: BlockManagerId) { - val prevList = outputLocs(partition) - outputLocs(partition) = bmAddress :: prevList - if (prevList == Nil) - numAvailableOutputs += 1 - } - - def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_ == bmAddress) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - numAvailableOutputs -= 1 - } - } - - def removeOutputsOnHost(host: String) { - var becameUnavailable = false - for (partition <- 0 until numPartitions) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_.ip == host) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - becameUnavailable = true - numAvailableOutputs -= 1 - } - } - if (becameUnavailable) { - logInfo("%s is now unavailable on %s (%d/%d, %s)".format(this, host, numAvailableOutputs, numPartitions, isAvailable)) - } - } - - def newAttemptId(): Int = { - val id = nextAttemptId - nextAttemptId += 1 - return id - } - - override def toString = "Stage " + id // + ": [RDD = " + rdd.id + ", isShuffle = " + isShuffleMap + "]" - - override def hashCode(): Int = id -} diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala deleted file mode 100644 index 42325956baa51cf1681799ad9a2b82531a7ef4ce..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/scheduler/Task.scala +++ /dev/null @@ -1,11 +0,0 @@ -package spark.scheduler - -/** - * A task to execute on a worker node. - */ -abstract class Task[T](val stageId: Int) extends Serializable { - def run(attemptId: Int): T - def preferredLocations: Seq[String] = Nil - - var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler. -} diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala deleted file mode 100644 index 868ddb237c0a23ca8f55d443df8a2473f1604ddd..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/scheduler/TaskResult.scala +++ /dev/null @@ -1,34 +0,0 @@ -package spark.scheduler - -import java.io._ - -import scala.collection.mutable.Map - -// Task result. Also contains updates to accumulator variables. -// TODO: Use of distributed cache to return result is a hack to get around -// what seems to be a bug with messages over 60KB in libprocess; fix it -class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any]) extends Externalizable { - def this() = this(null.asInstanceOf[T], null) - - override def writeExternal(out: ObjectOutput) { - out.writeObject(value) - out.writeInt(accumUpdates.size) - for ((key, value) <- accumUpdates) { - out.writeLong(key) - out.writeObject(value) - } - } - - override def readExternal(in: ObjectInput) { - value = in.readObject().asInstanceOf[T] - val numUpdates = in.readInt - if (numUpdates == 0) { - accumUpdates = null - } else { - accumUpdates = Map() - for (i <- 0 until numUpdates) { - accumUpdates(in.readLong()) = in.readObject() - } - } - } -} diff --git a/core/src/main/scala/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/spark/scheduler/TaskScheduler.scala deleted file mode 100644 index cb7c375d97e09e07c022fc3dcca238971efbf425..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/scheduler/TaskScheduler.scala +++ /dev/null @@ -1,27 +0,0 @@ -package spark.scheduler - -/** - * Low-level task scheduler interface, implemented by both MesosScheduler and LocalScheduler. - * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage, - * and are responsible for sending the tasks to the cluster, running them, retrying if there - * are failures, and mitigating stragglers. They return events to the DAGScheduler through - * the TaskSchedulerListener interface. - */ -trait TaskScheduler { - def start(): Unit - - // Wait for registration with Mesos. - def waitForRegister(): Unit - - // Disconnect from the cluster. - def stop(): Unit - - // Submit a sequence of tasks to run. - def submitTasks(taskSet: TaskSet): Unit - - // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called. - def setListener(listener: TaskSchedulerListener): Unit - - // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. - def defaultParallelism(): Int -} diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala deleted file mode 100644 index a647eec9e477831f5c77b84f05344efaaa7ec2d5..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala +++ /dev/null @@ -1,16 +0,0 @@ -package spark.scheduler - -import scala.collection.mutable.Map - -import spark.TaskEndReason - -/** - * Interface for getting events back from the TaskScheduler. - */ -trait TaskSchedulerListener { - // A task has finished or failed. - def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit - - // A node was lost from the cluster. - def hostLost(host: String): Unit -} diff --git a/core/src/main/scala/spark/scheduler/TaskSet.scala b/core/src/main/scala/spark/scheduler/TaskSet.scala deleted file mode 100644 index 6f29dd2e9d6dd0688c3a9ac4a38f3fae4fcddb4e..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/scheduler/TaskSet.scala +++ /dev/null @@ -1,9 +0,0 @@ -package spark.scheduler - -/** - * A set of tasks submitted together to the low-level TaskScheduler, usually representing - * missing partitions of a particular stage. - */ -class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int) { - val id: String = stageId + "." + attempt -} diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala deleted file mode 100644 index 8182901ce3abb6d80b5f8bbcf1008098fd44b304..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala +++ /dev/null @@ -1,364 +0,0 @@ -package spark.scheduler.mesos - -import java.io.{File, FileInputStream, FileOutputStream} -import java.util.{ArrayList => JArrayList} -import java.util.{List => JList} -import java.util.{HashMap => JHashMap} -import java.util.concurrent._ - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.collection.mutable.Map -import scala.collection.mutable.PriorityQueue -import scala.collection.JavaConversions._ -import scala.math.Ordering - -import akka.actor._ -import akka.actor.Actor -import akka.actor.Actor._ -import akka.actor.Channel -import akka.serialization.RemoteActorSerialization._ - -import com.google.protobuf.ByteString - -import org.apache.mesos.{Scheduler => MScheduler} -import org.apache.mesos._ -import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _} - -import spark._ -import spark.scheduler._ - -sealed trait CoarseMesosSchedulerMessage -case class RegisterSlave(slaveId: String, host: String, port: Int) extends CoarseMesosSchedulerMessage -case class StatusUpdate(slaveId: String, status: TaskStatus) extends CoarseMesosSchedulerMessage -case class LaunchTask(slaveId: String, task: MTaskInfo) extends CoarseMesosSchedulerMessage -case class ReviveOffers() extends CoarseMesosSchedulerMessage - -case class FakeOffer(slaveId: String, host: String, cores: Int) - -/** - * Mesos scheduler that uses coarse-grained tasks and does its own fine-grained scheduling inside - * them using Akka actors for messaging. Clients should first call start(), then submit task sets - * through the runTasks method. - * - * TODO: This is a pretty big hack for now. - */ -class CoarseMesosScheduler( - sc: SparkContext, - master: String, - frameworkName: String) - extends MesosScheduler(sc, master, frameworkName) { - - val CORES_PER_SLAVE = System.getProperty("spark.coarseMesosScheduler.coresPerSlave", "4").toInt - - class MasterActor extends Actor { - val slaveActor = new HashMap[String, ActorRef] - val slaveHost = new HashMap[String, String] - val freeCores = new HashMap[String, Int] - - def receive = { - case RegisterSlave(slaveId, host, port) => - slaveActor(slaveId) = remote.actorFor("WorkerActor", host, port) - logInfo("Slave actor: " + slaveActor(slaveId)) - slaveHost(slaveId) = host - freeCores(slaveId) = CORES_PER_SLAVE - makeFakeOffers() - - case StatusUpdate(slaveId, status) => - fakeStatusUpdate(status) - if (isFinished(status.getState)) { - freeCores(slaveId) += 1 - makeFakeOffers(slaveId) - } - - case LaunchTask(slaveId, task) => - freeCores(slaveId) -= 1 - slaveActor(slaveId) ! LaunchTask(slaveId, task) - - case ReviveOffers() => - logInfo("Reviving offers") - makeFakeOffers() - } - - // Make fake resource offers for all slaves - def makeFakeOffers() { - fakeResourceOffers(slaveHost.toSeq.map{case (id, host) => FakeOffer(id, host, freeCores(id))}) - } - - // Make fake resource offers for all slaves - def makeFakeOffers(slaveId: String) { - fakeResourceOffers(Seq(FakeOffer(slaveId, slaveHost(slaveId), freeCores(slaveId)))) - } - } - - val masterActor: ActorRef = actorOf(new MasterActor) - remote.register("MasterActor", masterActor) - masterActor.start() - - val taskIdsOnSlave = new HashMap[String, HashSet[String]] - - /** - * Method called by Mesos to offer resources on slaves. We resond by asking our active task sets - * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that - * tasks are balanced across the cluster. - */ - override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { - synchronized { - val tasks = offers.map(o => new JArrayList[MTaskInfo]) - for (i <- 0 until offers.size) { - val o = offers.get(i) - val slaveId = o.getSlaveId.getValue - if (!slaveIdToHost.contains(slaveId)) { - slaveIdToHost(slaveId) = o.getHostname - hostsAlive += o.getHostname - taskIdsOnSlave(slaveId) = new HashSet[String] - // Launch an infinite task on the node that will talk to the MasterActor to get fake tasks - val cpuRes = Resource.newBuilder() - .setName("cpus") - .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(1).build()) - .build() - val task = new WorkerTask(slaveId, o.getHostname) - val serializedTask = Utils.serialize(task) - tasks(i).add(MTaskInfo.newBuilder() - .setTaskId(newTaskId()) - .setSlaveId(o.getSlaveId) - .setExecutor(executorInfo) - .setName("worker task") - .addResources(cpuRes) - .setData(ByteString.copyFrom(serializedTask)) - .build()) - } - } - val filters = Filters.newBuilder().setRefuseSeconds(10).build() - for (i <- 0 until offers.size) { - d.launchTasks(offers(i).getId(), tasks(i), filters) - } - } - } - - override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { - val tid = status.getTaskId.getValue - var taskSetToUpdate: Option[TaskSetManager] = None - var taskFailed = false - synchronized { - try { - taskIdToTaskSetId.get(tid) match { - case Some(taskSetId) => - if (activeTaskSets.contains(taskSetId)) { - //activeTaskSets(taskSetId).statusUpdate(status) - taskSetToUpdate = Some(activeTaskSets(taskSetId)) - } - if (isFinished(status.getState)) { - taskIdToTaskSetId.remove(tid) - if (taskSetTaskIds.contains(taskSetId)) { - taskSetTaskIds(taskSetId) -= tid - } - val slaveId = taskIdToSlaveId(tid) - taskIdToSlaveId -= tid - taskIdsOnSlave(slaveId) -= tid - } - if (status.getState == TaskState.TASK_FAILED) { - taskFailed = true - } - case None => - logInfo("Ignoring update from TID " + tid + " because its task set is gone") - } - } catch { - case e: Exception => logError("Exception in statusUpdate", e) - } - } - // Update the task set and DAGScheduler without holding a lock on this, because that can deadlock - if (taskSetToUpdate != None) { - taskSetToUpdate.get.statusUpdate(status) - } - if (taskFailed) { - // Revive offers if a task had failed for some reason other than host lost - reviveOffers() - } - } - - override def slaveLost(d: SchedulerDriver, s: SlaveID) { - logInfo("Slave lost: " + s.getValue) - var failedHost: Option[String] = None - var lostTids: Option[HashSet[String]] = None - synchronized { - val slaveId = s.getValue - val host = slaveIdToHost(slaveId) - if (hostsAlive.contains(host)) { - slaveIdsWithExecutors -= slaveId - hostsAlive -= host - failedHost = Some(host) - lostTids = Some(taskIdsOnSlave(slaveId)) - logInfo("failedHost: " + host) - logInfo("lostTids: " + lostTids) - taskIdsOnSlave -= slaveId - activeTaskSetsQueue.foreach(_.hostLost(host)) - } - } - if (failedHost != None) { - // Report all the tasks on the failed host as lost, without holding a lock on this - for (tid <- lostTids.get; taskSetId <- taskIdToTaskSetId.get(tid)) { - // TODO: Maybe call our statusUpdate() instead to clean our internal data structures - activeTaskSets(taskSetId).statusUpdate(TaskStatus.newBuilder() - .setTaskId(TaskID.newBuilder().setValue(tid).build()) - .setState(TaskState.TASK_LOST) - .build()) - } - // Also report the loss to the DAGScheduler - listener.hostLost(failedHost.get) - reviveOffers(); - } - } - - override def offerRescinded(d: SchedulerDriver, o: OfferID) {} - - // Check for speculatable tasks in all our active jobs. - override def checkSpeculatableTasks() { - var shouldRevive = false - synchronized { - for (ts <- activeTaskSetsQueue) { - shouldRevive |= ts.checkSpeculatableTasks() - } - } - if (shouldRevive) { - reviveOffers() - } - } - - - val lock2 = new Object - var firstWait = true - - override def waitForRegister() { - lock2.synchronized { - if (firstWait) { - super.waitForRegister() - Thread.sleep(5000) - firstWait = false - } - } - } - - def fakeStatusUpdate(status: TaskStatus) { - statusUpdate(driver, status) - } - - def fakeResourceOffers(offers: Seq[FakeOffer]) { - logDebug("fakeResourceOffers: " + offers) - val availableCpus = offers.map(_.cores.toDouble).toArray - var launchedTask = false - for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) { - do { - launchedTask = false - for (i <- 0 until offers.size if hostsAlive.contains(offers(i).host)) { - manager.slaveOffer(offers(i).slaveId, offers(i).host, availableCpus(i)) match { - case Some(task) => - val tid = task.getTaskId.getValue - val sid = offers(i).slaveId - taskIdToTaskSetId(tid) = manager.taskSet.id - taskSetTaskIds(manager.taskSet.id) += tid - taskIdToSlaveId(tid) = sid - taskIdsOnSlave(sid) += tid - slaveIdsWithExecutors += sid - availableCpus(i) -= getResource(task.getResourcesList(), "cpus") - launchedTask = true - masterActor ! LaunchTask(sid, task) - - case None => {} - } - } - } while (launchedTask) - } - } - - override def reviveOffers() { - masterActor ! ReviveOffers() - } -} - -class WorkerTask(slaveId: String, host: String) extends Task[Unit](-1) { - generation = 0 - - def run(id: Int): Unit = { - val actor = actorOf(new WorkerActor(slaveId, host)) - if (!remote.isRunning) { - remote.start(Utils.localIpAddress, 7078) - } - remote.register("WorkerActor", actor) - actor.start() - while (true) { - Thread.sleep(10000) - } - } -} - -class WorkerActor(slaveId: String, host: String) extends Actor with Logging { - val env = SparkEnv.get - val classLoader = currentThread.getContextClassLoader - val threadPool = new ThreadPoolExecutor( - 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) - - val masterIp: String = System.getProperty("spark.master.host", "localhost") - val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt - val masterActor = remote.actorFor("MasterActor", masterIp, masterPort) - - class TaskRunner(desc: MTaskInfo) - extends Runnable { - override def run() = { - val tid = desc.getTaskId.getValue - logInfo("Running task ID " + tid) - try { - SparkEnv.set(env) - Thread.currentThread.setContextClassLoader(classLoader) - Accumulators.clear - val task = Utils.deserialize[Task[Any]](desc.getData.toByteArray, classLoader) - env.mapOutputTracker.updateGeneration(task.generation) - val value = task.run(tid.toInt) - val accumUpdates = Accumulators.values - val result = new TaskResult(value, accumUpdates) - masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder() - .setTaskId(desc.getTaskId) - .setState(TaskState.TASK_FINISHED) - .setData(ByteString.copyFrom(Utils.serialize(result))) - .build()) - logInfo("Finished task ID " + tid) - } catch { - case ffe: FetchFailedException => { - val reason = ffe.toTaskEndReason - masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder() - .setTaskId(desc.getTaskId) - .setState(TaskState.TASK_FAILED) - .setData(ByteString.copyFrom(Utils.serialize(reason))) - .build()) - } - case t: Throwable => { - val reason = ExceptionFailure(t) - masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder() - .setTaskId(desc.getTaskId) - .setState(TaskState.TASK_FAILED) - .setData(ByteString.copyFrom(Utils.serialize(reason))) - .build()) - - // TODO: Should we exit the whole executor here? On the one hand, the failed task may - // have left some weird state around depending on when the exception was thrown, but on - // the other hand, maybe we could detect that when future tasks fail and exit then. - logError("Exception in task ID " + tid, t) - //System.exit(1) - } - } - } - } - - override def preStart { - val ref = toRemoteActorRefProtocol(self).toByteArray - logInfo("Registering with master") - masterActor ! RegisterSlave(slaveId, host, remote.address.getPort) - } - - override def receive = { - case LaunchTask(slaveId, task) => - threadPool.execute(new TaskRunner(task)) - } -} diff --git a/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala b/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala deleted file mode 100644 index af2f80ea6671756f768c66be2f4ae2142c9f23d4..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala +++ /dev/null @@ -1,32 +0,0 @@ -package spark.scheduler.mesos - -/** - * Information about a running task attempt. - */ -class TaskInfo(val taskId: String, val index: Int, val launchTime: Long, val host: String) { - var finishTime: Long = 0 - var failed = false - - def markSuccessful(time: Long = System.currentTimeMillis) { - finishTime = time - } - - def markFailed(time: Long = System.currentTimeMillis) { - finishTime = time - failed = true - } - - def finished: Boolean = finishTime != 0 - - def successful: Boolean = finished && !failed - - def duration: Long = { - if (!finished) { - throw new UnsupportedOperationException("duration() called on unfinished tasks") - } else { - finishTime - launchTime - } - } - - def timeRunning(currentTime: Long): Long = currentTime - launchTime -} diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala deleted file mode 100644 index 9e4816f7ce1418c6bed93c82b74284546a102cb6..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ /dev/null @@ -1,588 +0,0 @@ -package spark.storage - -import java.io._ -import java.nio._ -import java.nio.channels.FileChannel.MapMode -import java.util.{HashMap => JHashMap} -import java.util.LinkedHashMap -import java.util.UUID -import java.util.Collections - -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.Future -import scala.actors.Futures.future -import scala.actors.remote._ -import scala.actors.remote.RemoteActor._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.JavaConversions._ - -import it.unimi.dsi.fastutil.io._ - -import spark.CacheTracker -import spark.Logging -import spark.Serializer -import spark.SizeEstimator -import spark.SparkEnv -import spark.SparkException -import spark.Utils -import spark.util.ByteBufferInputStream -import spark.network._ - -class BlockManagerId(var ip: String, var port: Int) extends Externalizable { - def this() = this(null, 0) - - override def writeExternal(out: ObjectOutput) { - out.writeUTF(ip) - out.writeInt(port) - } - - override def readExternal(in: ObjectInput) { - ip = in.readUTF() - port = in.readInt() - } - - override def toString = "BlockManagerId(" + ip + ", " + port + ")" - - override def hashCode = ip.hashCode * 41 + port - - override def equals(that: Any) = that match { - case id: BlockManagerId => port == id.port && ip == id.ip - case _ => false - } -} - - -case class BlockException(blockId: String, message: String, ex: Exception = null) extends Exception(message) - - -class BlockLocker(numLockers: Int) { - private val hashLocker = Array.fill(numLockers)(new Object()) - - def getLock(blockId: String): Object = { - return hashLocker(Math.abs(blockId.hashCode % numLockers)) - } -} - - - -class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging { - - case class BlockInfo(level: StorageLevel, tellMaster: Boolean) - - private val NUM_LOCKS = 337 - private val locker = new BlockLocker(NUM_LOCKS) - - private val blockInfo = Collections.synchronizedMap(new JHashMap[String, BlockInfo]) - private val memoryStore: BlockStore = new MemoryStore(this, maxMemory) - private val diskStore: BlockStore = new DiskStore(this, - System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) - - val connectionManager = new ConnectionManager(0) - - val connectionManagerId = connectionManager.id - val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port) - - // TODO: This will be removed after cacheTracker is removed from the code base. - var cacheTracker: CacheTracker = null - - initLogging() - - initialize() - - /** - * Construct a BlockManager with a memory limit set based on system properties. - */ - def this(serializer: Serializer) = - this(BlockManager.getMaxMemoryFromSystemProperties(), serializer) - - /** - * Initialize the BlockManager. Register to the BlockManagerMaster, and start the - * BlockManagerWorker actor. - */ - private def initialize() { - BlockManagerMaster.mustRegisterBlockManager( - RegisterBlockManager(blockManagerId, maxMemory, maxMemory)) - BlockManagerWorker.startBlockManagerWorker(this) - } - - /** - * Get storage level of local block. If no info exists for the block, then returns null. - */ - def getLevel(blockId: String): StorageLevel = { - val info = blockInfo.get(blockId) - if (info != null) info.level else null - } - - /** - * Change storage level for a local block and tell master is necesary. - * If new level is invalid, then block info (if it exists) will be silently removed. - */ - def setLevel(blockId: String, level: StorageLevel, tellMaster: Boolean = true) { - if (level == null) { - throw new IllegalArgumentException("Storage level is null") - } - - // If there was earlier info about the block, then use earlier tellMaster - val oldInfo = blockInfo.get(blockId) - val newTellMaster = if (oldInfo != null) oldInfo.tellMaster else tellMaster - if (oldInfo != null && oldInfo.tellMaster != tellMaster) { - logWarning("Ignoring tellMaster setting as it is different from earlier setting") - } - - // If level is valid, store the block info, else remove the block info - if (level.isValid) { - blockInfo.put(blockId, new BlockInfo(level, newTellMaster)) - logDebug("Info for block " + blockId + " updated with new level as " + level) - } else { - blockInfo.remove(blockId) - logDebug("Info for block " + blockId + " removed as new level is null or invalid") - } - - // Tell master if necessary - if (newTellMaster) { - logDebug("Told master about block " + blockId) - notifyMaster(HeartBeat(blockManagerId, blockId, level, 0, 0)) - } else { - logDebug("Did not tell master about block " + blockId) - } - } - - /** - * Get locations of the block. - */ - def getLocations(blockId: String): Seq[String] = { - val startTimeMs = System.currentTimeMillis - var managers: Array[BlockManagerId] = BlockManagerMaster.mustGetLocations(GetLocations(blockId)) - val locations = managers.map((manager: BlockManagerId) => { manager.ip }).toSeq - logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs)) - return locations - } - - /** - * Get locations of an array of blocks. - */ - def getLocations(blockIds: Array[String]): Array[Seq[String]] = { - val startTimeMs = System.currentTimeMillis - val locations = BlockManagerMaster.mustGetLocationsMultipleBlockIds( - GetLocationsMultipleBlockIds(blockIds)).map(_.map(_.ip).toSeq).toArray - logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) - return locations - } - - /** - * Get block from local block manager. - */ - def getLocal(blockId: String): Option[Iterator[Any]] = { - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } - logDebug("Getting local block " + blockId) - locker.getLock(blockId).synchronized { - - // Check storage level of block - val level = getLevel(blockId) - if (level != null) { - logDebug("Level for block " + blockId + " is " + level + " on local machine") - - // Look for the block in memory - if (level.useMemory) { - logDebug("Getting block " + blockId + " from memory") - memoryStore.getValues(blockId) match { - case Some(iterator) => { - logDebug("Block " + blockId + " found in memory") - return Some(iterator) - } - case None => { - logDebug("Block " + blockId + " not found in memory") - } - } - } else { - logDebug("Not getting block " + blockId + " from memory") - } - - // Look for block in disk - if (level.useDisk) { - logDebug("Getting block " + blockId + " from disk") - diskStore.getValues(blockId) match { - case Some(iterator) => { - logDebug("Block " + blockId + " found in disk") - return Some(iterator) - } - case None => { - throw new Exception("Block " + blockId + " not found in disk") - return None - } - } - } else { - logDebug("Not getting block " + blockId + " from disk") - } - - } else { - logDebug("Level for block " + blockId + " not found") - } - } - return None - } - - /** - * Get block from remote block managers. - */ - def getRemote(blockId: String): Option[Iterator[Any]] = { - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } - logDebug("Getting remote block " + blockId) - // Get locations of block - val locations = BlockManagerMaster.mustGetLocations(GetLocations(blockId)) - - // Get block from remote locations - for (loc <- locations) { - logDebug("Getting remote block " + blockId + " from " + loc) - val data = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port)) - if (data != null) { - logDebug("Data is not null: " + data) - return Some(dataDeserialize(data)) - } - logDebug("Data is null") - } - logDebug("Data not found") - return None - } - - /** - * Get a block from the block manager (either local or remote). - */ - def get(blockId: String): Option[Iterator[Any]] = { - getLocal(blockId).orElse(getRemote(blockId)) - } - - /** - * Get many blocks from local and remote block manager using their BlockManagerIds. - */ - def get(blocksByAddress: Seq[(BlockManagerId, Seq[String])]): HashMap[String, Option[Iterator[Any]]] = { - if (blocksByAddress == null) { - throw new IllegalArgumentException("BlocksByAddress is null") - } - logDebug("Getting " + blocksByAddress.map(_._2.size).sum + " blocks") - var startTime = System.currentTimeMillis - val blocks = new HashMap[String,Option[Iterator[Any]]]() - val localBlockIds = new ArrayBuffer[String]() - val remoteBlockIds = new ArrayBuffer[String]() - val remoteBlockIdsPerLocation = new HashMap[BlockManagerId, Seq[String]]() - - // Split local and remote blocks - for ((address, blockIds) <- blocksByAddress) { - if (address == blockManagerId) { - localBlockIds ++= blockIds - } else { - remoteBlockIds ++= blockIds - remoteBlockIdsPerLocation(address) = blockIds - } - } - - // Start getting remote blocks - val remoteBlockFutures = remoteBlockIdsPerLocation.toSeq.map { case (bmId, bIds) => - val cmId = ConnectionManagerId(bmId.ip, bmId.port) - val blockMessages = bIds.map(bId => BlockMessage.fromGetBlock(GetBlock(bId))) - val blockMessageArray = new BlockMessageArray(blockMessages) - val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - (cmId, future) - } - logDebug("Started remote gets for " + remoteBlockIds.size + " blocks in " + - Utils.getUsedTimeMs(startTime) + " ms") - - // Get the local blocks while remote blocks are being fetched - startTime = System.currentTimeMillis - localBlockIds.foreach(id => { - get(id) match { - case Some(block) => { - blocks.update(id, Some(block)) - logDebug("Got local block " + id) - } - case None => { - throw new BlockException(id, "Could not get block " + id + " from local machine") - } - } - }) - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - - // wait for and gather all the remote blocks - for ((cmId, future) <- remoteBlockFutures) { - var count = 0 - val oneBlockId = remoteBlockIdsPerLocation(new BlockManagerId(cmId.host, cmId.port)).first - future() match { - case Some(message) => { - val bufferMessage = message.asInstanceOf[BufferMessage] - val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - blockMessageArray.foreach(blockMessage => { - if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { - throw new BlockException(oneBlockId, "Unexpected message received from " + cmId) - } - val buffer = blockMessage.getData() - val blockId = blockMessage.getId() - val block = dataDeserialize(buffer) - blocks.update(blockId, Some(block)) - logDebug("Got remote block " + blockId + " in " + Utils.getUsedTimeMs(startTime)) - count += 1 - }) - } - case None => { - throw new BlockException(oneBlockId, "Could not get blocks from " + cmId) - } - } - logDebug("Got remote " + count + " blocks from " + cmId.host + " in " + - Utils.getUsedTimeMs(startTime) + " ms") - } - - logDebug("Got all blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - return blocks - } - - /** - * Put a new block of values to the block manager. - */ - def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean = true) { - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } - if (values == null) { - throw new IllegalArgumentException("Values is null") - } - if (level == null || !level.isValid) { - throw new IllegalArgumentException("Storage level is null or invalid") - } - - val startTimeMs = System.currentTimeMillis - var bytes: ByteBuffer = null - - locker.getLock(blockId).synchronized { - logDebug("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) - + " to get into synchronized block") - - // Check and warn if block with same id already exists - if (getLevel(blockId) != null) { - logWarning("Block " + blockId + " already exists in local machine") - return - } - - if (level.useMemory && level.useDisk) { - // If saving to both memory and disk, then serialize only once - memoryStore.putValues(blockId, values, level) match { - case Left(newValues) => - diskStore.putValues(blockId, newValues, level) match { - case Right(newBytes) => bytes = newBytes - case _ => throw new Exception("Unexpected return value") - } - case Right(newBytes) => - bytes = newBytes - diskStore.putBytes(blockId, newBytes, level) - } - } else if (level.useMemory) { - // If only save to memory - memoryStore.putValues(blockId, values, level) match { - case Right(newBytes) => bytes = newBytes - case _ => - } - } else { - // If only save to disk - diskStore.putValues(blockId, values, level) match { - case Right(newBytes) => bytes = newBytes - case _ => throw new Exception("Unexpected return value") - } - } - - // Store the storage level - setLevel(blockId, level, tellMaster) - } - logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs)) - - // Replicate block if required - if (level.replication > 1) { - if (bytes == null) { - bytes = dataSerialize(values) // serialize the block if not already done - } - replicate(blockId, bytes, level) - } - - // TODO: This code will be removed when CacheTracker is gone. - if (blockId.startsWith("rdd")) { - notifyTheCacheTracker(blockId) - } - logDebug("Put block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)) - } - - - /** - * Put a new block of serialized bytes to the block manager. - */ - def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) { - if (blockId == null) { - throw new IllegalArgumentException("Block Id is null") - } - if (bytes == null) { - throw new IllegalArgumentException("Bytes is null") - } - if (level == null || !level.isValid) { - throw new IllegalArgumentException("Storage level is null or invalid") - } - - val startTimeMs = System.currentTimeMillis - - // Initiate the replication before storing it locally. This is faster as - // data is already serialized and ready for sending - val replicationFuture = if (level.replication > 1) { - future { - replicate(blockId, bytes, level) - } - } else { - null - } - - locker.getLock(blockId).synchronized { - logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs) - + " to get into synchronized block") - if (getLevel(blockId) != null) { - logWarning("Block " + blockId + " already exists") - return - } - - if (level.useMemory) { - memoryStore.putBytes(blockId, bytes, level) - } - if (level.useDisk) { - diskStore.putBytes(blockId, bytes, level) - } - - // Store the storage level - setLevel(blockId, level, tellMaster) - } - - // TODO: This code will be removed when CacheTracker is gone. - if (blockId.startsWith("rdd")) { - notifyTheCacheTracker(blockId) - } - - // If replication had started, then wait for it to finish - if (level.replication > 1) { - if (replicationFuture == null) { - throw new Exception("Unexpected") - } - replicationFuture() - } - - val finishTime = System.currentTimeMillis - if (level.replication > 1) { - logDebug("PutBytes for block " + blockId + " with replication took " + - Utils.getUsedTimeMs(startTimeMs)) - } else { - logDebug("PutBytes for block " + blockId + " without replication took " + - Utils.getUsedTimeMs(startTimeMs)) - } - } - - /** - * Replicate block to another node. - */ - - private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { - val tLevel: StorageLevel = - new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) - var peers: Array[BlockManagerId] = BlockManagerMaster.mustGetPeers( - GetPeers(blockManagerId, level.replication - 1)) - for (peer: BlockManagerId <- peers) { - val start = System.nanoTime - logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is " - + data.array().length + " Bytes. To node: " + peer) - if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel), - new ConnectionManagerId(peer.ip, peer.port))) { - logError("Failed to call syncPutBlock to " + peer) - } - logDebug("Replicated BlockId " + blockId + " once used " + - (System.nanoTime - start) / 1e6 + " s; The size of the data is " + - data.array().length + " bytes.") - } - } - - // TODO: This code will be removed when CacheTracker is gone. - private def notifyTheCacheTracker(key: String) { - val rddInfo = key.split(":") - val rddId: Int = rddInfo(1).toInt - val splitIndex: Int = rddInfo(2).toInt - val host = System.getProperty("spark.hostname", Utils.localHostName) - cacheTracker.notifyTheCacheTrackerFromBlockManager(spark.AddedToCache(rddId, splitIndex, host)) - } - - /** - * Read a block consisting of a single object. - */ - def getSingle(blockId: String): Option[Any] = { - get(blockId).map(_.next) - } - - /** - * Write a block consisting of a single object. - */ - def putSingle(blockId: String, value: Any, level: StorageLevel, tellMaster: Boolean = true) { - put(blockId, Iterator(value), level, tellMaster) - } - - /** - * Drop block from memory (called when memory store has reached it limit) - */ - def dropFromMemory(blockId: String) { - locker.getLock(blockId).synchronized { - val level = getLevel(blockId) - if (level == null) { - logWarning("Block " + blockId + " cannot be removed from memory as it does not exist") - return - } - if (!level.useMemory) { - logWarning("Block " + blockId + " cannot be removed from memory as it is not in memory") - return - } - memoryStore.remove(blockId) - val newLevel = new StorageLevel(level.useDisk, false, level.deserialized, level.replication) - setLevel(blockId, newLevel) - } - } - - def dataSerialize(values: Iterator[Any]): ByteBuffer = { - /*serializer.newInstance().serializeMany(values)*/ - val byteStream = new FastByteArrayOutputStream(4096) - serializer.newInstance().serializeStream(byteStream).writeAll(values).close() - byteStream.trim() - ByteBuffer.wrap(byteStream.array) - } - - def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = { - /*serializer.newInstance().deserializeMany(bytes)*/ - val ser = serializer.newInstance() - bytes.rewind() - return ser.deserializeStream(new ByteBufferInputStream(bytes)).toIterator - } - - private def notifyMaster(heartBeat: HeartBeat) { - BlockManagerMaster.mustHeartBeat(heartBeat) - } - - def stop() { - connectionManager.stop() - blockInfo.clear() - memoryStore.clear() - diskStore.clear() - logInfo("BlockManager stopped") - } -} - - -object BlockManager extends Logging { - def getMaxMemoryFromSystemProperties(): Long = { - val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble - val bytes = (Runtime.getRuntime.totalMemory * memoryFraction).toLong - logInfo("Maximum memory to use: " + bytes) - bytes - } -} diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala deleted file mode 100644 index d8400a1f65bde55736dfdae9f8a19ab624615e85..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ /dev/null @@ -1,517 +0,0 @@ -package spark.storage - -import java.io._ -import java.util.{HashMap => JHashMap} - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.util.Random - -import akka.actor._ -import akka.actor.Actor -import akka.actor.Actor._ -import akka.util.duration._ - -import spark.Logging -import spark.Utils - -sealed trait ToBlockManagerMaster - -case class RegisterBlockManager( - blockManagerId: BlockManagerId, - maxMemSize: Long, - maxDiskSize: Long) - extends ToBlockManagerMaster - -class HeartBeat( - var blockManagerId: BlockManagerId, - var blockId: String, - var storageLevel: StorageLevel, - var deserializedSize: Long, - var size: Long) - extends ToBlockManagerMaster - with Externalizable { - - def this() = this(null, null, null, 0, 0) // For deserialization only - - override def writeExternal(out: ObjectOutput) { - blockManagerId.writeExternal(out) - out.writeUTF(blockId) - storageLevel.writeExternal(out) - out.writeInt(deserializedSize.toInt) - out.writeInt(size.toInt) - } - - override def readExternal(in: ObjectInput) { - blockManagerId = new BlockManagerId() - blockManagerId.readExternal(in) - blockId = in.readUTF() - storageLevel = new StorageLevel() - storageLevel.readExternal(in) - deserializedSize = in.readInt() - size = in.readInt() - } -} - -object HeartBeat { - def apply(blockManagerId: BlockManagerId, - blockId: String, - storageLevel: StorageLevel, - deserializedSize: Long, - size: Long): HeartBeat = { - new HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) - } - - - // For pattern-matching - def unapply(h: HeartBeat): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { - Some((h.blockManagerId, h.blockId, h.storageLevel, h.deserializedSize, h.size)) - } -} - -case class GetLocations( - blockId: String) - extends ToBlockManagerMaster - -case class GetLocationsMultipleBlockIds( - blockIds: Array[String]) - extends ToBlockManagerMaster - -case class GetPeers( - blockManagerId: BlockManagerId, - size: Int) - extends ToBlockManagerMaster - -case class RemoveHost( - host: String) - extends ToBlockManagerMaster - - -class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { - - class BlockManagerInfo( - timeMs: Long, - maxMem: Long, - maxDisk: Long) { - private var lastSeenMs = timeMs - private var remainedMem = maxMem - private var remainedDisk = maxDisk - private val blocks = new JHashMap[String, StorageLevel] - - def updateLastSeenMs() { - lastSeenMs = System.currentTimeMillis() / 1000 - } - - def addBlock(blockId: String, storageLevel: StorageLevel, deserializedSize: Long, size: Long) = - synchronized { - updateLastSeenMs() - - if (blocks.containsKey(blockId)) { - val oriLevel: StorageLevel = blocks.get(blockId) - - if (oriLevel.deserialized) { - remainedMem += deserializedSize - } - if (oriLevel.useMemory) { - remainedMem += size - } - if (oriLevel.useDisk) { - remainedDisk += size - } - } - - if (storageLevel.isValid) { - blocks.put(blockId, storageLevel) - if (storageLevel.deserialized) { - remainedMem -= deserializedSize - } - if (storageLevel.useMemory) { - remainedMem -= size - } - if (storageLevel.useDisk) { - remainedDisk -= size - } - } else { - blocks.remove(blockId) - } - } - - def getLastSeenMs(): Long = { - return lastSeenMs - } - - def getRemainedMem(): Long = { - return remainedMem - } - - def getRemainedDisk(): Long = { - return remainedDisk - } - - override def toString(): String = { - return "BlockManagerInfo " + timeMs + " " + remainedMem + " " + remainedDisk - } - - def clear() { - blocks.clear() - } - } - - private val blockManagerInfo = new HashMap[BlockManagerId, BlockManagerInfo] - private val blockInfo = new JHashMap[String, Pair[Int, HashSet[BlockManagerId]]] - - initLogging() - - def removeHost(host: String) { - logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.") - logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq) - val ip = host.split(":")(0) - val port = host.split(":")(1) - blockManagerInfo.remove(new BlockManagerId(ip, port.toInt)) - logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq) - self.reply(true) - } - - def receive = { - case RegisterBlockManager(blockManagerId, maxMemSize, maxDiskSize) => - register(blockManagerId, maxMemSize, maxDiskSize) - - case HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) => - heartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) - - case GetLocations(blockId) => - getLocations(blockId) - - case GetLocationsMultipleBlockIds(blockIds) => - getLocationsMultipleBlockIds(blockIds) - - case GetPeers(blockManagerId, size) => - getPeers_Deterministic(blockManagerId, size) - /*getPeers(blockManagerId, size)*/ - - case RemoveHost(host) => - removeHost(host) - - case msg => - logInfo("Got unknown msg: " + msg) - } - - private def register(blockManagerId: BlockManagerId, maxMemSize: Long, maxDiskSize: Long) { - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockManagerId + " " - logDebug("Got in register 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) - logInfo("Got Register Msg from " + blockManagerId) - if (blockManagerId.ip == Utils.localHostName() && !isLocal) { - logInfo("Got Register Msg from master node, don't register it") - } else { - blockManagerInfo += (blockManagerId -> new BlockManagerInfo( - System.currentTimeMillis() / 1000, maxMemSize, maxDiskSize)) - } - logDebug("Got in register 1" + tmp + Utils.getUsedTimeMs(startTimeMs)) - self.reply(true) - } - - private def heartBeat( - blockManagerId: BlockManagerId, - blockId: String, - storageLevel: StorageLevel, - deserializedSize: Long, - size: Long) { - - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockManagerId + " " + blockId + " " - - if (blockId == null) { - blockManagerInfo(blockManagerId).updateLastSeenMs() - logDebug("Got in heartBeat 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) - self.reply(true) - } - - blockManagerInfo(blockManagerId).addBlock(blockId, storageLevel, deserializedSize, size) - - var locations: HashSet[BlockManagerId] = null - if (blockInfo.containsKey(blockId)) { - locations = blockInfo.get(blockId)._2 - } else { - locations = new HashSet[BlockManagerId] - blockInfo.put(blockId, (storageLevel.replication, locations)) - } - - if (storageLevel.isValid) { - locations += blockManagerId - } else { - locations.remove(blockManagerId) - } - - if (locations.size == 0) { - blockInfo.remove(blockId) - } - self.reply(true) - } - - private def getLocations(blockId: String) { - val startTimeMs = System.currentTimeMillis() - val tmp = " " + blockId + " " - logDebug("Got in getLocations 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) - if (blockInfo.containsKey(blockId)) { - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(blockInfo.get(blockId)._2) - logDebug("Got in getLocations 1" + tmp + " as "+ res.toSeq + " at " - + Utils.getUsedTimeMs(startTimeMs)) - self.reply(res.toSeq) - } else { - logDebug("Got in getLocations 2" + tmp + Utils.getUsedTimeMs(startTimeMs)) - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - self.reply(res) - } - } - - private def getLocationsMultipleBlockIds(blockIds: Array[String]) { - def getLocations(blockId: String): Seq[BlockManagerId] = { - val tmp = blockId - logDebug("Got in getLocationsMultipleBlockIds Sub 0 " + tmp) - if (blockInfo.containsKey(blockId)) { - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(blockInfo.get(blockId)._2) - logDebug("Got in getLocationsMultipleBlockIds Sub 1 " + tmp + " " + res.toSeq) - return res.toSeq - } else { - logDebug("Got in getLocationsMultipleBlockIds Sub 2 " + tmp) - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - return res.toSeq - } - } - - logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq) - var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]] - for (blockId <- blockIds) { - res.append(getLocations(blockId)) - } - logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq + " : " + res.toSeq) - self.reply(res.toSeq) - } - - private def getPeers(blockManagerId: BlockManagerId, size: Int) { - var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - res.appendAll(peers) - res -= blockManagerId - val rand = new Random(System.currentTimeMillis()) - while (res.length > size) { - res.remove(rand.nextInt(res.length)) - } - self.reply(res.toSeq) - } - - private def getPeers_Deterministic(blockManagerId: BlockManagerId, size: Int) { - var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray - var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - - val peersWithIndices = peers.zipWithIndex - val selfIndex = peersWithIndices.find(_._1 == blockManagerId).map(_._2).getOrElse(-1) - if (selfIndex == -1) { - throw new Exception("Self index for " + blockManagerId + " not found") - } - - var index = selfIndex - while (res.size < size) { - index += 1 - if (index == selfIndex) { - throw new Exception("More peer expected than available") - } - res += peers(index % peers.size) - } - val resStr = res.map(_.toString).reduceLeft(_ + ", " + _) - self.reply(res.toSeq) - } -} - -object BlockManagerMaster extends Logging { - initLogging() - - val AKKA_ACTOR_NAME: String = "BlockMasterManager" - val REQUEST_RETRY_INTERVAL_MS = 100 - val DEFAULT_MASTER_IP: String = System.getProperty("spark.master.host", "localhost") - val DEFAULT_MASTER_PORT: Int = System.getProperty("spark.master.port", "7077").toInt - val DEFAULT_MANAGER_IP: String = Utils.localHostName() - val DEFAULT_MANAGER_PORT: String = "10902" - - implicit val TIME_OUT_SEC = Actor.Timeout(3000 millis) - var masterActor: ActorRef = null - - def startBlockManagerMaster(isMaster: Boolean, isLocal: Boolean) { - if (isMaster) { - masterActor = actorOf(new BlockManagerMaster(isLocal)) - remote.register(AKKA_ACTOR_NAME, masterActor) - logInfo("Registered BlockManagerMaster Actor: " + DEFAULT_MASTER_IP + ":" + DEFAULT_MASTER_PORT) - masterActor.start() - } else { - masterActor = remote.actorFor(AKKA_ACTOR_NAME, DEFAULT_MASTER_IP, DEFAULT_MASTER_PORT) - } - } - - def stopBlockManagerMaster() { - if (masterActor != null) { - masterActor.stop() - masterActor = null - logInfo("BlockManagerMaster stopped") - } - } - - def notifyADeadHost(host: String) { - (masterActor ? RemoveHost(host + ":" + DEFAULT_MANAGER_PORT)).as[Any] match { - case Some(true) => - logInfo("Removed " + host + " successfully. @ notifyADeadHost") - case Some(oops) => - logError("Failed @ notifyADeadHost: " + oops) - case None => - logError("None @ notifyADeadHost.") - } - } - - def mustRegisterBlockManager(msg: RegisterBlockManager) { - while (! syncRegisterBlockManager(msg)) { - logWarning("Failed to register " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - } - } - - def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = { - //val masterActor = RemoteActor.select(node, name) - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncRegisterBlockManager 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - (masterActor ? msg).as[Any] match { - case Some(true) => - logInfo("BlockManager registered successfully @ syncRegisterBlockManager.") - logDebug("Got in syncRegisterBlockManager 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return true - case Some(oops) => - logError("Failed @ syncRegisterBlockManager: " + oops) - return false - case None => - logError("None @ syncRegisterBlockManager.") - return false - } - } - - def mustHeartBeat(msg: HeartBeat) { - while (! syncHeartBeat(msg)) { - logWarning("Failed to send heartbeat" + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - } - } - - def syncHeartBeat(msg: HeartBeat): Boolean = { - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncHeartBeat " + tmp + " 0 " + Utils.getUsedTimeMs(startTimeMs)) - - (masterActor ? msg).as[Any] match { - case Some(true) => - logInfo("Heartbeat sent successfully.") - logDebug("Got in syncHeartBeat " + tmp + " 1 " + Utils.getUsedTimeMs(startTimeMs)) - return true - case Some(oops) => - logError("Failed: " + oops) - return false - case None => - logError("None.") - return false - } - } - - def mustGetLocations(msg: GetLocations): Array[BlockManagerId] = { - var res: Array[BlockManagerId] = syncGetLocations(msg) - while (res == null) { - logInfo("Failed to get locations " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetLocations(msg) - } - return res - } - - def syncGetLocations(msg: GetLocations): Array[BlockManagerId] = { - val startTimeMs = System.currentTimeMillis() - val tmp = " msg " + msg + " " - logDebug("Got in syncGetLocations 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - (masterActor ? msg).as[Seq[BlockManagerId]] match { - case Some(arr) => - logDebug("GetLocations successfully.") - logDebug("Got in syncGetLocations 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - val res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - for (ele <- arr) { - res += ele - } - logDebug("Got in syncGetLocations 2 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return res.toArray - case None => - logError("GetLocations call returned None.") - return null - } - } - - def mustGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds): - Seq[Seq[BlockManagerId]] = { - var res: Seq[Seq[BlockManagerId]] = syncGetLocationsMultipleBlockIds(msg) - while (res == null) { - logWarning("Failed to GetLocationsMultipleBlockIds " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetLocationsMultipleBlockIds(msg) - } - return res - } - - def syncGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds): - Seq[Seq[BlockManagerId]] = { - val startTimeMs = System.currentTimeMillis - val tmp = " msg " + msg + " " - logDebug("Got in syncGetLocationsMultipleBlockIds 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - (masterActor ? msg).as[Any] match { - case Some(arr: Seq[Seq[BlockManagerId]]) => - logDebug("GetLocationsMultipleBlockIds successfully: " + arr) - logDebug("Got in syncGetLocationsMultipleBlockIds 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return arr - case Some(oops) => - logError("Failed: " + oops) - return null - case None => - logInfo("None.") - return null - } - } - - def mustGetPeers(msg: GetPeers): Array[BlockManagerId] = { - var res: Array[BlockManagerId] = syncGetPeers(msg) - while ((res == null) || (res.length != msg.size)) { - logInfo("Failed to get peers " + msg) - Thread.sleep(REQUEST_RETRY_INTERVAL_MS) - res = syncGetPeers(msg) - } - - return res - } - - def syncGetPeers(msg: GetPeers): Array[BlockManagerId] = { - val startTimeMs = System.currentTimeMillis - val tmp = " msg " + msg + " " - logDebug("Got in syncGetPeers 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - (masterActor ? msg).as[Seq[BlockManagerId]] match { - case Some(arr) => - logDebug("Got in syncGetPeers 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - val res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - logInfo("GetPeers successfully: " + arr.length) - res.appendAll(arr) - logDebug("Got in syncGetPeers 2 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return res.toArray - case None => - logError("GetPeers call returned None.") - return null - } - } -} diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala deleted file mode 100644 index 3a8574a815aa86a1c98993a704c12bcd287b4de8..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala +++ /dev/null @@ -1,142 +0,0 @@ -package spark.storage - -import java.nio._ - -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.util.Random - -import spark.Logging -import spark.Utils -import spark.SparkEnv -import spark.network._ - -/** - * This should be changed to use event model late. - */ -class BlockManagerWorker(val blockManager: BlockManager) extends Logging { - initLogging() - - blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive) - - def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { - logDebug("Handling message " + msg) - msg match { - case bufferMessage: BufferMessage => { - try { - logDebug("Handling as a buffer message " + bufferMessage) - val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) - logDebug("Parsed as a block message array") - val responseMessages = blockMessages.map(processBlockMessage _).filter(_ != None).map(_.get) - /*logDebug("Processed block messages")*/ - return Some(new BlockMessageArray(responseMessages).toBufferMessage) - } catch { - case e: Exception => logError("Exception handling buffer message: " + e.getMessage) - return None - } - } - case otherMessage: Any => { - logError("Unknown type message received: " + otherMessage) - return None - } - } - } - - def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = { - blockMessage.getType() match { - case BlockMessage.TYPE_PUT_BLOCK => { - val pB = PutBlock(blockMessage.getId(), blockMessage.getData(), blockMessage.getLevel()) - logInfo("Received [" + pB + "]") - putBlock(pB.id, pB.data, pB.level) - return None - } - case BlockMessage.TYPE_GET_BLOCK => { - val gB = new GetBlock(blockMessage.getId()) - logInfo("Received [" + gB + "]") - val buffer = getBlock(gB.id) - if (buffer == null) { - return None - } - return Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer))) - } - case _ => return None - } - } - - private def putBlock(id: String, bytes: ByteBuffer, level: StorageLevel) { - val startTimeMs = System.currentTimeMillis() - logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes) - blockManager.putBytes(id, bytes, level) - logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) - + " with data size: " + bytes.array().length) - } - - private def getBlock(id: String): ByteBuffer = { - val startTimeMs = System.currentTimeMillis() - logDebug("Getblock " + id + " started from " + startTimeMs) - val block = blockManager.getLocal(id) - val buffer = block match { - case Some(tValues) => { - val values = tValues.asInstanceOf[Iterator[Any]] - val buffer = blockManager.dataSerialize(values) - buffer - } - case None => { - null - } - } - logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) - + " and got buffer " + buffer) - return buffer - } -} - -object BlockManagerWorker extends Logging { - private var blockManagerWorker: BlockManagerWorker = null - private val DATA_TRANSFER_TIME_OUT_MS: Long = 500 - private val REQUEST_RETRY_INTERVAL_MS: Long = 1000 - - initLogging() - - def startBlockManagerWorker(manager: BlockManager) { - blockManagerWorker = new BlockManagerWorker(manager) - } - - def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = { - val blockManager = blockManagerWorker.blockManager - val connectionManager = blockManager.connectionManager - val serializer = blockManager.serializer - val blockMessage = BlockMessage.fromPutBlock(msg) - val blockMessageArray = new BlockMessageArray(blockMessage) - val resultMessage = connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage()) - return (resultMessage != None) - } - - def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = { - val blockManager = blockManagerWorker.blockManager - val connectionManager = blockManager.connectionManager - val serializer = blockManager.serializer - val blockMessage = BlockMessage.fromGetBlock(msg) - val blockMessageArray = new BlockMessageArray(blockMessage) - val responseMessage = connectionManager.sendMessageReliablySync( - toConnManagerId, blockMessageArray.toBufferMessage()) - responseMessage match { - case Some(message) => { - val bufferMessage = message.asInstanceOf[BufferMessage] - logDebug("Response message received " + bufferMessage) - BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => { - logDebug("Found " + blockMessage) - return blockMessage.getData - }) - } - case None => logDebug("No response message received"); return null - } - return null - } -} diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala deleted file mode 100644 index bb128dce7a6b8ad45c476c59d87ccf17c77ab667..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/storage/BlockMessage.scala +++ /dev/null @@ -1,219 +0,0 @@ -package spark.storage - -import java.nio._ - -import scala.collection.mutable.StringBuilder -import scala.collection.mutable.ArrayBuffer - -import spark._ -import spark.network._ - -case class GetBlock(id: String) -case class GotBlock(id: String, data: ByteBuffer) -case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel) - -class BlockMessage() extends Logging{ - // Un-initialized: typ = 0 - // GetBlock: typ = 1 - // GotBlock: typ = 2 - // PutBlock: typ = 3 - private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED - private var id: String = null - private var data: ByteBuffer = null - private var level: StorageLevel = null - - initLogging() - - def set(getBlock: GetBlock) { - typ = BlockMessage.TYPE_GET_BLOCK - id = getBlock.id - } - - def set(gotBlock: GotBlock) { - typ = BlockMessage.TYPE_GOT_BLOCK - id = gotBlock.id - data = gotBlock.data - } - - def set(putBlock: PutBlock) { - typ = BlockMessage.TYPE_PUT_BLOCK - id = putBlock.id - data = putBlock.data - level = putBlock.level - } - - def set(buffer: ByteBuffer) { - val startTime = System.currentTimeMillis - /* - println() - println("BlockMessage: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ - typ = buffer.getInt() - val idLength = buffer.getInt() - val idBuilder = new StringBuilder(idLength) - for (i <- 1 to idLength) { - idBuilder += buffer.getChar() - } - id = idBuilder.toString() - - logDebug("Set from buffer Result: " + typ + " " + id) - logDebug("Buffer position is " + buffer.position) - if (typ == BlockMessage.TYPE_PUT_BLOCK) { - - val booleanInt = buffer.getInt() - val replication = buffer.getInt() - level = new StorageLevel(booleanInt, replication) - - val dataLength = buffer.getInt() - data = ByteBuffer.allocate(dataLength) - if (dataLength != buffer.remaining) { - throw new Exception("Error parsing buffer") - } - data.put(buffer) - data.flip() - logDebug("Set from buffer Result 2: " + level + " " + data) - } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { - - val dataLength = buffer.getInt() - logDebug("Data length is "+ dataLength) - logDebug("Buffer position is " + buffer.position) - data = ByteBuffer.allocate(dataLength) - if (dataLength != buffer.remaining) { - throw new Exception("Error parsing buffer") - } - data.put(buffer) - data.flip() - logDebug("Set from buffer Result 3: " + data) - } - - val finishTime = System.currentTimeMillis - logDebug("Converted " + id + " from bytebuffer in " + (finishTime - startTime) / 1000.0 + " s") - } - - def set(bufferMsg: BufferMessage) { - val buffer = bufferMsg.buffers.apply(0) - buffer.clear() - set(buffer) - } - - def getType(): Int = { - return typ - } - - def getId(): String = { - return id - } - - def getData(): ByteBuffer = { - return data - } - - def getLevel(): StorageLevel = { - return level - } - - def toBufferMessage(): BufferMessage = { - val startTime = System.currentTimeMillis - val buffers = new ArrayBuffer[ByteBuffer]() - var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2) - buffer.putInt(typ).putInt(id.length()) - id.foreach((x: Char) => buffer.putChar(x)) - buffer.flip() - buffers += buffer - - if (typ == BlockMessage.TYPE_PUT_BLOCK) { - buffer = ByteBuffer.allocate(8).putInt(level.toInt()).putInt(level.replication) - buffer.flip() - buffers += buffer - - buffer = ByteBuffer.allocate(4).putInt(data.remaining) - buffer.flip() - buffers += buffer - - buffers += data - } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { - buffer = ByteBuffer.allocate(4).putInt(data.remaining) - buffer.flip() - buffers += buffer - - buffers += data - } - - logDebug("Start to log buffers.") - buffers.foreach((x: ByteBuffer) => logDebug("" + x)) - /* - println() - println("BlockMessage: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ - val finishTime = System.currentTimeMillis - logDebug("Converted " + id + " to buffer message in " + (finishTime - startTime) / 1000.0 + " s") - return Message.createBufferMessage(buffers) - } - - override def toString(): String = { - "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + - ", data = " + (if (data != null) data.remaining.toString else "null") + "]" - } -} - -object BlockMessage { - val TYPE_NON_INITIALIZED: Int = 0 - val TYPE_GET_BLOCK: Int = 1 - val TYPE_GOT_BLOCK: Int = 2 - val TYPE_PUT_BLOCK: Int = 3 - - def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(bufferMessage) - newBlockMessage - } - - def fromByteBuffer(buffer: ByteBuffer): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(buffer) - newBlockMessage - } - - def fromGetBlock(getBlock: GetBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(getBlock) - newBlockMessage - } - - def fromGotBlock(gotBlock: GotBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(gotBlock) - newBlockMessage - } - - def fromPutBlock(putBlock: PutBlock): BlockMessage = { - val newBlockMessage = new BlockMessage() - newBlockMessage.set(putBlock) - newBlockMessage - } - - def main(args: Array[String]) { - val B = new BlockMessage() - B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.DISK_AND_MEMORY_2)) - val bMsg = B.toBufferMessage() - val C = new BlockMessage() - C.set(bMsg) - - println(B.getId() + " " + B.getLevel()) - println(C.getId() + " " + C.getLevel()) - } -} diff --git a/core/src/main/scala/spark/storage/BlockMessageArray.scala b/core/src/main/scala/spark/storage/BlockMessageArray.scala deleted file mode 100644 index 5f411d34884e12871405b12b24dcb0765af01427..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/storage/BlockMessageArray.scala +++ /dev/null @@ -1,140 +0,0 @@ -package spark.storage -import java.nio._ - -import scala.collection.mutable.StringBuilder -import scala.collection.mutable.ArrayBuffer - -import spark._ -import spark.network._ - -class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockMessage] with Logging { - - def this(bm: BlockMessage) = this(Array(bm)) - - def this() = this(null.asInstanceOf[Seq[BlockMessage]]) - - def apply(i: Int) = blockMessages(i) - - def iterator = blockMessages.iterator - - def length = blockMessages.length - - initLogging() - - def set(bufferMessage: BufferMessage) { - val startTime = System.currentTimeMillis - val newBlockMessages = new ArrayBuffer[BlockMessage]() - val buffer = bufferMessage.buffers(0) - buffer.clear() - /* - println() - println("BlockMessageArray: ") - while(buffer.remaining > 0) { - print(buffer.get()) - } - buffer.rewind() - println() - println() - */ - while(buffer.remaining() > 0) { - val size = buffer.getInt() - logDebug("Creating block message of size " + size + " bytes") - val newBuffer = buffer.slice() - newBuffer.clear() - newBuffer.limit(size) - logDebug("Trying to convert buffer " + newBuffer + " to block message") - val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer) - logDebug("Created " + newBlockMessage) - newBlockMessages += newBlockMessage - buffer.position(buffer.position() + size) - } - val finishTime = System.currentTimeMillis - logDebug("Converted block message array from buffer message in " + (finishTime - startTime) / 1000.0 + " s") - this.blockMessages = newBlockMessages - } - - def toBufferMessage(): BufferMessage = { - val buffers = new ArrayBuffer[ByteBuffer]() - - blockMessages.foreach(blockMessage => { - val bufferMessage = blockMessage.toBufferMessage - logDebug("Adding " + blockMessage) - val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size) - sizeBuffer.flip - buffers += sizeBuffer - buffers ++= bufferMessage.buffers - logDebug("Added " + bufferMessage) - }) - - logDebug("Buffer list:") - buffers.foreach((x: ByteBuffer) => logDebug("" + x)) - /* - println() - println("BlockMessageArray: ") - buffers.foreach(b => { - while(b.remaining > 0) { - print(b.get()) - } - b.rewind() - }) - println() - println() - */ - return Message.createBufferMessage(buffers) - } -} - -object BlockMessageArray { - - def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { - val newBlockMessageArray = new BlockMessageArray() - newBlockMessageArray.set(bufferMessage) - newBlockMessageArray - } - - def main(args: Array[String]) { - val blockMessages = - (0 until 10).map(i => { - if (i % 2 == 0) { - val buffer = ByteBuffer.allocate(100) - buffer.clear - BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY)) - } else { - BlockMessage.fromGetBlock(GetBlock(i.toString)) - } - }) - val blockMessageArray = new BlockMessageArray(blockMessages) - println("Block message array created") - - val bufferMessage = blockMessageArray.toBufferMessage - println("Converted to buffer message") - - val totalSize = bufferMessage.size - val newBuffer = ByteBuffer.allocate(totalSize) - newBuffer.clear() - bufferMessage.buffers.foreach(buffer => { - newBuffer.put(buffer) - buffer.rewind() - }) - newBuffer.flip - val newBufferMessage = Message.createBufferMessage(newBuffer) - println("Copied to new buffer message, size = " + newBufferMessage.size) - - val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) - println("Converted back to block message array") - newBlockMessageArray.foreach(blockMessage => { - blockMessage.getType() match { - case BlockMessage.TYPE_PUT_BLOCK => { - val pB = PutBlock(blockMessage.getId(), blockMessage.getData(), blockMessage.getLevel()) - println(pB) - } - case BlockMessage.TYPE_GET_BLOCK => { - val gB = new GetBlock(blockMessage.getId()) - println(gB) - } - } - }) - } -} - - diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala deleted file mode 100644 index 8672a5376ebd057eff95e9e5c5748c429da39296..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/storage/BlockStore.scala +++ /dev/null @@ -1,291 +0,0 @@ -package spark.storage - -import spark.{Utils, Logging, Serializer, SizeEstimator} - -import scala.collection.mutable.ArrayBuffer - -import java.io.{File, RandomAccessFile} -import java.nio.ByteBuffer -import java.nio.channels.FileChannel.MapMode -import java.util.{UUID, LinkedHashMap} -import java.util.concurrent.Executors - -import it.unimi.dsi.fastutil.io._ - -/** - * Abstract class to store blocks - */ -abstract class BlockStore(blockManager: BlockManager) extends Logging { - initLogging() - - def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) - - def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] - - def getBytes(blockId: String): Option[ByteBuffer] - - def getValues(blockId: String): Option[Iterator[Any]] - - def remove(blockId: String) - - def dataSerialize(values: Iterator[Any]): ByteBuffer = blockManager.dataSerialize(values) - - def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = blockManager.dataDeserialize(bytes) - - def clear() { } -} - -/** - * Class to store blocks in memory - */ -class MemoryStore(blockManager: BlockManager, maxMemory: Long) - extends BlockStore(blockManager) { - - class Entry(var value: Any, val size: Long, val deserialized: Boolean) - - private val memoryStore = new LinkedHashMap[String, Entry](32, 0.75f, true) - private var currentMemory = 0L - - private val blockDropper = Executors.newSingleThreadExecutor() - - def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) { - if (level.deserialized) { - bytes.rewind() - val values = dataDeserialize(bytes) - val elements = new ArrayBuffer[Any] - elements ++= values - val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef]) - ensureFreeSpace(sizeEstimate) - val entry = new Entry(elements, sizeEstimate, true) - memoryStore.synchronized { memoryStore.put(blockId, entry) } - currentMemory += sizeEstimate - logDebug("Block " + blockId + " stored as values to memory") - } else { - val entry = new Entry(bytes, bytes.array().length, false) - ensureFreeSpace(bytes.array.length) - memoryStore.synchronized { memoryStore.put(blockId, entry) } - currentMemory += bytes.array().length - logDebug("Block " + blockId + " stored as " + bytes.array().length + " bytes to memory") - } - } - - def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = { - if (level.deserialized) { - val elements = new ArrayBuffer[Any] - elements ++= values - val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef]) - ensureFreeSpace(sizeEstimate) - val entry = new Entry(elements, sizeEstimate, true) - memoryStore.synchronized { memoryStore.put(blockId, entry) } - currentMemory += sizeEstimate - logDebug("Block " + blockId + " stored as values to memory") - return Left(elements.iterator) - } else { - val bytes = dataSerialize(values) - ensureFreeSpace(bytes.array().length) - val entry = new Entry(bytes, bytes.array().length, false) - memoryStore.synchronized { memoryStore.put(blockId, entry) } - currentMemory += bytes.array().length - logDebug("Block " + blockId + " stored as " + bytes.array.length + " bytes to memory") - return Right(bytes) - } - } - - def getBytes(blockId: String): Option[ByteBuffer] = { - throw new UnsupportedOperationException("Not implemented") - } - - def getValues(blockId: String): Option[Iterator[Any]] = { - val entry = memoryStore.synchronized { memoryStore.get(blockId) } - if (entry == null) { - return None - } - if (entry.deserialized) { - return Some(entry.value.asInstanceOf[ArrayBuffer[Any]].toIterator) - } else { - return Some(dataDeserialize(entry.value.asInstanceOf[ByteBuffer])) - } - } - - def remove(blockId: String) { - memoryStore.synchronized { - val entry = memoryStore.get(blockId) - if (entry != null) { - memoryStore.remove(blockId) - currentMemory -= entry.size - logDebug("Block " + blockId + " of size " + entry.size + " dropped from memory") - } else { - logWarning("Block " + blockId + " could not be removed as it doesnt exist") - } - } - } - - override def clear() { - memoryStore.synchronized { - memoryStore.clear() - } - blockDropper.shutdown() - logInfo("MemoryStore cleared") - } - - private def drop(blockId: String) { - blockDropper.submit(new Runnable() { - def run() { - blockManager.dropFromMemory(blockId) - } - }) - } - - private def ensureFreeSpace(space: Long) { - logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format( - space, currentMemory, maxMemory)) - - val droppedBlockIds = new ArrayBuffer[String]() - var droppedMemory = 0L - - memoryStore.synchronized { - val iter = memoryStore.entrySet().iterator() - while (maxMemory - (currentMemory - droppedMemory) < space && iter.hasNext) { - val pair = iter.next() - val blockId = pair.getKey - droppedBlockIds += blockId - droppedMemory += pair.getValue.size - logDebug("Decided to drop " + blockId) - } - } - - for (blockId <- droppedBlockIds) { - drop(blockId) - } - droppedBlockIds.clear() - } -} - - -/** - * Class to store blocks in disk - */ -class DiskStore(blockManager: BlockManager, rootDirs: String) - extends BlockStore(blockManager) { - - val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - val localDirs = createLocalDirs() - var lastLocalDirUsed = 0 - - addShutdownHook() - - def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) { - logDebug("Attempting to put block " + blockId) - val startTime = System.currentTimeMillis - val file = createFile(blockId) - if (file != null) { - val channel = new RandomAccessFile(file, "rw").getChannel() - val buffer = channel.map(MapMode.READ_WRITE, 0, bytes.array.length) - buffer.put(bytes.array) - channel.close() - val finishTime = System.currentTimeMillis - logDebug("Block " + blockId + " stored to file of " + bytes.array.length + " bytes to disk in " + (finishTime - startTime) + " ms") - } else { - logError("File not created for block " + blockId) - } - } - - def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = { - val bytes = dataSerialize(values) - logDebug("Converted block " + blockId + " to " + bytes.array.length + " bytes") - putBytes(blockId, bytes, level) - return Right(bytes) - } - - def getBytes(blockId: String): Option[ByteBuffer] = { - val file = getFile(blockId) - val length = file.length().toInt - val channel = new RandomAccessFile(file, "r").getChannel() - val bytes = ByteBuffer.allocate(length) - bytes.put(channel.map(MapMode.READ_WRITE, 0, length)) - return Some(bytes) - } - - def getValues(blockId: String): Option[Iterator[Any]] = { - val file = getFile(blockId) - val length = file.length().toInt - val channel = new RandomAccessFile(file, "r").getChannel() - val bytes = channel.map(MapMode.READ_ONLY, 0, length) - val buffer = dataDeserialize(bytes) - channel.close() - return Some(buffer) - } - - def remove(blockId: String) { - throw new UnsupportedOperationException("Not implemented") - } - - private def createFile(blockId: String): File = { - val file = getFile(blockId) - if (file == null) { - lastLocalDirUsed = (lastLocalDirUsed + 1) % localDirs.size - val newFile = new File(localDirs(lastLocalDirUsed), blockId) - newFile.getParentFile.mkdirs() - return newFile - } else { - logError("File for block " + blockId + " already exists on disk, " + file) - return null - } - } - - private def getFile(blockId: String): File = { - logDebug("Getting file for block " + blockId) - // Search for the file in all the local directories, only one of them should have the file - val files = localDirs.map(localDir => new File(localDir, blockId)).filter(_.exists) - if (files.size > 1) { - throw new Exception("Multiple files for same block " + blockId + " exists: " + - files.map(_.toString).reduceLeft(_ + ", " + _)) - return null - } else if (files.size == 0) { - return null - } else { - logDebug("Got file " + files(0) + " of size " + files(0).length + " bytes") - return files(0) - } - } - - private def createLocalDirs(): Seq[File] = { - logDebug("Creating local directories at root dirs '" + rootDirs + "'") - rootDirs.split("[;,:]").map(rootDir => { - var foundLocalDir: Boolean = false - var localDir: File = null - var localDirUuid: UUID = null - var tries = 0 - while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) { - tries += 1 - try { - localDirUuid = UUID.randomUUID() - localDir = new File(rootDir, "spark-local-" + localDirUuid) - if (!localDir.exists) { - localDir.mkdirs() - foundLocalDir = true - } - } catch { - case e: Exception => - logWarning("Attempt " + tries + " to create local dir failed", e) - } - } - if (!foundLocalDir) { - logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + - " attempts to create local dir in " + rootDir) - System.exit(1) - } - logDebug("Created local directory at " + localDir) - localDir - }) - } - - private def addShutdownHook() { - Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { - override def run() { - logDebug("Shutdown hook called") - localDirs.foreach(localDir => Utils.deleteRecursively(localDir)) - } - }) - } -} diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala deleted file mode 100644 index 693a679c4e79fc7adb788fc705e93f79af76b3f2..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/storage/StorageLevel.scala +++ /dev/null @@ -1,80 +0,0 @@ -package spark.storage - -import java.io._ - -class StorageLevel( - var useDisk: Boolean, - var useMemory: Boolean, - var deserialized: Boolean, - var replication: Int = 1) - extends Externalizable { - - // TODO: Also add fields for caching priority, dataset ID, and flushing. - - def this(booleanInt: Int, replication: Int) { - this(((booleanInt & 4) != 0), - ((booleanInt & 2) != 0), - ((booleanInt & 1) != 0), - replication) - } - - def this() = this(false, true, false) // For deserialization - - override def clone(): StorageLevel = new StorageLevel( - this.useDisk, this.useMemory, this.deserialized, this.replication) - - override def equals(other: Any): Boolean = other match { - case s: StorageLevel => - s.useDisk == useDisk && - s.useMemory == useMemory && - s.deserialized == deserialized && - s.replication == replication - case _ => - false - } - - def isValid() = ((useMemory || useDisk) && (replication > 0)) - - def toInt(): Int = { - var ret = 0 - if (useDisk) { - ret += 4 - } - if (useMemory) { - ret += 2 - } - if (deserialized) { - ret += 1 - } - return ret - } - - override def writeExternal(out: ObjectOutput) { - out.writeByte(toInt().toByte) - out.writeByte(replication.toByte) - } - - override def readExternal(in: ObjectInput) { - val flags = in.readByte() - useDisk = (flags & 4) != 0 - useMemory = (flags & 2) != 0 - deserialized = (flags & 1) != 0 - replication = in.readByte() - } - - override def toString(): String = - "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication) -} - -object StorageLevel { - val NONE = new StorageLevel(false, false, false) - val DISK_ONLY = new StorageLevel(true, false, false) - val MEMORY_ONLY = new StorageLevel(false, true, false) - val MEMORY_ONLY_2 = new StorageLevel(false, true, false, 2) - val MEMORY_ONLY_DESER = new StorageLevel(false, true, true) - val MEMORY_ONLY_DESER_2 = new StorageLevel(false, true, true, 2) - val DISK_AND_MEMORY = new StorageLevel(true, true, false) - val DISK_AND_MEMORY_2 = new StorageLevel(true, true, false, 2) - val DISK_AND_MEMORY_DESER = new StorageLevel(true, true, true) - val DISK_AND_MEMORY_DESER_2 = new StorageLevel(true, true, true, 2) -} diff --git a/core/src/main/scala/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/spark/util/ByteBufferInputStream.scala deleted file mode 100644 index abe2d99dd8a5f6814aa57c4ee2fc15fb08b09ac2..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/util/ByteBufferInputStream.scala +++ /dev/null @@ -1,30 +0,0 @@ -package spark.util - -import java.io.InputStream -import java.nio.ByteBuffer - -class ByteBufferInputStream(buffer: ByteBuffer) extends InputStream { - override def read(): Int = { - if (buffer.remaining() == 0) { - -1 - } else { - buffer.get() - } - } - - override def read(dest: Array[Byte]): Int = { - read(dest, 0, dest.length) - } - - override def read(dest: Array[Byte], offset: Int, length: Int): Int = { - val amountToGet = math.min(buffer.remaining(), length) - buffer.get(dest, offset, amountToGet) - return amountToGet - } - - override def skip(bytes: Long): Long = { - val amountToSkip = math.min(bytes, buffer.remaining).toInt - buffer.position(buffer.position + amountToSkip) - return amountToSkip - } -} diff --git a/core/src/main/scala/spark/util/StatCounter.scala b/core/src/main/scala/spark/util/StatCounter.scala deleted file mode 100644 index efb1ae75290f5482cb44d46b3222d34b283d9270..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/util/StatCounter.scala +++ /dev/null @@ -1,89 +0,0 @@ -package spark.util - -/** - * A class for tracking the statistics of a set of numbers (count, mean and variance) in a - * numerically robust way. Includes support for merging two StatCounters. Based on Welford and - * Chan's algorithms described at http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. - */ -class StatCounter(values: TraversableOnce[Double]) { - private var n: Long = 0 // Running count of our values - private var mu: Double = 0 // Running mean of our values - private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2) - - merge(values) - - def this() = this(Nil) - - def merge(value: Double): StatCounter = { - val delta = value - mu - n += 1 - mu += delta / n - m2 += delta * (value - mu) - this - } - - def merge(values: TraversableOnce[Double]): StatCounter = { - values.foreach(v => merge(v)) - this - } - - def merge(other: StatCounter): StatCounter = { - if (other == this) { - merge(other.copy()) // Avoid overwriting fields in a weird order - } else { - val delta = other.mu - mu - if (other.n * 10 < n) { - mu = mu + (delta * other.n) / (n + other.n) - } else if (n * 10 < other.n) { - mu = other.mu - (delta * n) / (n + other.n) - } else { - mu = (mu * n + other.mu * other.n) / (n + other.n) - } - m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n) - n += other.n - this - } - } - - def copy(): StatCounter = { - val other = new StatCounter - other.n = n - other.mu = mu - other.m2 = m2 - other - } - - def count: Long = n - - def mean: Double = mu - - def sum: Double = n * mu - - def variance: Double = { - if (n == 0) - Double.NaN - else - m2 / n - } - - def sampleVariance: Double = { - if (n <= 1) - Double.NaN - else - m2 / (n - 1) - } - - def stdev: Double = math.sqrt(variance) - - def sampleStdev: Double = math.sqrt(sampleVariance) - - override def toString: String = { - "(count: %d, mean: %f, stdev: %f)".format(count, mean, stdev) - } -} - -object StatCounter { - def apply(values: TraversableOnce[Double]) = new StatCounter(values) - - def apply(values: Double*) = new StatCounter(values) -} diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala index 3d170a6e22ef0cec8544454e5622d4432cb0c78c..60290d14cab69427a771004f5e1270f00708eaa4 100644 --- a/core/src/test/scala/spark/CacheTrackerSuite.scala +++ b/core/src/test/scala/spark/CacheTrackerSuite.scala @@ -1,103 +1,95 @@ package spark import org.scalatest.FunSuite - -import scala.collection.mutable.HashMap - -import akka.actor._ -import akka.actor.Actor -import akka.actor.Actor._ +import collection.mutable.HashMap class CacheTrackerSuite extends FunSuite { test("CacheTrackerActor slave initialization & cache status") { - //System.setProperty("spark.master.port", "1345") + System.setProperty("spark.master.port", "1345") val initialSize = 2L << 20 - val tracker = actorOf(new CacheTrackerActor) + val tracker = new CacheTrackerActor tracker.start() - tracker !! SlaveCacheStarted("host001", initialSize) + tracker !? SlaveCacheStarted("host001", initialSize) - assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 0L))) + assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 0L))) - tracker !! StopCacheTracker + tracker !? StopCacheTracker } test("RegisterRDD") { - //System.setProperty("spark.master.port", "1345") + System.setProperty("spark.master.port", "1345") val initialSize = 2L << 20 - val tracker = actorOf(new CacheTrackerActor) + val tracker = new CacheTrackerActor tracker.start() - tracker !! SlaveCacheStarted("host001", initialSize) + tracker !? SlaveCacheStarted("host001", initialSize) - tracker !! RegisterRDD(1, 3) - tracker !! RegisterRDD(2, 1) + tracker !? RegisterRDD(1, 3) + tracker !? RegisterRDD(2, 1) - assert(getCacheLocations(tracker) === Map(1 -> List(List(), List(), List()), 2 -> List(List()))) + assert(getCacheLocations(tracker) == Map(1 -> List(List(), List(), List()), 2 -> List(List()))) - tracker !! StopCacheTracker + tracker !? StopCacheTracker } test("AddedToCache") { - //System.setProperty("spark.master.port", "1345") + System.setProperty("spark.master.port", "1345") val initialSize = 2L << 20 - val tracker = actorOf(new CacheTrackerActor) + val tracker = new CacheTrackerActor tracker.start() - tracker !! SlaveCacheStarted("host001", initialSize) + tracker !? SlaveCacheStarted("host001", initialSize) - tracker !! RegisterRDD(1, 2) - tracker !! RegisterRDD(2, 1) + tracker !? RegisterRDD(1, 2) + tracker !? RegisterRDD(2, 1) - tracker !! AddedToCache(1, 0, "host001", 2L << 15) - tracker !! AddedToCache(1, 1, "host001", 2L << 11) - tracker !! AddedToCache(2, 0, "host001", 3L << 10) + tracker !? AddedToCache(1, 0, "host001", 2L << 15) + tracker !? AddedToCache(1, 1, "host001", 2L << 11) + tracker !? AddedToCache(2, 0, "host001", 3L << 10) - assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 72704L))) + assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L))) - assert(getCacheLocations(tracker) === - Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) + assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) - tracker !! StopCacheTracker + tracker !? StopCacheTracker } test("DroppedFromCache") { - //System.setProperty("spark.master.port", "1345") + System.setProperty("spark.master.port", "1345") val initialSize = 2L << 20 - val tracker = actorOf(new CacheTrackerActor) + val tracker = new CacheTrackerActor tracker.start() - tracker !! SlaveCacheStarted("host001", initialSize) + tracker !? SlaveCacheStarted("host001", initialSize) - tracker !! RegisterRDD(1, 2) - tracker !! RegisterRDD(2, 1) + tracker !? RegisterRDD(1, 2) + tracker !? RegisterRDD(2, 1) - tracker !! AddedToCache(1, 0, "host001", 2L << 15) - tracker !! AddedToCache(1, 1, "host001", 2L << 11) - tracker !! AddedToCache(2, 0, "host001", 3L << 10) + tracker !? AddedToCache(1, 0, "host001", 2L << 15) + tracker !? AddedToCache(1, 1, "host001", 2L << 11) + tracker !? AddedToCache(2, 0, "host001", 3L << 10) - assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 72704L))) - assert(getCacheLocations(tracker) === - Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) + assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L))) + assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) - tracker !! DroppedFromCache(1, 1, "host001", 2L << 11) + tracker !? DroppedFromCache(1, 1, "host001", 2L << 11) - assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 68608L))) - assert(getCacheLocations(tracker) === - Map(1 -> List(List("host001"),List()), 2 -> List(List("host001")))) + assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 68608L))) + assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"),List()), 2 -> List(List("host001")))) - tracker !! StopCacheTracker + tracker !? StopCacheTracker } /** * Helper function to get cacheLocations from CacheTracker */ - def getCacheLocations(tracker: ActorRef) = (tracker ? GetCacheLocations).get match { + def getCacheLocations(tracker: CacheTrackerActor) = tracker !? GetCacheLocations match { case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]].map { case (i, arr) => (i -> arr.toList) } diff --git a/core/src/test/scala/spark/MesosSchedulerSuite.scala b/core/src/test/scala/spark/MesosSchedulerSuite.scala index 54421225d881e9b9e1f84b0cd1373498e64fa749..0e6820cbdcf31b0135d57283ef6b2b78681a5569 100644 --- a/core/src/test/scala/spark/MesosSchedulerSuite.scala +++ b/core/src/test/scala/spark/MesosSchedulerSuite.scala @@ -2,8 +2,6 @@ package spark import org.scalatest.FunSuite -import spark.scheduler.mesos.MesosScheduler - class MesosSchedulerSuite extends FunSuite { test("memoryStringToMb"){ diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index 00b24464a62bcfd0913391214601665e76b3bfd5..c61cb90f826678a3c5ae070ef3a7a48ec514ee39 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -48,7 +48,7 @@ class ShuffleSuite extends FunSuite { assert(valuesFor2.toList.sorted === List(1)) sc.stop() } - + test("groupByKey with many output partitions") { val sc = new SparkContext("local", "test") val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1))) @@ -189,7 +189,7 @@ class ShuffleSuite extends FunSuite { )) sc.stop() } - + test("zero-partition RDD") { val sc = new SparkContext("local", "test") val emptyDir = Files.createTempDir() @@ -199,5 +199,5 @@ class ShuffleSuite extends FunSuite { // Test that a shuffle on the file works, because this used to be a bug assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil) sc.stop() - } + } } diff --git a/core/src/test/scala/spark/UtilsSuite.scala b/core/src/test/scala/spark/UtilsSuite.scala index 1ac4737f046d35294a89e7165692fe10f809c966..f31251e509a9c14460a573f7584f42d206362e4e 100644 --- a/core/src/test/scala/spark/UtilsSuite.scala +++ b/core/src/test/scala/spark/UtilsSuite.scala @@ -2,7 +2,7 @@ package spark import org.scalatest.FunSuite import java.io.{ByteArrayOutputStream, ByteArrayInputStream} -import scala.util.Random +import util.Random class UtilsSuite extends FunSuite { diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala deleted file mode 100644 index 63501f0613ea845a1846f15e8ad057ed5133c74e..0000000000000000000000000000000000000000 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ /dev/null @@ -1,212 +0,0 @@ -package spark.storage - -import spark.KryoSerializer - -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter - -class BlockManagerSuite extends FunSuite with BeforeAndAfter{ - before { - BlockManagerMaster.startBlockManagerMaster(true, true) - } - - test("manager-master interaction") { - val store = new BlockManager(2000, new KryoSerializer) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - - // Putting a1, a2 and a3 in memory and telling master only about a1 and a2 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_DESER) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_DESER) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY_DESER, false) - - // Checking whether blocks are in memory - assert(store.getSingle("a1") != None, "a1 was not in store") - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") - - // Checking whether master knows about the blocks or not - assert(BlockManagerMaster.mustGetLocations(GetLocations("a1")).size > 0, "master was not told about a1") - assert(BlockManagerMaster.mustGetLocations(GetLocations("a2")).size > 0, "master was not told about a2") - assert(BlockManagerMaster.mustGetLocations(GetLocations("a3")).size === 0, "master was told about a3") - - // Setting storage level of a1 and a2 to invalid; they should be removed from store and master - store.setLevel("a1", new StorageLevel(false, false, false, 1)) - store.setLevel("a2", new StorageLevel(true, false, false, 0)) - assert(store.getSingle("a1") === None, "a1 not removed from store") - assert(store.getSingle("a2") === None, "a2 not removed from store") - assert(BlockManagerMaster.mustGetLocations(GetLocations("a1")).size === 0, "master did not remove a1") - assert(BlockManagerMaster.mustGetLocations(GetLocations("a2")).size === 0, "master did not remove a2") - } - - test("in-memory LRU storage") { - val store = new BlockManager(1000, new KryoSerializer) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_DESER) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_DESER) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY_DESER) - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") - Thread.sleep(100) - assert(store.getSingle("a1") === None, "a1 was in store") - assert(store.getSingle("a2") != None, "a2 was not in store") - // At this point a2 was gotten last, so LRU will getSingle rid of a3 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_DESER) - assert(store.getSingle("a1") != None, "a1 was not in store") - assert(store.getSingle("a2") != None, "a2 was not in store") - Thread.sleep(100) - assert(store.getSingle("a3") === None, "a3 was in store") - } - - test("in-memory LRU storage with serialization") { - val store = new BlockManager(1000, new KryoSerializer) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY) - Thread.sleep(100) - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") - assert(store.getSingle("a1") === None, "a1 was in store") - assert(store.getSingle("a2") != None, "a2 was not in store") - // At this point a2 was gotten last, so LRU will getSingle rid of a3 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_DESER) - Thread.sleep(100) - assert(store.getSingle("a1") != None, "a1 was not in store") - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") === None, "a1 was in store") - } - - test("on-disk storage") { - val store = new BlockManager(1000, new KryoSerializer) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.DISK_ONLY) - store.putSingle("a2", a2, StorageLevel.DISK_ONLY) - store.putSingle("a3", a3, StorageLevel.DISK_ONLY) - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") - assert(store.getSingle("a1") != None, "a1 was not in store") - } - - test("disk and memory storage") { - val store = new BlockManager(1000, new KryoSerializer) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.DISK_AND_MEMORY_DESER) - store.putSingle("a2", a2, StorageLevel.DISK_AND_MEMORY_DESER) - store.putSingle("a3", a3, StorageLevel.DISK_AND_MEMORY_DESER) - Thread.sleep(100) - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") - assert(store.getSingle("a1") != None, "a1 was not in store") - } - - test("disk and memory storage with serialization") { - val store = new BlockManager(1000, new KryoSerializer) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - store.putSingle("a1", a1, StorageLevel.DISK_AND_MEMORY) - store.putSingle("a2", a2, StorageLevel.DISK_AND_MEMORY) - store.putSingle("a3", a3, StorageLevel.DISK_AND_MEMORY) - Thread.sleep(100) - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") - assert(store.getSingle("a1") != None, "a1 was not in store") - } - - test("LRU with mixed storage levels") { - val store = new BlockManager(1000, new KryoSerializer) - val a1 = new Array[Byte](400) - val a2 = new Array[Byte](400) - val a3 = new Array[Byte](400) - val a4 = new Array[Byte](400) - // First store a1 and a2, both in memory, and a3, on disk only - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3", a3, StorageLevel.DISK_ONLY) - // At this point LRU should not kick in because a3 is only on disk - assert(store.getSingle("a1") != None, "a2 was not in store") - assert(store.getSingle("a2") != None, "a3 was not in store") - assert(store.getSingle("a3") != None, "a1 was not in store") - assert(store.getSingle("a1") != None, "a2 was not in store") - assert(store.getSingle("a2") != None, "a3 was not in store") - assert(store.getSingle("a3") != None, "a1 was not in store") - // Now let's add in a4, which uses both disk and memory; a1 should drop out - store.putSingle("a4", a4, StorageLevel.DISK_AND_MEMORY) - Thread.sleep(100) - assert(store.getSingle("a1") == None, "a1 was in store") - assert(store.getSingle("a2") != None, "a2 was not in store") - assert(store.getSingle("a3") != None, "a3 was not in store") - assert(store.getSingle("a4") != None, "a4 was not in store") - } - - test("in-memory LRU with streams") { - val store = new BlockManager(1000, new KryoSerializer) - val list1 = List(new Array[Byte](200), new Array[Byte](200)) - val list2 = List(new Array[Byte](200), new Array[Byte](200)) - val list3 = List(new Array[Byte](200), new Array[Byte](200)) - store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY_DESER) - store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY_DESER) - store.put("list3", list3.iterator, StorageLevel.MEMORY_ONLY_DESER) - Thread.sleep(100) - assert(store.get("list2") != None, "list2 was not in store") - assert(store.get("list2").get.size == 2) - assert(store.get("list3") != None, "list3 was not in store") - assert(store.get("list3").get.size == 2) - assert(store.get("list1") === None, "list1 was in store") - assert(store.get("list2") != None, "list2 was not in store") - assert(store.get("list2").get.size == 2) - // At this point list2 was gotten last, so LRU will getSingle rid of list3 - store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY_DESER) - Thread.sleep(100) - assert(store.get("list1") != None, "list1 was not in store") - assert(store.get("list1").get.size == 2) - assert(store.get("list2") != None, "list2 was not in store") - assert(store.get("list2").get.size == 2) - assert(store.get("list3") === None, "list1 was in store") - } - - test("LRU with mixed storage levels and streams") { - val store = new BlockManager(1000, new KryoSerializer) - val list1 = List(new Array[Byte](200), new Array[Byte](200)) - val list2 = List(new Array[Byte](200), new Array[Byte](200)) - val list3 = List(new Array[Byte](200), new Array[Byte](200)) - val list4 = List(new Array[Byte](200), new Array[Byte](200)) - // First store list1 and list2, both in memory, and list3, on disk only - store.put("list1", list1.iterator, StorageLevel.MEMORY_ONLY) - store.put("list2", list2.iterator, StorageLevel.MEMORY_ONLY) - store.put("list3", list3.iterator, StorageLevel.DISK_ONLY) - Thread.sleep(100) - // At this point LRU should not kick in because list3 is only on disk - assert(store.get("list1") != None, "list2 was not in store") - assert(store.get("list1").get.size === 2) - assert(store.get("list2") != None, "list3 was not in store") - assert(store.get("list2").get.size === 2) - assert(store.get("list3") != None, "list1 was not in store") - assert(store.get("list3").get.size === 2) - assert(store.get("list1") != None, "list2 was not in store") - assert(store.get("list1").get.size === 2) - assert(store.get("list2") != None, "list3 was not in store") - assert(store.get("list2").get.size === 2) - assert(store.get("list3") != None, "list1 was not in store") - assert(store.get("list3").get.size === 2) - // Now let's add in list4, which uses both disk and memory; list1 should drop out - store.put("list4", list4.iterator, StorageLevel.DISK_AND_MEMORY) - assert(store.get("list1") === None, "list1 was in store") - assert(store.get("list2") != None, "list3 was not in store") - assert(store.get("list2").get.size === 2) - assert(store.get("list3") != None, "list1 was not in store") - assert(store.get("list3").get.size === 2) - assert(store.get("list4") != None, "list4 was not in store") - assert(store.get("list4").get.size === 2) - } -} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 3ce6a086c1686ef407623a41d31554d0c2f1e192..caaf5ebc681aa84494823453ce06abed6b877b53 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -33,7 +33,6 @@ object SparkBuild extends Build { "org.scalatest" %% "scalatest" % "1.6.1" % "test", "org.scala-tools.testing" %% "scalacheck" % "1.9" % "test" ), - parallelExecution in Test := false, /* Workaround for issue #206 (fixed after SBT 0.11.0) */ watchTransitiveSources <<= Defaults.inDependencies[Task[Seq[File]]](watchSources.task, const(std.TaskExtra.constant(Nil)), aggregate = true, includeRoot = true) apply { _.join.map(_.flatten) } @@ -58,12 +57,8 @@ object SparkBuild extends Build { "asm" % "asm-all" % "3.3.1", "com.google.protobuf" % "protobuf-java" % "2.4.1", "de.javakaffee" % "kryo-serializers" % "0.9", - "se.scalablesolutions.akka" % "akka-actor" % "1.3.1", - "se.scalablesolutions.akka" % "akka-remote" % "1.3.1", - "se.scalablesolutions.akka" % "akka-slf4j" % "1.3.1", "org.jboss.netty" % "netty" % "3.2.6.Final", - "it.unimi.dsi" % "fastutil" % "6.4.4", - "colt" % "colt" % "1.2.0" + "it.unimi.dsi" % "fastutil" % "6.4.2" ) ) ++ assemblySettings ++ Seq(test in assembly := {}) @@ -73,7 +68,8 @@ object SparkBuild extends Build { ) ++ assemblySettings ++ Seq(test in assembly := {}) def examplesSettings = sharedSettings ++ Seq( - name := "spark-examples" + name := "spark-examples", + libraryDependencies += "colt" % "colt" % "1.2.0" ) def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") diff --git a/sbt/sbt b/sbt/sbt index fab996728686a59cee0c0d23619b986eb15066b9..714e3d15d7b50d06a84a84dd69352f5bad72bf53 100755 --- a/sbt/sbt +++ b/sbt/sbt @@ -4,4 +4,4 @@ if [ "$MESOS_HOME" != "" ]; then EXTRA_ARGS="-Djava.library.path=$MESOS_HOME/lib/java" fi export SPARK_HOME=$(cd "$(dirname $0)/.."; pwd) -java -Xmx1200M -XX:MaxPermSize=200m $EXTRA_ARGS -jar $SPARK_HOME/sbt/sbt-launch-*.jar "$@" +java -Xmx800M -XX:MaxPermSize=150m $EXTRA_ARGS -jar $SPARK_HOME/sbt/sbt-launch-*.jar "$@" diff --git a/sbt/sbt-launch-0.11.1.jar b/sbt/sbt-launch-0.11.1.jar new file mode 100644 index 0000000000000000000000000000000000000000..59d325ecfe8bf3422394496d36f740263bbacb7e Binary files /dev/null and b/sbt/sbt-launch-0.11.1.jar differ diff --git a/sbt/sbt-launch-0.11.3-2.jar b/sbt/sbt-launch-0.11.3-2.jar deleted file mode 100644 index 23e5c3f31149bbf2bddbf1ae8d1fd02aba7910ad..0000000000000000000000000000000000000000 Binary files a/sbt/sbt-launch-0.11.3-2.jar and /dev/null differ