diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala index 7084ff97d90d28e20785483ca8d450db73e7eced..4c18cb913442b71463bfbea3cb32c00da31d2bf9 100644 --- a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala +++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala @@ -11,6 +11,7 @@ import scala.xml.{XML,NodeSeq} import scala.collection.mutable.ArrayBuffer import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} +import java.nio.ByteBuffer object WikipediaPageRankStandalone { def main(args: Array[String]) { @@ -118,23 +119,23 @@ class WPRSerializer extends spark.Serializer { } class WPRSerializerInstance extends SerializerInstance { - def serialize[T](t: T): Array[Byte] = { + def serialize[T](t: T): ByteBuffer = { throw new UnsupportedOperationException() } - def deserialize[T](bytes: Array[Byte]): T = { + def deserialize[T](bytes: ByteBuffer): T = { throw new UnsupportedOperationException() } - def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { + def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { throw new UnsupportedOperationException() } - def outputStream(s: OutputStream): SerializationStream = { + def serializeStream(s: OutputStream): SerializationStream = { new WPRSerializationStream(s) } - def inputStream(s: InputStream): DeserializationStream = { + def deserializeStream(s: InputStream): DeserializationStream = { new WPRDeserializationStream(s) } } diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala new file mode 100644 index 0000000000000000000000000000000000000000..e00a0d80fa25a15e4bf884912613566acba5ab63 --- /dev/null +++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala @@ -0,0 +1,70 @@ +package spark + +import java.io.EOFException +import java.net.URL + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap + +import spark.storage.BlockException +import spark.storage.BlockManagerId + +import it.unimi.dsi.fastutil.io.FastBufferedInputStream + + +class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging { + def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) { + logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) + val ser = SparkEnv.get.serializer.newInstance() + val blockManager = SparkEnv.get.blockManager + + val startTime = System.currentTimeMillis + val addresses = SparkEnv.get.mapOutputTracker.getServerAddresses(shuffleId) + logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( + shuffleId, reduceId, System.currentTimeMillis - startTime)) + + val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[Int]] + for ((address, index) <- addresses.zipWithIndex) { + splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += index + } + + val blocksByAddress: Seq[(BlockManagerId, Seq[String])] = splitsByAddress.toSeq.map { + case (address, splits) => + (address, splits.map(i => "shuffleid_%d_%d_%d".format(shuffleId, i, reduceId))) + } + + try { + val blockOptions = blockManager.get(blocksByAddress) + logDebug("Fetching map output blocks for shuffle %d, reduce %d took %d ms".format( + shuffleId, reduceId, System.currentTimeMillis - startTime)) + blockOptions.foreach(x => { + val (blockId, blockOption) = x + blockOption match { + case Some(block) => { + val values = block.asInstanceOf[Iterator[Any]] + for(value <- values) { + val v = value.asInstanceOf[(K, V)] + func(v._1, v._2) + } + } + case None => { + throw new BlockException(blockId, "Did not get block " + blockId) + } + } + }) + } catch { + case be: BlockException => { + val regex = "shuffledid_([0-9]*)_([0-9]*)_([0-9]]*)".r + be.blockId match { + case regex(sId, mId, rId) => { + val address = addresses(mId.toInt) + throw new FetchFailedException(address, sId.toInt, mId.toInt, rId.toInt, be) + } + case _ => { + throw be + } + } + } + } + } +} diff --git a/core/src/main/scala/spark/BoundedMemoryCache.scala b/core/src/main/scala/spark/BoundedMemoryCache.scala index 1162e34ab03340c763e943b696a611ba9cb5d8d8..fa5dcee7bbf0c4cd66a1d2f0bd363799e4c9eaff 100644 --- a/core/src/main/scala/spark/BoundedMemoryCache.scala +++ b/core/src/main/scala/spark/BoundedMemoryCache.scala @@ -90,7 +90,8 @@ class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging { protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) { logInfo("Dropping key (%s, %d) of size %d to make space".format(datasetId, partition, entry.size)) - SparkEnv.get.cacheTracker.dropEntry(datasetId, partition) + // TODO: remove BoundedMemoryCache + SparkEnv.get.cacheTracker.dropEntry(datasetId.asInstanceOf[(Int, Int)]._2, partition) } } diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index 4867829c17ac6519dba55aed2a47af78f54fe85f..64b4af0ae20e327b90abd36df6ea9a33969a64ed 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -1,11 +1,17 @@ package spark -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ +import akka.actor._ +import akka.actor.Actor +import akka.actor.Actor._ +import akka.util.duration._ + +import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet +import spark.storage.BlockManager +import spark.storage.StorageLevel + sealed trait CacheTrackerMessage case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L) extends CacheTrackerMessage @@ -18,8 +24,8 @@ case object GetCacheStatus extends CacheTrackerMessage case object GetCacheLocations extends CacheTrackerMessage case object StopCacheTracker extends CacheTrackerMessage - -class CacheTrackerActor extends DaemonActor with Logging { +class CacheTrackerActor extends Actor with Logging { + // TODO: Should probably store (String, CacheType) tuples private val locs = new HashMap[Int, Array[List[String]]] /** @@ -28,109 +34,93 @@ class CacheTrackerActor extends DaemonActor with Logging { private val slaveCapacity = new HashMap[String, Long] private val slaveUsage = new HashMap[String, Long] - // TODO: Should probably store (String, CacheType) tuples - private def getCacheUsage(host: String): Long = slaveUsage.getOrElse(host, 0L) private def getCacheCapacity(host: String): Long = slaveCapacity.getOrElse(host, 0L) private def getCacheAvailable(host: String): Long = getCacheCapacity(host) - getCacheUsage(host) - def act() { - val port = System.getProperty("spark.master.port").toInt - RemoteActor.alive(port) - RemoteActor.register('CacheTracker, self) - logInfo("Registered actor on port " + port) - - loop { - react { - case SlaveCacheStarted(host: String, size: Long) => - logInfo("Started slave cache (size %s) on %s".format( - Utils.memoryBytesToString(size), host)) - slaveCapacity.put(host, size) - slaveUsage.put(host, 0) - reply('OK) - - case RegisterRDD(rddId: Int, numPartitions: Int) => - logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions") - locs(rddId) = Array.fill[List[String]](numPartitions)(Nil) - reply('OK) - - case AddedToCache(rddId, partition, host, size) => - if (size > 0) { - slaveUsage.put(host, getCacheUsage(host) + size) - logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format( - rddId, partition, host, Utils.memoryBytesToString(size), - Utils.memoryBytesToString(getCacheAvailable(host)))) - } else { - logInfo("Cache entry added: (%s, %s) on %s".format(rddId, partition, host)) - } - locs(rddId)(partition) = host :: locs(rddId)(partition) - reply('OK) - - case DroppedFromCache(rddId, partition, host, size) => - if (size > 0) { - logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format( - rddId, partition, host, Utils.memoryBytesToString(size), - Utils.memoryBytesToString(getCacheAvailable(host)))) - slaveUsage.put(host, getCacheUsage(host) - size) - - // Do a sanity check to make sure usage is greater than 0. - val usage = getCacheUsage(host) - if (usage < 0) { - logError("Cache usage on %s is negative (%d)".format(host, usage)) - } - } else { - logInfo("Cache entry removed: (%s, %s) on %s".format(rddId, partition, host)) - } - locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host) - reply('OK) + def receive = { + case SlaveCacheStarted(host: String, size: Long) => + logInfo("Started slave cache (size %s) on %s".format( + Utils.memoryBytesToString(size), host)) + slaveCapacity.put(host, size) + slaveUsage.put(host, 0) + self.reply(true) - case MemoryCacheLost(host) => - logInfo("Memory cache lost on " + host) - // TODO: Drop host from the memory locations list of all RDDs - - case GetCacheLocations => - logInfo("Asked for current cache locations") - reply(locs.map{case (rrdId, array) => (rrdId -> array.clone())}) - - case GetCacheStatus => - val status = slaveCapacity.map { case (host,capacity) => - (host, capacity, getCacheUsage(host)) - }.toSeq - reply(status) - - case StopCacheTracker => - reply('OK) - exit() + case RegisterRDD(rddId: Int, numPartitions: Int) => + logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions") + locs(rddId) = Array.fill[List[String]](numPartitions)(Nil) + self.reply(true) + + case AddedToCache(rddId, partition, host, size) => + slaveUsage.put(host, getCacheUsage(host) + size) + logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format( + rddId, partition, host, Utils.memoryBytesToString(size), + Utils.memoryBytesToString(getCacheAvailable(host)))) + locs(rddId)(partition) = host :: locs(rddId)(partition) + self.reply(true) + + case DroppedFromCache(rddId, partition, host, size) => + logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format( + rddId, partition, host, Utils.memoryBytesToString(size), + Utils.memoryBytesToString(getCacheAvailable(host)))) + slaveUsage.put(host, getCacheUsage(host) - size) + // Do a sanity check to make sure usage is greater than 0. + val usage = getCacheUsage(host) + if (usage < 0) { + logError("Cache usage on %s is negative (%d)".format(host, usage)) } - } - } -} + locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host) + self.reply(true) + case MemoryCacheLost(host) => + logInfo("Memory cache lost on " + host) + for ((id, locations) <- locs) { + for (i <- 0 until locations.length) { + locations(i) = locations(i).filterNot(_ == host) + } + } + self.reply(true) -class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging { - // Tracker actor on the master, or remote reference to it on workers - var trackerActor: AbstractActor = null + case GetCacheLocations => + logInfo("Asked for current cache locations") + self.reply(locs.map{case (rrdId, array) => (rrdId -> array.clone())}) - val registeredRddIds = new HashSet[Int] + case GetCacheStatus => + val status = slaveCapacity.map { case (host, capacity) => + (host, capacity, getCacheUsage(host)) + }.toSeq + self.reply(status) - // Stores map results for various splits locally - val cache = theCache.newKeySpace() + case StopCacheTracker => + logInfo("CacheTrackerActor Server stopped!") + self.reply(true) + self.exit() + } +} +class CacheTracker(isMaster: Boolean, blockManager: BlockManager) extends Logging { + // Tracker actor on the master, or remote reference to it on workers + val ip: String = System.getProperty("spark.master.host", "localhost") + val port: Int = System.getProperty("spark.master.port", "7077").toInt + val aName: String = "CacheTracker" + if (isMaster) { - val tracker = new CacheTrackerActor - tracker.start() - trackerActor = tracker + } + + var trackerActor: ActorRef = if (isMaster) { + val actor = actorOf(new CacheTrackerActor) + remote.register(aName, actor) + actor.start() + logInfo("Registered CacheTrackerActor actor @ " + ip + ":" + port) + actor } else { - val host = System.getProperty("spark.master.host") - val port = System.getProperty("spark.master.port").toInt - trackerActor = RemoteActor.select(Node(host, port), 'CacheTracker) + remote.actorFor(aName, ip, port) } - // Report the cache being started. - trackerActor !? SlaveCacheStarted(Utils.getHost, cache.getCapacity) + val registeredRddIds = new HashSet[Int] // Remembers which splits are currently being loaded (on worker nodes) - val loading = new HashSet[(Int, Int)] + val loading = new HashSet[String] // Registers an RDD (on master only) def registerRDD(rddId: Int, numPartitions: Int) { @@ -138,24 +128,33 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging { if (!registeredRddIds.contains(rddId)) { logInfo("Registering RDD ID " + rddId + " with cache") registeredRddIds += rddId - trackerActor !? RegisterRDD(rddId, numPartitions) + (trackerActor ? RegisterRDD(rddId, numPartitions)).as[Any] match { + case Some(true) => + logInfo("CacheTracker registerRDD " + RegisterRDD(rddId, numPartitions) + " successfully.") + case Some(oops) => + logError("CacheTracker registerRDD" + RegisterRDD(rddId, numPartitions) + " failed: " + oops) + case None => + logError("CacheTracker registerRDD None. " + RegisterRDD(rddId, numPartitions)) + throw new SparkException("Internal error: CacheTracker registerRDD None.") + } } } } - - // Get a snapshot of the currently known locations - def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = { - (trackerActor !? GetCacheLocations) match { - case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]] - - case _ => throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap") + + // For BlockManager.scala only + def cacheLost(host: String) { + (trackerActor ? MemoryCacheLost(host)).as[Any] match { + case Some(true) => + logInfo("CacheTracker successfully removed entries on " + host) + case _ => + logError("CacheTracker did not reply to MemoryCacheLost") } } // Get the usage status of slave caches. Each tuple in the returned sequence // is in the form of (host name, capacity, usage). def getCacheStatus(): Seq[(String, Long, Long)] = { - (trackerActor !? GetCacheStatus) match { + (trackerActor ? GetCacheStatus) match { case h: Seq[(String, Long, Long)] => h.asInstanceOf[Seq[(String, Long, Long)]] case _ => @@ -164,75 +163,94 @@ class CacheTracker(isMaster: Boolean, theCache: Cache) extends Logging { } } + // For BlockManager.scala only + def notifyTheCacheTrackerFromBlockManager(t: AddedToCache) { + (trackerActor ? t).as[Any] match { + case Some(true) => + logInfo("CacheTracker notifyTheCacheTrackerFromBlockManager successfully.") + case Some(oops) => + logError("CacheTracker notifyTheCacheTrackerFromBlockManager failed: " + oops) + case None => + logError("CacheTracker notifyTheCacheTrackerFromBlockManager None.") + } + } + + // Get a snapshot of the currently known locations + def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = { + (trackerActor ? GetCacheLocations).as[Any] match { + case Some(h: HashMap[_, _]) => + h.asInstanceOf[HashMap[Int, Array[List[String]]]] + + case _ => + throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap") + } + } + // Gets or computes an RDD split - def getOrCompute[T](rdd: RDD[T], split: Split)(implicit m: ClassManifest[T]): Iterator[T] = { - logInfo("Looking for RDD partition %d:%d".format(rdd.id, split.index)) - val cachedVal = cache.get(rdd.id, split.index) - if (cachedVal != null) { - // Split is in cache, so just return its values - logInfo("Found partition in cache!") - return cachedVal.asInstanceOf[Array[T]].iterator - } else { - // Mark the split as loading (unless someone else marks it first) - val key = (rdd.id, split.index) - loading.synchronized { - while (loading.contains(key)) { - // Someone else is loading it; let's wait for them - try { loading.wait() } catch { case _ => } - } - // See whether someone else has successfully loaded it. The main way this would fail - // is for the RDD-level cache eviction policy if someone else has loaded the same RDD - // partition but we didn't want to make space for it. However, that case is unlikely - // because it's unlikely that two threads would work on the same RDD partition. One - // downside of the current code is that threads wait serially if this does happen. - val cachedVal = cache.get(rdd.id, split.index) - if (cachedVal != null) { - return cachedVal.asInstanceOf[Array[T]].iterator - } - // Nobody's loading it and it's not in the cache; let's load it ourselves - loading.add(key) - } - // If we got here, we have to load the split - // Tell the master that we're doing so - - // TODO: fetch any remote copy of the split that may be available - logInfo("Computing partition " + split) - var array: Array[T] = null - var putResponse: CachePutResponse = null - try { - array = rdd.compute(split).toArray(m) - putResponse = cache.put(rdd.id, split.index, array) - } finally { - // Tell other threads that we've finished our attempt to load the key (whether or not - // we've actually succeeded to put it in the map) + def getOrCompute[T](rdd: RDD[T], split: Split, storageLevel: StorageLevel): Iterator[T] = { + val key = "rdd:%d:%d".format(rdd.id, split.index) + logInfo("Cache key is " + key) + blockManager.get(key) match { + case Some(cachedValues) => + // Split is in cache, so just return its values + logInfo("Found partition in cache!") + return cachedValues.asInstanceOf[Iterator[T]] + + case None => + // Mark the split as loading (unless someone else marks it first) loading.synchronized { - loading.remove(key) - loading.notifyAll() + if (loading.contains(key)) { + logInfo("Loading contains " + key + ", waiting...") + while (loading.contains(key)) { + try {loading.wait()} catch {case _ =>} + } + logInfo("Loading no longer contains " + key + ", so returning cached result") + // See whether someone else has successfully loaded it. The main way this would fail + // is for the RDD-level cache eviction policy if someone else has loaded the same RDD + // partition but we didn't want to make space for it. However, that case is unlikely + // because it's unlikely that two threads would work on the same RDD partition. One + // downside of the current code is that threads wait serially if this does happen. + blockManager.get(key) match { + case Some(values) => + return values.asInstanceOf[Iterator[T]] + case None => + logInfo("Whoever was loading " + key + " failed; we'll try it ourselves") + loading.add(key) + } + } else { + loading.add(key) + } } - } - - putResponse match { - case CachePutSuccess(size) => { - // Tell the master that we added the entry. Don't return until it - // replies so it can properly schedule future tasks that use this RDD. - trackerActor !? AddedToCache(rdd.id, split.index, Utils.getHost, size) + // If we got here, we have to load the split + // Tell the master that we're doing so + //val host = System.getProperty("spark.hostname", Utils.localHostName) + //val future = trackerActor !! AddedToCache(rdd.id, split.index, host) + // TODO: fetch any remote copy of the split that may be available + // TODO: also register a listener for when it unloads + logInfo("Computing partition " + split) + try { + val values = new ArrayBuffer[Any] + values ++= rdd.compute(split) + blockManager.put(key, values.iterator, storageLevel, false) + //future.apply() // Wait for the reply from the cache tracker + return values.iterator.asInstanceOf[Iterator[T]] + } finally { + loading.synchronized { + loading.remove(key) + loading.notifyAll() + } } - case _ => null - } - return array.iterator } } // Called by the Cache to report that an entry has been dropped from it - def dropEntry(datasetId: Any, partition: Int) { - datasetId match { - //TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here. - case (cache.keySpaceId, rddId: Int) => trackerActor !! DroppedFromCache(rddId, partition, Utils.getHost) - } + def dropEntry(rddId: Int, partition: Int) { + //TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here. + trackerActor !! DroppedFromCache(rddId, partition, Utils.localHostName()) } def stop() { - trackerActor !? StopCacheTracker + trackerActor !! StopCacheTracker registeredRddIds.clear() trackerActor = null } diff --git a/core/src/main/scala/spark/CoGroupedRDD.scala b/core/src/main/scala/spark/CoGroupedRDD.scala index 93f453bc5e4341bcf74de43ee22d332cdeaf4e1a..3543c8afa8a081f201f630df60dcb6f915c01115 100644 --- a/core/src/main/scala/spark/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/CoGroupedRDD.scala @@ -22,11 +22,12 @@ class CoGroupAggregator { (b1, b2) => b1 ++ b2 }) with Serializable -class CoGroupedRDD[K](rdds: Seq[RDD[(_, _)]], part: Partitioner) +class CoGroupedRDD[K](@transient rdds: Seq[RDD[(_, _)]], part: Partitioner) extends RDD[(K, Seq[Seq[_]])](rdds.head.context) with Logging { val aggr = new CoGroupAggregator + @transient override val dependencies = { val deps = new ArrayBuffer[Dependency[_]] for ((rdd, index) <- rdds.zipWithIndex) { @@ -67,9 +68,10 @@ class CoGroupedRDD[K](rdds: Seq[RDD[(_, _)]], part: Partitioner) override def compute(s: Split): Iterator[(K, Seq[Seq[_]])] = { val split = s.asInstanceOf[CoGroupSplit] + val numRdds = split.deps.size val map = new HashMap[K, Seq[ArrayBuffer[Any]]] def getSeq(k: K): Seq[ArrayBuffer[Any]] = { - map.getOrElseUpdate(k, Array.fill(rdds.size)(new ArrayBuffer[Any])) + map.getOrElseUpdate(k, Array.fill(numRdds)(new ArrayBuffer[Any])) } for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, itsSplit) => { diff --git a/core/src/main/scala/spark/DAGScheduler.scala b/core/src/main/scala/spark/DAGScheduler.scala deleted file mode 100644 index 1b4af9d84c6d2159eb05084e2587ddef62a6bed1..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/DAGScheduler.scala +++ /dev/null @@ -1,374 +0,0 @@ -package spark - -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map} - -/** - * A task created by the DAG scheduler. Knows its stage ID and map ouput tracker generation. - */ -abstract class DAGTask[T](val runId: Int, val stageId: Int) extends Task[T] { - val gen = SparkEnv.get.mapOutputTracker.getGeneration - override def generation: Option[Long] = Some(gen) -} - -/** - * A completion event passed by the underlying task scheduler to the DAG scheduler. - */ -case class CompletionEvent( - task: DAGTask[_], - reason: TaskEndReason, - result: Any, - accumUpdates: Map[Long, Any]) - -/** - * Various possible reasons why a DAG task ended. The underlying scheduler is supposed to retry - * tasks several times for "ephemeral" failures, and only report back failures that require some - * old stages to be resubmitted, such as shuffle map fetch failures. - */ -sealed trait TaskEndReason -case object Success extends TaskEndReason -case class FetchFailed(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason -case class ExceptionFailure(exception: Throwable) extends TaskEndReason -case class OtherFailure(message: String) extends TaskEndReason - -/** - * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for - * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal - * schedule to run the job. Subclasses only need to implement the code to send a task to the cluster - * and to report fetch failures (the submitTasks method, and code to add CompletionEvents). - */ -private trait DAGScheduler extends Scheduler with Logging { - // Must be implemented by subclasses to start running a set of tasks. The subclass should also - // attempt to run different sets of tasks in the order given by runId (lower values first). - def submitTasks(tasks: Seq[Task[_]], runId: Int): Unit - - // Must be called by subclasses to report task completions or failures. - def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]) { - lock.synchronized { - val dagTask = task.asInstanceOf[DAGTask[_]] - eventQueues.get(dagTask.runId) match { - case Some(queue) => - queue += CompletionEvent(dagTask, reason, result, accumUpdates) - lock.notifyAll() - case None => - logInfo("Ignoring completion event for DAG job " + dagTask.runId + " because it's gone") - } - } - } - - // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; - // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one - // as more failure events come in - val RESUBMIT_TIMEOUT = 2000L - - // The time, in millis, to wake up between polls of the completion queue in order to potentially - // resubmit failed stages - val POLL_TIMEOUT = 500L - - private val lock = new Object // Used for access to the entire DAGScheduler - - private val eventQueues = new HashMap[Int, Queue[CompletionEvent]] // Indexed by run ID - - val nextRunId = new AtomicInteger(0) - - val nextStageId = new AtomicInteger(0) - - val idToStage = new HashMap[Int, Stage] - - val shuffleToMapStage = new HashMap[Int, Stage] - - var cacheLocs = new HashMap[Int, Array[List[String]]] - - val env = SparkEnv.get - val cacheTracker = env.cacheTracker - val mapOutputTracker = env.mapOutputTracker - - def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { - cacheLocs(rdd.id) - } - - def updateCacheLocs() { - cacheLocs = cacheTracker.getLocationsSnapshot() - } - - def getShuffleMapStage(shuf: ShuffleDependency[_,_,_]): Stage = { - shuffleToMapStage.get(shuf.shuffleId) match { - case Some(stage) => stage - case None => - val stage = newStage(shuf.rdd, Some(shuf)) - shuffleToMapStage(shuf.shuffleId) = stage - stage - } - } - - def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]]): Stage = { - // Kind of ugly: need to register RDDs with the cache and map output tracker here - // since we can't do it in the RDD constructor because # of splits is unknown - cacheTracker.registerRDD(rdd.id, rdd.splits.size) - if (shuffleDep != None) { - mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size) - } - val id = nextStageId.getAndIncrement() - val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd)) - idToStage(id) = stage - stage - } - - def getParentStages(rdd: RDD[_]): List[Stage] = { - val parents = new HashSet[Stage] - val visited = new HashSet[RDD[_]] - def visit(r: RDD[_]) { - if (!visited(r)) { - visited += r - // Kind of ugly: need to register RDDs with the cache here since - // we can't do it in its constructor because # of splits is unknown - cacheTracker.registerRDD(r.id, r.splits.size) - for (dep <- r.dependencies) { - dep match { - case shufDep: ShuffleDependency[_,_,_] => - parents += getShuffleMapStage(shufDep) - case _ => - visit(dep.rdd) - } - } - } - } - visit(rdd) - parents.toList - } - - def getMissingParentStages(stage: Stage): List[Stage] = { - val missing = new HashSet[Stage] - val visited = new HashSet[RDD[_]] - def visit(rdd: RDD[_]) { - if (!visited(rdd)) { - visited += rdd - val locs = getCacheLocs(rdd) - for (p <- 0 until rdd.splits.size) { - if (locs(p) == Nil) { - for (dep <- rdd.dependencies) { - dep match { - case shufDep: ShuffleDependency[_,_,_] => - val stage = getShuffleMapStage(shufDep) - if (!stage.isAvailable) { - missing += stage - } - case narrowDep: NarrowDependency[_] => - visit(narrowDep.rdd) - } - } - } - } - } - } - visit(stage.rdd) - missing.toList - } - - override def runJob[T, U]( - finalRdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - partitions: Seq[Int], - allowLocal: Boolean) - (implicit m: ClassManifest[U]): Array[U] = { - lock.synchronized { - val runId = nextRunId.getAndIncrement() - - val outputParts = partitions.toArray - val numOutputParts: Int = partitions.size - val finalStage = newStage(finalRdd, None) - val results = new Array[U](numOutputParts) - val finished = new Array[Boolean](numOutputParts) - var numFinished = 0 - - val waiting = new HashSet[Stage] // stages we need to run whose parents aren't done - val running = new HashSet[Stage] // stages we are running right now - val failed = new HashSet[Stage] // stages that must be resubmitted due to fetch failures - val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // missing tasks from each stage - var lastFetchFailureTime: Long = 0 // used to wait a bit to avoid repeated resubmits - - SparkEnv.set(env) - - updateCacheLocs() - - logInfo("Final stage: " + finalStage) - logInfo("Parents of final stage: " + finalStage.parents) - logInfo("Missing parents: " + getMissingParentStages(finalStage)) - - // Optimization for short actions like first() and take() that can be computed locally - // without shipping tasks to the cluster. - if (allowLocal && finalStage.parents.size == 0 && numOutputParts == 1) { - logInfo("Computing the requested partition locally") - val split = finalRdd.splits(outputParts(0)) - val taskContext = new TaskContext(finalStage.id, outputParts(0), 0) - return Array(func(taskContext, finalRdd.iterator(split))) - } - - // Register the job ID so that we can get completion events for it - eventQueues(runId) = new Queue[CompletionEvent] - - def submitStage(stage: Stage) { - if (!waiting(stage) && !running(stage)) { - val missing = getMissingParentStages(stage) - if (missing == Nil) { - logInfo("Submitting " + stage + ", which has no missing parents") - submitMissingTasks(stage) - running += stage - } else { - for (parent <- missing) { - submitStage(parent) - } - waiting += stage - } - } - } - - def submitMissingTasks(stage: Stage) { - // Get our pending tasks and remember them in our pendingTasks entry - val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet) - var tasks = ArrayBuffer[Task[_]]() - if (stage == finalStage) { - for (id <- 0 until numOutputParts if (!finished(id))) { - val part = outputParts(id) - val locs = getPreferredLocs(finalRdd, part) - tasks += new ResultTask(runId, finalStage.id, finalRdd, func, part, locs, id) - } - } else { - for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { - val locs = getPreferredLocs(stage.rdd, p) - tasks += new ShuffleMapTask(runId, stage.id, stage.rdd, stage.shuffleDep.get, p, locs) - } - } - myPending ++= tasks - submitTasks(tasks, runId) - } - - submitStage(finalStage) - - while (numFinished != numOutputParts) { - val eventOption = waitForEvent(runId, POLL_TIMEOUT) - val time = System.currentTimeMillis // TODO: use a pluggable clock for testability - - // If we got an event off the queue, mark the task done or react to a fetch failure - if (eventOption != None) { - val evt = eventOption.get - val stage = idToStage(evt.task.stageId) - pendingTasks(stage) -= evt.task - if (evt.reason == Success) { - // A task ended - logInfo("Completed " + evt.task) - Accumulators.add(evt.accumUpdates) - evt.task match { - case rt: ResultTask[_, _] => - results(rt.outputId) = evt.result.asInstanceOf[U] - finished(rt.outputId) = true - numFinished += 1 - case smt: ShuffleMapTask => - val stage = idToStage(smt.stageId) - stage.addOutputLoc(smt.partition, evt.result.asInstanceOf[String]) - if (running.contains(stage) && pendingTasks(stage).isEmpty) { - logInfo(stage + " finished; looking for newly runnable stages") - running -= stage - if (stage.shuffleDep != None) { - mapOutputTracker.registerMapOutputs( - stage.shuffleDep.get.shuffleId, - stage.outputLocs.map(_.head).toArray) - } - updateCacheLocs() - val newlyRunnable = new ArrayBuffer[Stage] - for (stage <- waiting if getMissingParentStages(stage) == Nil) { - newlyRunnable += stage - } - waiting --= newlyRunnable - running ++= newlyRunnable - for (stage <- newlyRunnable) { - submitMissingTasks(stage) - } - } - } - } else { - evt.reason match { - case FetchFailed(serverUri, shuffleId, mapId, reduceId) => - // Mark the stage that the reducer was in as unrunnable - val failedStage = idToStage(evt.task.stageId) - running -= failedStage - failed += failedStage - // TODO: Cancel running tasks in the stage - logInfo("Marking " + failedStage + " for resubmision due to a fetch failure") - // Mark the map whose fetch failed as broken in the map stage - val mapStage = shuffleToMapStage(shuffleId) - mapStage.removeOutputLoc(mapId, serverUri) - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, serverUri) - logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission") - failed += mapStage - // Remember that a fetch failed now; this is used to resubmit the broken - // stages later, after a small wait (to give other tasks the chance to fail) - lastFetchFailureTime = time - // TODO: If there are a lot of fetch failures on the same node, maybe mark all - // outputs on the node as dead. - case _ => - // Non-fetch failure -- probably a bug in the job, so bail out - throw new SparkException("Task failed: " + evt.task + ", reason: " + evt.reason) - // TODO: Cancel all tasks that are still running - } - } - } // end if (evt != null) - - // If fetches have failed recently and we've waited for the right timeout, - // resubmit all the failed stages - if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) { - logInfo("Resubmitting failed stages") - updateCacheLocs() - for (stage <- failed) { - submitStage(stage) - } - failed.clear() - } - } - - eventQueues -= runId - return results - } - } - - def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = { - // If the partition is cached, return the cache locations - val cached = getCacheLocs(rdd)(partition) - if (cached != Nil) { - return cached - } - // If the RDD has some placement preferences (as is the case for input RDDs), get those - val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList - if (rddPrefs != Nil) { - return rddPrefs - } - // If the RDD has narrow dependencies, pick the first partition of the first narrow dep - // that has any placement preferences. Ideally we would choose based on transfer sizes, - // but this will do for now. - rdd.dependencies.foreach(_ match { - case n: NarrowDependency[_] => - for (inPart <- n.getParents(partition)) { - val locs = getPreferredLocs(n.rdd, inPart) - if (locs != Nil) - return locs; - } - case _ => - }) - return Nil - } - - // Assumes that lock is held on entrance, but will release it to wait for the next event. - def waitForEvent(runId: Int, timeout: Long): Option[CompletionEvent] = { - val endTime = System.currentTimeMillis() + timeout // TODO: Use pluggable clock for testing - while (eventQueues(runId).isEmpty) { - val time = System.currentTimeMillis() - if (time >= endTime) { - return None - } else { - lock.wait(endTime - time) - } - } - return Some(eventQueues(runId).dequeue()) - } -} diff --git a/core/src/main/scala/spark/Dependency.scala b/core/src/main/scala/spark/Dependency.scala index d93c84924a5038fb202157b907092591b1343ac8..c0ff94acc6266b3e25e1988d700680100affec24 100644 --- a/core/src/main/scala/spark/Dependency.scala +++ b/core/src/main/scala/spark/Dependency.scala @@ -8,7 +8,7 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd, false) { class ShuffleDependency[K, V, C]( val shuffleId: Int, - rdd: RDD[(K, V)], + @transient rdd: RDD[(K, V)], val aggregator: Aggregator[K, V, C], val partitioner: Partitioner) extends Dependency(rdd, true) diff --git a/core/src/main/scala/spark/DiskSpillingCache.scala b/core/src/main/scala/spark/DiskSpillingCache.scala deleted file mode 100644 index e11466eb64eec01e923bde295867653d88bb7706..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/DiskSpillingCache.scala +++ /dev/null @@ -1,75 +0,0 @@ -package spark - -import java.io.File -import java.io.{FileOutputStream,FileInputStream} -import java.io.IOException -import java.util.LinkedHashMap -import java.util.UUID - -// TODO: cache into a separate directory using Utils.createTempDir -// TODO: clean up disk cache afterwards -class DiskSpillingCache extends BoundedMemoryCache { - private val diskMap = new LinkedHashMap[(Any, Int), File](32, 0.75f, true) - - override def get(datasetId: Any, partition: Int): Any = { - synchronized { - val ser = SparkEnv.get.serializer.newInstance() - super.get(datasetId, partition) match { - case bytes: Any => // found in memory - ser.deserialize(bytes.asInstanceOf[Array[Byte]]) - - case _ => diskMap.get((datasetId, partition)) match { - case file: Any => // found on disk - try { - val startTime = System.currentTimeMillis - val bytes = new Array[Byte](file.length.toInt) - new FileInputStream(file).read(bytes) - val timeTaken = System.currentTimeMillis - startTime - logInfo("Reading key (%s, %d) of size %d bytes from disk took %d ms".format( - datasetId, partition, file.length, timeTaken)) - super.put(datasetId, partition, bytes) - ser.deserialize(bytes.asInstanceOf[Array[Byte]]) - } catch { - case e: IOException => - logWarning("Failed to read key (%s, %d) from disk at %s: %s".format( - datasetId, partition, file.getPath(), e.getMessage())) - diskMap.remove((datasetId, partition)) // remove dead entry - null - } - - case _ => // not found - null - } - } - } - } - - override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = { - var ser = SparkEnv.get.serializer.newInstance() - super.put(datasetId, partition, ser.serialize(value)) - } - - /** - * Spill the given entry to disk. Assumes that a lock is held on the - * DiskSpillingCache. Assumes that entry.value is a byte array. - */ - override protected def reportEntryDropped(datasetId: Any, partition: Int, entry: Entry) { - logInfo("Spilling key (%s, %d) of size %d to make space".format( - datasetId, partition, entry.size)) - val cacheDir = System.getProperty( - "spark.diskSpillingCache.cacheDir", - System.getProperty("java.io.tmpdir")) - val file = new File(cacheDir, "spark-dsc-" + UUID.randomUUID.toString) - try { - val stream = new FileOutputStream(file) - stream.write(entry.value.asInstanceOf[Array[Byte]]) - stream.close() - diskMap.put((datasetId, partition), file) - } catch { - case e: IOException => - logWarning("Failed to spill key (%s, %d) to disk at %s: %s".format( - datasetId, partition, file.getPath(), e.getMessage())) - // Do nothing and let the entry be discarded - } - } -} diff --git a/core/src/main/scala/spark/DoubleRDDFunctions.scala b/core/src/main/scala/spark/DoubleRDDFunctions.scala new file mode 100644 index 0000000000000000000000000000000000000000..1fbf66b7ded3c2e16ed708159be075e12ea0e8e3 --- /dev/null +++ b/core/src/main/scala/spark/DoubleRDDFunctions.scala @@ -0,0 +1,39 @@ +package spark + +import spark.partial.BoundedDouble +import spark.partial.MeanEvaluator +import spark.partial.PartialResult +import spark.partial.SumEvaluator + +import spark.util.StatCounter + +/** + * Extra functions available on RDDs of Doubles through an implicit conversion. + */ +class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { + def sum(): Double = { + self.reduce(_ + _) + } + + def stats(): StatCounter = { + self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b)) + } + + def mean(): Double = stats().mean + + def variance(): Double = stats().variance + + def stdev(): Double = stats().stdev + + def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { + val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) + val evaluator = new MeanEvaluator(self.splits.size, confidence) + self.context.runApproximateJob(self, processPartition, evaluator, timeout) + } + + def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { + val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) + val evaluator = new SumEvaluator(self.splits.size, confidence) + self.context.runApproximateJob(self, processPartition, evaluator, timeout) + } +} diff --git a/core/src/main/scala/spark/Executor.scala b/core/src/main/scala/spark/Executor.scala index c795b6c3519332a6ea3fe0a9193918a32ec69b99..af9eb9c878ede5fd39441c413bf72c56524b0b5f 100644 --- a/core/src/main/scala/spark/Executor.scala +++ b/core/src/main/scala/spark/Executor.scala @@ -10,9 +10,10 @@ import scala.collection.mutable.ArrayBuffer import com.google.protobuf.ByteString import org.apache.mesos._ -import org.apache.mesos.Protos._ +import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _} import spark.broadcast._ +import spark.scheduler._ /** * The Mesos executor for Spark. @@ -29,6 +30,9 @@ class Executor extends org.apache.mesos.Executor with Logging { executorInfo: ExecutorInfo, frameworkInfo: FrameworkInfo, slaveInfo: SlaveInfo) { + // Make sure the local hostname we report matches Mesos's name for this host + Utils.setCustomHostname(slaveInfo.getHostname()) + // Read spark.* system properties from executor arg val props = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) for ((key, value) <- props) { @@ -39,7 +43,7 @@ class Executor extends org.apache.mesos.Executor with Logging { RemoteActor.classLoader = getClass.getClassLoader // Initialize Spark environment (using system properties read above) - env = SparkEnv.createFromSystemProperties(false) + env = SparkEnv.createFromSystemProperties(false, false) SparkEnv.set(env) // Old stuff that isn't yet using env Broadcast.initialize(false) @@ -57,11 +61,11 @@ class Executor extends org.apache.mesos.Executor with Logging { override def reregistered(d: ExecutorDriver, s: SlaveInfo) {} - override def launchTask(d: ExecutorDriver, task: TaskInfo) { + override def launchTask(d: ExecutorDriver, task: MTaskInfo) { threadPool.execute(new TaskRunner(task, d)) } - class TaskRunner(info: TaskInfo, d: ExecutorDriver) + class TaskRunner(info: MTaskInfo, d: ExecutorDriver) extends Runnable { override def run() = { val tid = info.getTaskId.getValue @@ -74,11 +78,11 @@ class Executor extends org.apache.mesos.Executor with Logging { .setState(TaskState.TASK_RUNNING) .build()) try { + SparkEnv.set(env) + Thread.currentThread.setContextClassLoader(classLoader) Accumulators.clear - val task = ser.deserialize[Task[Any]](info.getData.toByteArray, classLoader) - for (gen <- task.generation) {// Update generation if any is set - env.mapOutputTracker.updateGeneration(gen) - } + val task = ser.deserialize[Task[Any]](info.getData.asReadOnlyByteBuffer, classLoader) + env.mapOutputTracker.updateGeneration(task.generation) val value = task.run(tid.toInt) val accumUpdates = Accumulators.values val result = new TaskResult(value, accumUpdates) @@ -105,9 +109,11 @@ class Executor extends org.apache.mesos.Executor with Logging { .setData(ByteString.copyFrom(ser.serialize(reason))) .build()) - // TODO: Handle errors in tasks less dramatically + // TODO: Should we exit the whole executor here? On the one hand, the failed task may + // have left some weird state around depending on when the exception was thrown, but on + // the other hand, maybe we could detect that when future tasks fail and exit then. logError("Exception in task ID " + tid, t) - System.exit(1) + //System.exit(1) } } } diff --git a/core/src/main/scala/spark/FetchFailedException.scala b/core/src/main/scala/spark/FetchFailedException.scala index a3c4e7873d7ac5b11468320008252c1d2b84a549..55512f4481af231aa13c7c4b629ccdcc6bd556b5 100644 --- a/core/src/main/scala/spark/FetchFailedException.scala +++ b/core/src/main/scala/spark/FetchFailedException.scala @@ -1,7 +1,9 @@ package spark +import spark.storage.BlockManagerId + class FetchFailedException( - val serverUri: String, + val bmAddress: BlockManagerId, val shuffleId: Int, val mapId: Int, val reduceId: Int, @@ -9,10 +11,10 @@ class FetchFailedException( extends Exception { override def getMessage(): String = - "Fetch failed: %s %d %d %d".format(serverUri, shuffleId, mapId, reduceId) + "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId) override def getCause(): Throwable = cause def toTaskEndReason: TaskEndReason = - FetchFailed(serverUri, shuffleId, mapId, reduceId) + FetchFailed(bmAddress, shuffleId, mapId, reduceId) } diff --git a/core/src/main/scala/spark/JavaSerializer.scala b/core/src/main/scala/spark/JavaSerializer.scala index 80f615eeb0a942f183d63128a3adec119101fcbe..ec5c33d1df0f639289401f0c9d5891f9bc57d9be 100644 --- a/core/src/main/scala/spark/JavaSerializer.scala +++ b/core/src/main/scala/spark/JavaSerializer.scala @@ -1,6 +1,7 @@ package spark import java.io._ +import java.nio.ByteBuffer class JavaSerializationStream(out: OutputStream) extends SerializationStream { val objOut = new ObjectOutputStream(out) @@ -9,10 +10,11 @@ class JavaSerializationStream(out: OutputStream) extends SerializationStream { def close() { objOut.close() } } -class JavaDeserializationStream(in: InputStream) extends DeserializationStream { +class JavaDeserializationStream(in: InputStream, loader: ClassLoader) +extends DeserializationStream { val objIn = new ObjectInputStream(in) { override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader) + Class.forName(desc.getName, false, loader) } def readObject[T](): T = objIn.readObject().asInstanceOf[T] @@ -20,35 +22,36 @@ class JavaDeserializationStream(in: InputStream) extends DeserializationStream { } class JavaSerializerInstance extends SerializerInstance { - def serialize[T](t: T): Array[Byte] = { + def serialize[T](t: T): ByteBuffer = { val bos = new ByteArrayOutputStream() - val out = outputStream(bos) + val out = serializeStream(bos) out.writeObject(t) out.close() - bos.toByteArray + ByteBuffer.wrap(bos.toByteArray) } - def deserialize[T](bytes: Array[Byte]): T = { - val bis = new ByteArrayInputStream(bytes) - val in = inputStream(bis) + def deserialize[T](bytes: ByteBuffer): T = { + val bis = new ByteArrayInputStream(bytes.array()) + val in = deserializeStream(bis) in.readObject().asInstanceOf[T] } - def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { - val bis = new ByteArrayInputStream(bytes) - val ois = new ObjectInputStream(bis) { - override def resolveClass(desc: ObjectStreamClass) = - Class.forName(desc.getName, false, loader) - } - return ois.readObject.asInstanceOf[T] + def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { + val bis = new ByteArrayInputStream(bytes.array()) + val in = deserializeStream(bis, loader) + in.readObject().asInstanceOf[T] } - def outputStream(s: OutputStream): SerializationStream = { + def serializeStream(s: OutputStream): SerializationStream = { new JavaSerializationStream(s) } - def inputStream(s: InputStream): DeserializationStream = { - new JavaDeserializationStream(s) + def deserializeStream(s: InputStream): DeserializationStream = { + new JavaDeserializationStream(s, currentThread.getContextClassLoader) + } + + def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = { + new JavaDeserializationStream(s, loader) } } diff --git a/core/src/main/scala/spark/Job.scala b/core/src/main/scala/spark/Job.scala deleted file mode 100644 index b7b0361c62c34c0377737b0328fe131a35d772e7..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/Job.scala +++ /dev/null @@ -1,16 +0,0 @@ -package spark - -import org.apache.mesos._ -import org.apache.mesos.Protos._ - -/** - * Class representing a parallel job in MesosScheduler. Schedules the job by implementing various - * callbacks. - */ -abstract class Job(val runId: Int, val jobId: Int) { - def slaveOffer(s: Offer, availableCpus: Double): Option[TaskInfo] - - def statusUpdate(t: TaskStatus): Unit - - def error(message: String): Unit -} diff --git a/core/src/main/scala/spark/KryoSerializer.scala b/core/src/main/scala/spark/KryoSerializer.scala index 5693613d6d45804767aeeab09c8990cb43babf43..65d0532bd58dddaea498fd4d9169eecfc4dea470 100644 --- a/core/src/main/scala/spark/KryoSerializer.scala +++ b/core/src/main/scala/spark/KryoSerializer.scala @@ -12,6 +12,8 @@ import com.esotericsoftware.kryo.{Serializer => KSerializer} import com.esotericsoftware.kryo.serialize.ClassSerializer import de.javakaffee.kryoserializers.KryoReflectionFactorySupport +import spark.storage._ + /** * Zig-zag encoder used to write object sizes to serialization streams. * Based on Kryo's integer encoder. @@ -64,57 +66,90 @@ object ZigZag { } } -class KryoSerializationStream(kryo: Kryo, buf: ByteBuffer, out: OutputStream) +class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream) extends SerializationStream { val channel = Channels.newChannel(out) def writeObject[T](t: T) { - kryo.writeClassAndObject(buf, t) - ZigZag.writeInt(buf.position(), out) - buf.flip() - channel.write(buf) - buf.clear() + kryo.writeClassAndObject(threadBuffer, t) + ZigZag.writeInt(threadBuffer.position(), out) + threadBuffer.flip() + channel.write(threadBuffer) + threadBuffer.clear() } def flush() { out.flush() } def close() { out.close() } } -class KryoDeserializationStream(buf: ObjectBuffer, in: InputStream) +class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream) extends DeserializationStream { def readObject[T](): T = { val len = ZigZag.readInt(in) - buf.readClassAndObject(in, len).asInstanceOf[T] + objectBuffer.readClassAndObject(in, len).asInstanceOf[T] } def close() { in.close() } } class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance { - val buf = ks.threadBuf.get() + val kryo = ks.kryo + val threadBuffer = ks.threadBuffer.get() + val objectBuffer = ks.objectBuffer.get() - def serialize[T](t: T): Array[Byte] = { - buf.writeClassAndObject(t) + def serialize[T](t: T): ByteBuffer = { + // Write it to our thread-local scratch buffer first to figure out the size, then return a new + // ByteBuffer of the appropriate size + threadBuffer.clear() + kryo.writeClassAndObject(threadBuffer, t) + val newBuf = ByteBuffer.allocate(threadBuffer.position) + threadBuffer.flip() + newBuf.put(threadBuffer) + newBuf.flip() + newBuf } - def deserialize[T](bytes: Array[Byte]): T = { - buf.readClassAndObject(bytes).asInstanceOf[T] + def deserialize[T](bytes: ByteBuffer): T = { + kryo.readClassAndObject(bytes).asInstanceOf[T] } - def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { - val oldClassLoader = ks.kryo.getClassLoader - ks.kryo.setClassLoader(loader) - val obj = buf.readClassAndObject(bytes).asInstanceOf[T] - ks.kryo.setClassLoader(oldClassLoader) + def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { + val oldClassLoader = kryo.getClassLoader + kryo.setClassLoader(loader) + val obj = kryo.readClassAndObject(bytes).asInstanceOf[T] + kryo.setClassLoader(oldClassLoader) obj } - def outputStream(s: OutputStream): SerializationStream = { - new KryoSerializationStream(ks.kryo, ks.threadByteBuf.get(), s) + def serializeStream(s: OutputStream): SerializationStream = { + threadBuffer.clear() + new KryoSerializationStream(kryo, threadBuffer, s) + } + + def deserializeStream(s: InputStream): DeserializationStream = { + new KryoDeserializationStream(objectBuffer, s) } - def inputStream(s: InputStream): DeserializationStream = { - new KryoDeserializationStream(buf, s) + override def serializeMany[T](iterator: Iterator[T]): ByteBuffer = { + threadBuffer.clear() + while (iterator.hasNext) { + val element = iterator.next() + // TODO: Do we also want to write the object's size? Doesn't seem necessary. + kryo.writeClassAndObject(threadBuffer, element) + } + val newBuf = ByteBuffer.allocate(threadBuffer.position) + threadBuffer.flip() + newBuf.put(threadBuffer) + newBuf.flip() + newBuf + } + + override def deserializeMany(buffer: ByteBuffer): Iterator[Any] = { + buffer.rewind() + new Iterator[Any] { + override def hasNext: Boolean = buffer.remaining > 0 + override def next(): Any = kryo.readClassAndObject(buffer) + } } } @@ -126,20 +161,17 @@ trait KryoRegistrator { class KryoSerializer extends Serializer with Logging { val kryo = createKryo() - val bufferSize = - System.getProperty("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 + val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024 - val threadBuf = new ThreadLocal[ObjectBuffer] { + val objectBuffer = new ThreadLocal[ObjectBuffer] { override def initialValue = new ObjectBuffer(kryo, bufferSize) } - val threadByteBuf = new ThreadLocal[ByteBuffer] { + val threadBuffer = new ThreadLocal[ByteBuffer] { override def initialValue = ByteBuffer.allocate(bufferSize) } def createKryo(): Kryo = { - // This is used so we can serialize/deserialize objects without a zero-arg - // constructor. val kryo = new KryoReflectionFactorySupport() // Register some commonly used classes @@ -148,14 +180,20 @@ class KryoSerializer extends Serializer with Logging { Array(1), Array(1.0), Array(1.0f), Array(1L), Array(""), Array(("", "")), Array(new java.lang.Object), Array(1.toByte), Array(true), Array('c'), // Specialized Tuple2s - ("", ""), (1, 1), (1.0, 1.0), (1L, 1L), + ("", ""), ("", 1), (1, 1), (1.0, 1.0), (1L, 1L), (1, 1.0), (1.0, 1), (1L, 1.0), (1.0, 1L), (1, 1L), (1L, 1), // Scala collections List(1), mutable.ArrayBuffer(1), // Options and Either Some(1), Left(1), Right(1), // Higher-dimensional tuples - (1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1) + (1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1), + None, + ByteBuffer.allocate(1), + StorageLevel.MEMORY_ONLY_DESER, + PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY_DESER), + GotBlock("1", ByteBuffer.allocate(1)), + GetBlock("1") ) for (obj <- toRegister) { kryo.register(obj.getClass) diff --git a/core/src/main/scala/spark/Logging.scala b/core/src/main/scala/spark/Logging.scala index 0d11ab9cbd836a5495f5392b942cb39ffd60e385..54bd57f6d3c94d2c17160f3ddaf38b1485f12e50 100644 --- a/core/src/main/scala/spark/Logging.scala +++ b/core/src/main/scala/spark/Logging.scala @@ -28,9 +28,11 @@ trait Logging { } // Log methods that take only a String - def logInfo(msg: => String) = if (log.isInfoEnabled) log.info(msg) + def logInfo(msg: => String) = if (log.isInfoEnabled /*&& msg.contains("job finished in")*/) log.info(msg) def logDebug(msg: => String) = if (log.isDebugEnabled) log.debug(msg) + + def logTrace(msg: => String) = if (log.isTraceEnabled) log.trace(msg) def logWarning(msg: => String) = if (log.isWarnEnabled) log.warn(msg) @@ -43,6 +45,9 @@ trait Logging { def logDebug(msg: => String, throwable: Throwable) = if (log.isDebugEnabled) log.debug(msg) + def logTrace(msg: => String, throwable: Throwable) = + if (log.isTraceEnabled) log.trace(msg) + def logWarning(msg: => String, throwable: Throwable) = if (log.isWarnEnabled) log.warn(msg, throwable) diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index a934c5a02fe30706ddb9d6ce7194743c91c40ca1..d938a6eb629867b0a45c9a4abbe24233a5947b5b 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -2,80 +2,80 @@ package spark import java.util.concurrent.ConcurrentHashMap -import scala.actors._ -import scala.actors.Actor._ -import scala.actors.remote._ +import akka.actor._ +import akka.actor.Actor +import akka.actor.Actor._ +import akka.util.duration._ + import scala.collection.mutable.HashSet +import spark.storage.BlockManagerId + sealed trait MapOutputTrackerMessage case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage case object StopMapOutputTracker extends MapOutputTrackerMessage -class MapOutputTrackerActor(serverUris: ConcurrentHashMap[Int, Array[String]]) -extends DaemonActor with Logging { - def act() { - val port = System.getProperty("spark.master.port").toInt - RemoteActor.alive(port) - RemoteActor.register('MapOutputTracker, self) - logInfo("Registered actor on port " + port) - - loop { - react { - case GetMapOutputLocations(shuffleId: Int) => - logInfo("Asked to get map output locations for shuffle " + shuffleId) - reply(serverUris.get(shuffleId)) - - case StopMapOutputTracker => - reply('OK) - exit() - } - } +class MapOutputTrackerActor(bmAddresses: ConcurrentHashMap[Int, Array[BlockManagerId]]) +extends Actor with Logging { + def receive = { + case GetMapOutputLocations(shuffleId: Int) => + logInfo("Asked to get map output locations for shuffle " + shuffleId) + self.reply(bmAddresses.get(shuffleId)) + + case StopMapOutputTracker => + logInfo("MapOutputTrackerActor stopped!") + self.reply(true) + self.exit() } } class MapOutputTracker(isMaster: Boolean) extends Logging { - var trackerActor: AbstractActor = null + val ip: String = System.getProperty("spark.master.host", "localhost") + val port: Int = System.getProperty("spark.master.port", "7077").toInt + val aName: String = "MapOutputTracker" - private var serverUris = new ConcurrentHashMap[Int, Array[String]] + private var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]] // Incremented every time a fetch fails so that client nodes know to clear // their cache of map output locations if this happens. private var generation: Long = 0 private var generationLock = new java.lang.Object - - if (isMaster) { - val tracker = new MapOutputTrackerActor(serverUris) - tracker.start() - trackerActor = tracker + + var trackerActor: ActorRef = if (isMaster) { + val actor = actorOf(new MapOutputTrackerActor(bmAddresses)) + remote.register(aName, actor) + logInfo("Registered MapOutputTrackerActor actor @ " + ip + ":" + port) + actor } else { - val host = System.getProperty("spark.master.host") - val port = System.getProperty("spark.master.port").toInt - trackerActor = RemoteActor.select(Node(host, port), 'MapOutputTracker) + remote.actorFor(aName, ip, port) } def registerShuffle(shuffleId: Int, numMaps: Int) { - if (serverUris.get(shuffleId) != null) { + if (bmAddresses.get(shuffleId) != null) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } - serverUris.put(shuffleId, new Array[String](numMaps)) + bmAddresses.put(shuffleId, new Array[BlockManagerId](numMaps)) } - def registerMapOutput(shuffleId: Int, mapId: Int, serverUri: String) { - var array = serverUris.get(shuffleId) + def registerMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { + var array = bmAddresses.get(shuffleId) array.synchronized { - array(mapId) = serverUri + array(mapId) = bmAddress } } - def registerMapOutputs(shuffleId: Int, locs: Array[String]) { - serverUris.put(shuffleId, Array[String]() ++ locs) + def registerMapOutputs(shuffleId: Int, locs: Array[BlockManagerId], changeGeneration: Boolean = false) { + bmAddresses.put(shuffleId, Array[BlockManagerId]() ++ locs) + if (changeGeneration) { + incrementGeneration() + } } - def unregisterMapOutput(shuffleId: Int, mapId: Int, serverUri: String) { - var array = serverUris.get(shuffleId) + def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { + var array = bmAddresses.get(shuffleId) if (array != null) { array.synchronized { - if (array(mapId) == serverUri) { + if (array(mapId) == bmAddress) { array(mapId) = null } } @@ -89,10 +89,10 @@ class MapOutputTracker(isMaster: Boolean) extends Logging { val fetching = new HashSet[Int] // Called on possibly remote nodes to get the server URIs for a given shuffle - def getServerUris(shuffleId: Int): Array[String] = { - val locs = serverUris.get(shuffleId) + def getServerAddresses(shuffleId: Int): Array[BlockManagerId] = { + val locs = bmAddresses.get(shuffleId) if (locs == null) { - logInfo("Don't have map outputs for " + shuffleId + ", fetching them") + logInfo("Don't have map outputs for shuffe " + shuffleId + ", fetching them") fetching.synchronized { if (fetching.contains(shuffleId)) { // Someone else is fetching it; wait for them to be done @@ -103,15 +103,17 @@ class MapOutputTracker(isMaster: Boolean) extends Logging { case _ => } } - return serverUris.get(shuffleId) + return bmAddresses.get(shuffleId) } else { fetching += shuffleId } } // We won the race to fetch the output locs; do so logInfo("Doing the fetch; tracker actor = " + trackerActor) - val fetched = (trackerActor !? GetMapOutputLocations(shuffleId)).asInstanceOf[Array[String]] - serverUris.put(shuffleId, fetched) + val fetched = (trackerActor ? GetMapOutputLocations(shuffleId)).as[Array[BlockManagerId]].get + + logInfo("Got the output locations") + bmAddresses.put(shuffleId, fetched) fetching.synchronized { fetching -= shuffleId fetching.notifyAll() @@ -121,14 +123,10 @@ class MapOutputTracker(isMaster: Boolean) extends Logging { return locs } } - - def getMapOutputUri(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int): String = { - "%s/shuffle/%s/%s/%s".format(serverUri, shuffleId, mapId, reduceId) - } def stop() { - trackerActor !? StopMapOutputTracker - serverUris.clear() + trackerActor !! StopMapOutputTracker + bmAddresses.clear() trackerActor = null } @@ -153,7 +151,7 @@ class MapOutputTracker(isMaster: Boolean) extends Logging { generationLock.synchronized { if (newGen > generation) { logInfo("Updating generation to " + newGen + " and clearing cache") - serverUris = new ConcurrentHashMap[Int, Array[String]] + bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]] generation = newGen } } diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 8b63d1aba1eeff4fd9a0c1fc99f37a87d0a9a7ec..ff6764e0a21d6f84d4dd1f0b7581451c12969565 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -4,14 +4,14 @@ import java.io.EOFException import java.net.URL import java.io.ObjectInputStream import java.util.concurrent.atomic.AtomicLong -import java.util.HashSet -import java.util.Random +import java.util.{HashMap => JHashMap} import java.util.Date import java.text.SimpleDateFormat +import scala.collection.Map import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.Map import scala.collection.mutable.HashMap +import scala.collection.JavaConversions._ import org.apache.hadoop.fs.Path import org.apache.hadoop.io.BytesWritable @@ -34,7 +34,9 @@ import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob} import org.apache.hadoop.mapreduce.TaskAttemptID import org.apache.hadoop.mapreduce.TaskAttemptContext -import SparkContext._ +import spark.SparkContext._ +import spark.partial.BoundedDouble +import spark.partial.PartialResult /** * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. @@ -43,19 +45,6 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( self: RDD[(K, V)]) extends Logging with Serializable { - - def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = { - def mergeMaps(m1: HashMap[K, V], m2: HashMap[K, V]): HashMap[K, V] = { - for ((k, v) <- m2) { - m1.get(k) match { - case None => m1(k) = v - case Some(w) => m1(k) = func(w, v) - } - } - return m1 - } - self.map(pair => HashMap(pair)).reduce(mergeMaps) - } def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, @@ -77,6 +66,39 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest]( def reduceByKey(func: (V, V) => V, numSplits: Int): RDD[(K, V)] = { combineByKey[V]((v: V) => v, func, func, numSplits) } + + def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = { + def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = { + val map = new JHashMap[K, V] + for ((k, v) <- iter) { + val old = map.get(k) + map.put(k, if (old == null) v else func(old, v)) + } + Iterator(map) + } + + def mergeMaps(m1: JHashMap[K, V], m2: JHashMap[K, V]): JHashMap[K, V] = { + for ((k, v) <- m2) { + val old = m1.get(k) + m1.put(k, if (old == null) v else func(old, v)) + } + return m1 + } + + self.mapPartitions(reducePartition).reduce(mergeMaps) + } + + // Alias for backwards compatibility + def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func) + + // TODO: This should probably be a distributed version + def countByKey(): Map[K, Long] = self.map(_._1).countByValue() + + // TODO: This should probably be a distributed version + def countByKeyApprox(timeout: Long, confidence: Double = 0.95) + : PartialResult[Map[K, BoundedDouble]] = { + self.map(_._1).countByValueApprox(timeout, confidence) + } def groupByKey(numSplits: Int): RDD[(K, Seq[V])] = { def createCombiner(v: V) = ArrayBuffer(v) diff --git a/core/src/main/scala/spark/ParallelShuffleFetcher.scala b/core/src/main/scala/spark/ParallelShuffleFetcher.scala deleted file mode 100644 index 19eb288e8460e599b501091f75b178cea388501a..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/ParallelShuffleFetcher.scala +++ /dev/null @@ -1,119 +0,0 @@ -package spark - -import java.io.ByteArrayInputStream -import java.io.EOFException -import java.net.URL -import java.util.concurrent.LinkedBlockingQueue -import java.util.concurrent.TimeUnit -import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.atomic.AtomicReference - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import it.unimi.dsi.fastutil.io.FastBufferedInputStream - - -class ParallelShuffleFetcher extends ShuffleFetcher with Logging { - val parallelFetches = System.getProperty("spark.parallel.fetches", "3").toInt - - def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) { - logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) - - // Figure out a list of input IDs (mapper IDs) for each server - val ser = SparkEnv.get.serializer.newInstance() - val inputsByUri = new HashMap[String, ArrayBuffer[Int]] - val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId) - for ((serverUri, index) <- serverUris.zipWithIndex) { - inputsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index - } - - // Randomize them and put them in a LinkedBlockingQueue - val serverQueue = new LinkedBlockingQueue[(String, ArrayBuffer[Int])] - for (pair <- Utils.randomize(inputsByUri)) { - serverQueue.put(pair) - } - - // Create a queue to hold the fetched data - val resultQueue = new LinkedBlockingQueue[Array[Byte]] - - // Atomic variables to communicate failures and # of fetches done - var failure = new AtomicReference[FetchFailedException](null) - - // Start multiple threads to do the fetching (TODO: may be possible to do it asynchronously) - for (i <- 0 until parallelFetches) { - new Thread("Fetch thread " + i + " for reduce " + reduceId) { - override def run() { - while (true) { - val pair = serverQueue.poll() - if (pair == null) - return - val (serverUri, inputIds) = pair - //logInfo("Pulled out server URI " + serverUri) - for (i <- inputIds) { - if (failure.get != null) - return - val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId) - logInfo("Starting HTTP request for " + url) - try { - val conn = new URL(url).openConnection() - conn.connect() - val len = conn.getContentLength() - if (len == -1) { - throw new SparkException("Content length was not specified by server") - } - val buf = new Array[Byte](len) - val in = new FastBufferedInputStream(conn.getInputStream()) - var pos = 0 - while (pos < len) { - val n = in.read(buf, pos, len-pos) - if (n == -1) { - throw new SparkException("EOF before reading the expected " + len + " bytes") - } else { - pos += n - } - } - // Done reading everything - resultQueue.put(buf) - in.close() - } catch { - case e: Exception => - logError("Fetch failed from " + url, e) - failure.set(new FetchFailedException(serverUri, shuffleId, i, reduceId, e)) - return - } - } - //logInfo("Done with server URI " + serverUri) - } - } - }.start() - } - - // Wait for results from the threads (either a failure or all servers done) - var resultsDone = 0 - var totalResults = inputsByUri.map{case (uri, inputs) => inputs.size}.sum - while (failure.get == null && resultsDone < totalResults) { - try { - val result = resultQueue.poll(100, TimeUnit.MILLISECONDS) - if (result != null) { - //logInfo("Pulled out a result") - val in = ser.inputStream(new ByteArrayInputStream(result)) - try { - while (true) { - val pair = in.readObject().asInstanceOf[(K, V)] - func(pair._1, pair._2) - } - } catch { - case e: EOFException => {} // TODO: cleaner way to detect EOF, such as a sentinel - } - resultsDone += 1 - //logInfo("Results done = " + resultsDone) - } - } catch { case e: InterruptedException => {} } - } - if (failure.get != null) { - throw failure.get - } - } -} diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala index ac61fe3b54526da22a0d812a485da167651a686e..8f3f0f5e15beca7c662feebb9f084120c0cd553f 100644 --- a/core/src/main/scala/spark/Partitioner.scala +++ b/core/src/main/scala/spark/Partitioner.scala @@ -70,4 +70,3 @@ class RangePartitioner[K <% Ordered[K]: ClassManifest, V]( false } } - diff --git a/core/src/main/scala/spark/PipedRDD.scala b/core/src/main/scala/spark/PipedRDD.scala index 8a5de3d7e96055ca839b476e661b0a9ed10035ad..9e0a01b5f9fb0357ca5eb0f599ccc2e567aef83b 100644 --- a/core/src/main/scala/spark/PipedRDD.scala +++ b/core/src/main/scala/spark/PipedRDD.scala @@ -3,6 +3,7 @@ package spark import java.io.PrintWriter import java.util.StringTokenizer +import scala.collection.Map import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import scala.io.Source diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index fa53d9be2c045de5bd0ba15a0597fdeb75761b74..22dcc27bad5ea303bc2a649a8dc6fc511334f91f 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -4,11 +4,14 @@ import java.io.EOFException import java.net.URL import java.io.ObjectInputStream import java.util.concurrent.atomic.AtomicLong -import java.util.HashSet import java.util.Random import java.util.Date +import java.util.{HashMap => JHashMap} import scala.collection.mutable.ArrayBuffer +import scala.collection.Map +import scala.collection.mutable.HashMap +import scala.collection.JavaConversions.mapAsScalaMap import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.NullWritable @@ -22,6 +25,14 @@ import org.apache.hadoop.mapred.OutputFormat import org.apache.hadoop.mapred.SequenceFileOutputFormat import org.apache.hadoop.mapred.TextOutputFormat +import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} + +import spark.partial.BoundedDouble +import spark.partial.CountEvaluator +import spark.partial.GroupedCountEvaluator +import spark.partial.PartialResult +import spark.storage.StorageLevel + import SparkContext._ /** @@ -61,19 +72,32 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial // Get a unique ID for this RDD val id = sc.newRddId() - // Variables relating to caching - private var shouldCache = false + // Variables relating to persistence + private var storageLevel: StorageLevel = StorageLevel.NONE - // Change this RDD's caching - def cache(): RDD[T] = { - shouldCache = true + // Change this RDD's storage level + def persist(newLevel: StorageLevel): RDD[T] = { + // TODO: Handle changes of StorageLevel + if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) { + throw new UnsupportedOperationException( + "Cannot change storage level of an RDD after it was already assigned a level") + } + storageLevel = newLevel this } + + // Turn on the default caching level for this RDD + def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY_DESER) + + // Turn on the default caching level for this RDD + def cache(): RDD[T] = persist() + + def getStorageLevel = storageLevel // Read this RDD; will read from cache if applicable, or otherwise compute final def iterator(split: Split): Iterator[T] = { - if (shouldCache) { - SparkEnv.get.cacheTracker.getOrCompute[T](this, split) + if (storageLevel != StorageLevel.NONE) { + SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel) } else { compute(split) } @@ -162,6 +186,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial Array.concat(results: _*) } + def toArray(): Array[T] = collect() + def reduce(f: (T, T) => T): T = { val cleanF = sc.clean(f) val reducePartition: Iterator[T] => Option[T] = iter => { @@ -222,7 +248,67 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial }).sum } - def toArray(): Array[T] = collect() + /** + * Approximate version of count() that returns a potentially incomplete result after a timeout. + */ + def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { + val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) => + var result = 0L + while (iter.hasNext) { + result += 1L + iter.next + } + result + } + val evaluator = new CountEvaluator(splits.size, confidence) + sc.runApproximateJob(this, countElements, evaluator, timeout) + } + + /** + * Count elements equal to each value, returning a map of (value, count) pairs. The final combine + * step happens locally on the master, equivalent to running a single reduce task. + * + * TODO: This should perhaps be distributed by default. + */ + def countByValue(): Map[T, Long] = { + def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = { + val map = new OLMap[T] + while (iter.hasNext) { + val v = iter.next() + map.put(v, map.getLong(v) + 1L) + } + Iterator(map) + } + def mergeMaps(m1: OLMap[T], m2: OLMap[T]): OLMap[T] = { + val iter = m2.object2LongEntrySet.fastIterator() + while (iter.hasNext) { + val entry = iter.next() + m1.put(entry.getKey, m1.getLong(entry.getKey) + entry.getLongValue) + } + return m1 + } + val myResult = mapPartitions(countPartition).reduce(mergeMaps) + myResult.asInstanceOf[java.util.Map[T, Long]] // Will be wrapped as a Scala mutable Map + } + + /** + * Approximate version of countByValue(). + */ + def countByValueApprox( + timeout: Long, + confidence: Double = 0.95 + ): PartialResult[Map[T, BoundedDouble]] = { + val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) => + val map = new OLMap[T] + while (iter.hasNext) { + val v = iter.next() + map.put(v, map.getLong(v) + 1L) + } + map + } + val evaluator = new GroupedCountEvaluator[T](splits.size, confidence) + sc.runApproximateJob(this, countPartition, evaluator, timeout) + } /** * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so diff --git a/core/src/main/scala/spark/Scheduler.scala b/core/src/main/scala/spark/Scheduler.scala deleted file mode 100644 index 6c7e569313b9f6a325b39c1606700715b90c56d9..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/Scheduler.scala +++ /dev/null @@ -1,27 +0,0 @@ -package spark - -/** - * Scheduler trait, implemented by both MesosScheduler and LocalScheduler. - */ -private trait Scheduler { - def start() - - // Wait for registration with Mesos. - def waitForRegister() - - /** - * Run a function on some partitions of an RDD, returning an array of results. The allowLocal - * flag specifies whether the scheduler is allowed to run the job on the master machine rather - * than shipping it to the cluster, for actions that create short jobs such as first() and take(). - */ - def runJob[T, U: ClassManifest]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - partitions: Seq[Int], - allowLocal: Boolean): Array[U] - - def stop() - - // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. - def defaultParallelism(): Int -} diff --git a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala index b213ca9dcbde6c70ad6ef03ca4c2150a84a1390f..9da73c4b028c8f70a085f5ec22a5891516c575d7 100644 --- a/core/src/main/scala/spark/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/spark/SequenceFileRDDFunctions.scala @@ -44,7 +44,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla } // TODO: use something like WritableConverter to avoid reflection } - c.asInstanceOf[Class[ _ <: Writable]] + c.asInstanceOf[Class[_ <: Writable]] } def saveAsSequenceFile(path: String) { diff --git a/core/src/main/scala/spark/Serializer.scala b/core/src/main/scala/spark/Serializer.scala index 2429bbfeb927445e887359465d54a8c8aafcade8..61a70beaf1fd73566443f8cf7e05c2317eceafd4 100644 --- a/core/src/main/scala/spark/Serializer.scala +++ b/core/src/main/scala/spark/Serializer.scala @@ -1,6 +1,12 @@ package spark -import java.io.{InputStream, OutputStream} +import java.io.{EOFException, InputStream, OutputStream} +import java.nio.ByteBuffer +import java.nio.channels.Channels + +import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream + +import spark.util.ByteBufferInputStream /** * A serializer. Because some serialization libraries are not thread safe, this class is used to @@ -14,11 +20,31 @@ trait Serializer { * An instance of the serializer, for use by one thread at a time. */ trait SerializerInstance { - def serialize[T](t: T): Array[Byte] - def deserialize[T](bytes: Array[Byte]): T - def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T - def outputStream(s: OutputStream): SerializationStream - def inputStream(s: InputStream): DeserializationStream + def serialize[T](t: T): ByteBuffer + + def deserialize[T](bytes: ByteBuffer): T + + def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T + + def serializeStream(s: OutputStream): SerializationStream + + def deserializeStream(s: InputStream): DeserializationStream + + def serializeMany[T](iterator: Iterator[T]): ByteBuffer = { + // Default implementation uses serializeStream + val stream = new FastByteArrayOutputStream() + serializeStream(stream).writeAll(iterator) + val buffer = ByteBuffer.allocate(stream.position.toInt) + buffer.put(stream.array, 0, stream.position.toInt) + buffer.flip() + buffer + } + + def deserializeMany(buffer: ByteBuffer): Iterator[Any] = { + // Default implementation uses deserializeStream + buffer.rewind() + deserializeStream(new ByteBufferInputStream(buffer)).toIterator + } } /** @@ -28,6 +54,13 @@ trait SerializationStream { def writeObject[T](t: T): Unit def flush(): Unit def close(): Unit + + def writeAll[T](iter: Iterator[T]): SerializationStream = { + while (iter.hasNext) { + writeObject(iter.next()) + } + this + } } /** @@ -36,4 +69,45 @@ trait SerializationStream { trait DeserializationStream { def readObject[T](): T def close(): Unit + + /** + * Read the elements of this stream through an iterator. This can only be called once, as + * reading each element will consume data from the input source. + */ + def toIterator: Iterator[Any] = new Iterator[Any] { + var gotNext = false + var finished = false + var nextValue: Any = null + + private def getNext() { + try { + nextValue = readObject[Any]() + } catch { + case eof: EOFException => + finished = true + } + gotNext = true + } + + override def hasNext: Boolean = { + if (!gotNext) { + getNext() + } + if (finished) { + close() + } + !finished + } + + override def next(): Any = { + if (!gotNext) { + getNext() + } + if (finished) { + throw new NoSuchElementException("End of stream") + } + gotNext = false + nextValue + } + } } diff --git a/core/src/main/scala/spark/SerializingCache.scala b/core/src/main/scala/spark/SerializingCache.scala deleted file mode 100644 index 3d192f24034a0f5a59a7247bf2850ba29efbbc80..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/SerializingCache.scala +++ /dev/null @@ -1,26 +0,0 @@ -package spark - -import java.io._ - -/** - * Wrapper around a BoundedMemoryCache that stores serialized objects as byte arrays in order to - * reduce storage cost and GC overhead - */ -class SerializingCache extends Cache with Logging { - val bmc = new BoundedMemoryCache - - override def put(datasetId: Any, partition: Int, value: Any): CachePutResponse = { - val ser = SparkEnv.get.serializer.newInstance() - bmc.put(datasetId, partition, ser.serialize(value)) - } - - override def get(datasetId: Any, partition: Int): Any = { - val bytes = bmc.get(datasetId, partition) - if (bytes != null) { - val ser = SparkEnv.get.serializer.newInstance() - return ser.deserialize(bytes.asInstanceOf[Array[Byte]]) - } else { - return null - } - } -} diff --git a/core/src/main/scala/spark/ShuffleMapTask.scala b/core/src/main/scala/spark/ShuffleMapTask.scala deleted file mode 100644 index 5fc59af06c039f6d74638c63cea13ad824058e40..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/ShuffleMapTask.scala +++ /dev/null @@ -1,56 +0,0 @@ -package spark - -import java.io.BufferedOutputStream -import java.io.FileOutputStream -import java.io.ObjectOutputStream -import java.util.{HashMap => JHashMap} - -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream - -class ShuffleMapTask( - runId: Int, - stageId: Int, - rdd: RDD[_], - dep: ShuffleDependency[_,_,_], - val partition: Int, - locs: Seq[String]) - extends DAGTask[String](runId, stageId) - with Logging { - - val split = rdd.splits(partition) - - override def run (attemptId: Int): String = { - val numOutputSplits = dep.partitioner.numPartitions - val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]] - val partitioner = dep.partitioner.asInstanceOf[Partitioner] - val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any]) - for (elem <- rdd.iterator(split)) { - val (k, v) = elem.asInstanceOf[(Any, Any)] - var bucketId = partitioner.getPartition(k) - val bucket = buckets(bucketId) - var existing = bucket.get(k) - if (existing == null) { - bucket.put(k, aggregator.createCombiner(v)) - } else { - bucket.put(k, aggregator.mergeValue(existing, v)) - } - } - val ser = SparkEnv.get.serializer.newInstance() - for (i <- 0 until numOutputSplits) { - val file = SparkEnv.get.shuffleManager.getOutputFile(dep.shuffleId, partition, i) - val out = ser.outputStream(new FastBufferedOutputStream(new FileOutputStream(file))) - val iter = buckets(i).entrySet().iterator() - while (iter.hasNext()) { - val entry = iter.next() - out.writeObject((entry.getKey, entry.getValue)) - } - // TODO: have some kind of EOF marker - out.close() - } - return SparkEnv.get.shuffleManager.getServerUri - } - - override def preferredLocations: Seq[String] = locs - - override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition) -} diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala index 5efc8cf50b8ef27154c59a2bf00bd7a3d2220114..5434197ecad3330fb000b6c5a3238453e16a3b19 100644 --- a/core/src/main/scala/spark/ShuffledRDD.scala +++ b/core/src/main/scala/spark/ShuffledRDD.scala @@ -8,7 +8,7 @@ class ShuffledRDDSplit(val idx: Int) extends Split { } class ShuffledRDD[K, V, C]( - parent: RDD[(K, V)], + @transient parent: RDD[(K, V)], aggregator: Aggregator[K, V, C], part : Partitioner) extends RDD[(K, C)](parent.context) { diff --git a/core/src/main/scala/spark/SimpleShuffleFetcher.scala b/core/src/main/scala/spark/SimpleShuffleFetcher.scala deleted file mode 100644 index 196c64cf1fb76758c9d1251dc296ddcb58d863cd..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/SimpleShuffleFetcher.scala +++ /dev/null @@ -1,46 +0,0 @@ -package spark - -import java.io.EOFException -import java.net.URL - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap - -import it.unimi.dsi.fastutil.io.FastBufferedInputStream - -class SimpleShuffleFetcher extends ShuffleFetcher with Logging { - def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) { - logInfo("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) - val ser = SparkEnv.get.serializer.newInstance() - val splitsByUri = new HashMap[String, ArrayBuffer[Int]] - val serverUris = SparkEnv.get.mapOutputTracker.getServerUris(shuffleId) - for ((serverUri, index) <- serverUris.zipWithIndex) { - splitsByUri.getOrElseUpdate(serverUri, ArrayBuffer()) += index - } - for ((serverUri, inputIds) <- Utils.randomize(splitsByUri)) { - for (i <- inputIds) { - try { - val url = "%s/shuffle/%d/%d/%d".format(serverUri, shuffleId, i, reduceId) - // TODO: multithreaded fetch - // TODO: would be nice to retry multiple times - val inputStream = ser.inputStream( - new FastBufferedInputStream(new URL(url).openStream())) - try { - while (true) { - val pair = inputStream.readObject().asInstanceOf[(K, V)] - func(pair._1, pair._2) - } - } finally { - inputStream.close() - } - } catch { - case e: EOFException => {} // We currently assume EOF means we read the whole thing - case other: Exception => { - logError("Fetch failed", other) - throw new FetchFailedException(serverUri, shuffleId, i, reduceId, other) - } - } - } - } - } -} diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 6e019d6e7f10c345bb79a7452124384e46a8c12b..7a9a70fee0111475ab02993b921405b4bea63af9 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -3,6 +3,9 @@ package spark import java.io._ import java.util.concurrent.atomic.AtomicInteger +import akka.actor.Actor +import akka.actor.Actor._ + import scala.actors.remote.RemoteActor import scala.collection.mutable.ArrayBuffer @@ -32,6 +35,15 @@ import org.apache.mesos.MesosNativeLibrary import spark.broadcast._ +import spark.partial.ApproximateEvaluator +import spark.partial.PartialResult + +import spark.scheduler.DAGScheduler +import spark.scheduler.TaskScheduler +import spark.scheduler.local.LocalScheduler +import spark.scheduler.mesos.MesosScheduler +import spark.scheduler.mesos.CoarseMesosScheduler + class SparkContext( master: String, frameworkName: String, @@ -54,14 +66,19 @@ class SparkContext( if (RemoteActor.classLoader == null) { RemoteActor.classLoader = getClass.getClassLoader } + + remote.start(System.getProperty("spark.master.host"), + System.getProperty("spark.master.port").toInt) + private val isLocal = master.startsWith("local") // TODO: better check for local + // Create the Spark execution environment (cache, map output tracker, etc) - val env = SparkEnv.createFromSystemProperties(true) + val env = SparkEnv.createFromSystemProperties(true, isLocal) SparkEnv.set(env) Broadcast.initialize(true) // Create and start the scheduler - private var scheduler: Scheduler = { + private var taskScheduler: TaskScheduler = { // Regular expression used for local[N] master format val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r // Regular expression for local[N, maxRetries], used in tests with failing tasks @@ -74,13 +91,17 @@ class SparkContext( case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => new LocalScheduler(threads.toInt, maxFailures.toInt) case _ => - MesosNativeLibrary.load() - new MesosScheduler(this, master, frameworkName) + System.loadLibrary("mesos") + if (System.getProperty("spark.mesos.coarse", "false") == "true") { + new CoarseMesosScheduler(this, master, frameworkName) + } else { + new MesosScheduler(this, master, frameworkName) + } } } - scheduler.start() + taskScheduler.start() - private val isLocal = scheduler.isInstanceOf[LocalScheduler] + private var dagScheduler = new DAGScheduler(taskScheduler) // Methods for creating RDDs @@ -237,19 +258,21 @@ class SparkContext( // Stop the SparkContext def stop() { - scheduler.stop() - scheduler = null + dagScheduler.stop() + dagScheduler = null + taskScheduler = null // TODO: Broadcast.stop(), Cache.stop()? env.mapOutputTracker.stop() env.cacheTracker.stop() env.shuffleFetcher.stop() env.shuffleManager.stop() + env.connectionManager.stop() SparkEnv.set(null) } - // Wait for the scheduler to be registered + // Wait for the scheduler to be registered with the cluster manager def waitForRegister() { - scheduler.waitForRegister() + taskScheduler.waitForRegister() } // Get Spark's home location from either a value set through the constructor, @@ -281,7 +304,7 @@ class SparkContext( ): Array[U] = { logInfo("Starting job...") val start = System.nanoTime - val result = scheduler.runJob(rdd, func, partitions, allowLocal) + val result = dagScheduler.runJob(rdd, func, partitions, allowLocal) logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s") result } @@ -306,6 +329,22 @@ class SparkContext( runJob(rdd, func, 0 until rdd.splits.size, false) } + /** + * Run a job that can return approximate results. + */ + def runApproximateJob[T, U, R]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + evaluator: ApproximateEvaluator[U, R], + timeout: Long + ): PartialResult[R] = { + logInfo("Starting job...") + val start = System.nanoTime + val result = dagScheduler.runApproximateJob(rdd, func, evaluator, timeout) + logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s") + result + } + // Clean a closure to make it ready to serialized and send to tasks // (removes unreferenced variables in $outer's, updates REPL variables) private[spark] def clean[F <: AnyRef](f: F): F = { @@ -314,7 +353,7 @@ class SparkContext( } // Default level of parallelism to use when not given by user (e.g. for reduce tasks) - def defaultParallelism: Int = scheduler.defaultParallelism + def defaultParallelism: Int = taskScheduler.defaultParallelism // Default min number of splits for Hadoop RDDs when not given by user def defaultMinSplits: Int = math.min(defaultParallelism, 2) @@ -349,15 +388,23 @@ object SparkContext { } // TODO: Add AccumulatorParams for other types, e.g. lists and strings + implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) = new PairRDDFunctions(rdd) - - implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](rdd: RDD[(K, V)]) = + + implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest]( + rdd: RDD[(K, V)]) = new SequenceFileRDDFunctions(rdd) - implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) = + implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( + rdd: RDD[(K, V)]) = new OrderedRDDFunctions(rdd) + implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd) + + implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = + new DoubleRDDFunctions(rdd.map(x => num.toDouble(x))) + // Implicit conversions to common Writable types, for saveAsSequenceFile implicit def intToIntWritable(i: Int) = new IntWritable(i) diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index cd752f8b6597e6feb97a1d1e582070dae745f628..897a5ef82d0913cf3d263d0d7db4e6986c4387d9 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -1,14 +1,26 @@ package spark +import akka.actor.Actor + +import spark.storage.BlockManager +import spark.storage.BlockManagerMaster +import spark.network.ConnectionManager + class SparkEnv ( - val cache: Cache, - val serializer: Serializer, - val closureSerializer: Serializer, - val cacheTracker: CacheTracker, - val mapOutputTracker: MapOutputTracker, - val shuffleFetcher: ShuffleFetcher, - val shuffleManager: ShuffleManager -) + val cache: Cache, + val serializer: Serializer, + val closureSerializer: Serializer, + val cacheTracker: CacheTracker, + val mapOutputTracker: MapOutputTracker, + val shuffleFetcher: ShuffleFetcher, + val shuffleManager: ShuffleManager, + val blockManager: BlockManager, + val connectionManager: ConnectionManager + ) { + + /** No-parameter constructor for unit tests. */ + def this() = this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null) +} object SparkEnv { private val env = new ThreadLocal[SparkEnv] @@ -21,36 +33,55 @@ object SparkEnv { env.get() } - def createFromSystemProperties(isMaster: Boolean): SparkEnv = { - val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache") - val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache] - - val serializerClass = System.getProperty("spark.serializer", "spark.JavaSerializer") + def createFromSystemProperties(isMaster: Boolean, isLocal: Boolean): SparkEnv = { + val serializerClass = System.getProperty("spark.serializer", "spark.KryoSerializer") val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer] + + BlockManagerMaster.startBlockManagerMaster(isMaster, isLocal) + + var blockManager = new BlockManager(serializer) + + val connectionManager = blockManager.connectionManager + + val shuffleManager = new ShuffleManager() val closureSerializerClass = System.getProperty("spark.closure.serializer", "spark.JavaSerializer") val closureSerializer = Class.forName(closureSerializerClass).newInstance().asInstanceOf[Serializer] + val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache") + val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache] - val cacheTracker = new CacheTracker(isMaster, cache) + val cacheTracker = new CacheTracker(isMaster, blockManager) + blockManager.cacheTracker = cacheTracker val mapOutputTracker = new MapOutputTracker(isMaster) val shuffleFetcherClass = - System.getProperty("spark.shuffle.fetcher", "spark.SimpleShuffleFetcher") + System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") val shuffleFetcher = Class.forName(shuffleFetcherClass).newInstance().asInstanceOf[ShuffleFetcher] - val shuffleMgr = new ShuffleManager() + /* + if (System.getProperty("spark.stream.distributed", "false") == "true") { + val blockManagerClass = classOf[spark.storage.BlockManager].asInstanceOf[Class[_]] + if (isLocal || !isMaster) { + (new Thread() { + override def run() { + println("Wait started") + Thread.sleep(60000) + println("Wait ended") + val receiverClass = Class.forName("spark.stream.TestStreamReceiver4") + val constructor = receiverClass.getConstructor(blockManagerClass) + val receiver = constructor.newInstance(blockManager) + receiver.asInstanceOf[Thread].start() + } + }).start() + } + } + */ - new SparkEnv( - cache, - serializer, - closureSerializer, - cacheTracker, - mapOutputTracker, - shuffleFetcher, - shuffleMgr) + new SparkEnv(cache, serializer, closureSerializer, cacheTracker, mapOutputTracker, shuffleFetcher, + shuffleManager, blockManager, connectionManager) } } diff --git a/core/src/main/scala/spark/Stage.scala b/core/src/main/scala/spark/Stage.scala deleted file mode 100644 index 9452ea3a8e57db93c4cc31744a80bef8b3dfbd15..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/Stage.scala +++ /dev/null @@ -1,41 +0,0 @@ -package spark - -class Stage( - val id: Int, - val rdd: RDD[_], - val shuffleDep: Option[ShuffleDependency[_,_,_]], - val parents: List[Stage]) { - - val isShuffleMap = shuffleDep != None - val numPartitions = rdd.splits.size - val outputLocs = Array.fill[List[String]](numPartitions)(Nil) - var numAvailableOutputs = 0 - - def isAvailable: Boolean = { - if (parents.size == 0 && !isShuffleMap) { - true - } else { - numAvailableOutputs == numPartitions - } - } - - def addOutputLoc(partition: Int, host: String) { - val prevList = outputLocs(partition) - outputLocs(partition) = host :: prevList - if (prevList == Nil) - numAvailableOutputs += 1 - } - - def removeOutputLoc(partition: Int, host: String) { - val prevList = outputLocs(partition) - val newList = prevList.filterNot(_ == host) - outputLocs(partition) = newList - if (prevList != Nil && newList == Nil) { - numAvailableOutputs -= 1 - } - } - - override def toString = "Stage " + id - - override def hashCode(): Int = id -} diff --git a/core/src/main/scala/spark/Task.scala b/core/src/main/scala/spark/Task.scala deleted file mode 100644 index bc3b3743447bda9d887bbbe970beb2ef52dbf38e..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/Task.scala +++ /dev/null @@ -1,9 +0,0 @@ -package spark - -class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Int) extends Serializable - -abstract class Task[T] extends Serializable { - def run(id: Int): T - def preferredLocations: Seq[String] = Nil - def generation: Option[Long] = None -} diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala new file mode 100644 index 0000000000000000000000000000000000000000..7a6214aab6648f6e7f5670b9839f3582dbe628bb --- /dev/null +++ b/core/src/main/scala/spark/TaskContext.scala @@ -0,0 +1,3 @@ +package spark + +class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Int) extends Serializable diff --git a/core/src/main/scala/spark/TaskEndReason.scala b/core/src/main/scala/spark/TaskEndReason.scala new file mode 100644 index 0000000000000000000000000000000000000000..6e4eb25ed44ff07e94085ebaa0d01c736a2839ed --- /dev/null +++ b/core/src/main/scala/spark/TaskEndReason.scala @@ -0,0 +1,16 @@ +package spark + +import spark.storage.BlockManagerId + +/** + * Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry + * tasks several times for "ephemeral" failures, and only report back failures that require some + * old stages to be resubmitted, such as shuffle map fetch failures. + */ +sealed trait TaskEndReason + +case object Success extends TaskEndReason +case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it +case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason +case class ExceptionFailure(exception: Throwable) extends TaskEndReason +case class OtherFailure(message: String) extends TaskEndReason diff --git a/core/src/main/scala/spark/TaskResult.scala b/core/src/main/scala/spark/TaskResult.scala deleted file mode 100644 index 2b7fd1a4b225e74dae4da46ad14d8b2cba0a87e9..0000000000000000000000000000000000000000 --- a/core/src/main/scala/spark/TaskResult.scala +++ /dev/null @@ -1,8 +0,0 @@ -package spark - -import scala.collection.mutable.Map - -// Task result. Also contains updates to accumulator variables. -// TODO: Use of distributed cache to return result is a hack to get around -// what seems to be a bug with messages over 60KB in libprocess; fix it -private class TaskResult[T](val value: T, val accumUpdates: Map[Long, Any]) extends Serializable diff --git a/core/src/main/scala/spark/UnionRDD.scala b/core/src/main/scala/spark/UnionRDD.scala index 4c0f255e6bb767e61ed3864f3e3600f237692247..17522e2bbb6d1077d4d8caefc778753229d820d2 100644 --- a/core/src/main/scala/spark/UnionRDD.scala +++ b/core/src/main/scala/spark/UnionRDD.scala @@ -33,7 +33,8 @@ class UnionRDD[T: ClassManifest]( override def splits = splits_ - @transient override val dependencies = { + @transient + override val dependencies = { val deps = new ArrayBuffer[Dependency[_]] var pos = 0 for ((rdd, index) <- rdds.zipWithIndex) { diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala index cfd6dc8b2aa3550e0f47dfdfbcc85732a72cd050..742e60b176f3b8103491ee01dfdc5219810a0fba 100644 --- a/core/src/main/scala/spark/Utils.scala +++ b/core/src/main/scala/spark/Utils.scala @@ -118,6 +118,23 @@ object Utils { * Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4). */ def localIpAddress(): String = InetAddress.getLocalHost.getHostAddress + + private var customHostname: Option[String] = None + + /** + * Allow setting a custom host name because when we run on Mesos we need to use the same + * hostname it reports to the master. + */ + def setCustomHostname(hostname: String) { + customHostname = Some(hostname) + } + + /** + * Get the local machine's hostname + */ + def localHostName(): String = { + customHostname.getOrElse(InetAddress.getLocalHost.getHostName) + } /** * Returns a standard ThreadFactory except all threads are daemons. @@ -142,6 +159,14 @@ object Utils { return threadPool } + + /** + * Return the string to tell how long has passed in seconds. The passing parameter should be in + * millisecond. + */ + def getUsedTimeMs(startTimeMs: Long): String = { + return " " + (System.currentTimeMillis - startTimeMs) + " ms " + } /** * Wrapper over newFixedThreadPool. @@ -154,16 +179,6 @@ object Utils { return threadPool } - /** - * Get the local machine's hostname. - */ - def localHostName(): String = InetAddress.getLocalHost.getHostName - - /** - * Get current host - */ - def getHost = System.getProperty("spark.hostname", localHostName()) - /** * Delete a file or directory and its contents recursively. */ diff --git a/core/src/main/scala/spark/network/Connection.scala b/core/src/main/scala/spark/network/Connection.scala new file mode 100644 index 0000000000000000000000000000000000000000..4546dfa0fac1b6c7f07d708a42abac2f4cedbdaa --- /dev/null +++ b/core/src/main/scala/spark/network/Connection.scala @@ -0,0 +1,364 @@ +package spark.network + +import spark._ + +import scala.collection.mutable.{HashMap, Queue, ArrayBuffer} + +import java.io._ +import java.nio._ +import java.nio.channels._ +import java.nio.channels.spi._ +import java.net._ + + +abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging { + + channel.configureBlocking(false) + channel.socket.setTcpNoDelay(true) + channel.socket.setReuseAddress(true) + channel.socket.setKeepAlive(true) + /*channel.socket.setReceiveBufferSize(32768) */ + + var onCloseCallback: Connection => Unit = null + var onExceptionCallback: (Connection, Exception) => Unit = null + var onKeyInterestChangeCallback: (Connection, Int) => Unit = null + + lazy val remoteAddress = getRemoteAddress() + lazy val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress) + + def key() = channel.keyFor(selector) + + def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress] + + def read() { + throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString) + } + + def write() { + throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString) + } + + def close() { + key.cancel() + channel.close() + callOnCloseCallback() + } + + def onClose(callback: Connection => Unit) {onCloseCallback = callback} + + def onException(callback: (Connection, Exception) => Unit) {onExceptionCallback = callback} + + def onKeyInterestChange(callback: (Connection, Int) => Unit) {onKeyInterestChangeCallback = callback} + + def callOnExceptionCallback(e: Exception) { + if (onExceptionCallback != null) { + onExceptionCallback(this, e) + } else { + logError("Error in connection to " + remoteConnectionManagerId + + " and OnExceptionCallback not registered", e) + } + } + + def callOnCloseCallback() { + if (onCloseCallback != null) { + onCloseCallback(this) + } else { + logWarning("Connection to " + remoteConnectionManagerId + + " closed and OnExceptionCallback not registered") + } + + } + + def changeConnectionKeyInterest(ops: Int) { + if (onKeyInterestChangeCallback != null) { + onKeyInterestChangeCallback(this, ops) + } else { + throw new Exception("OnKeyInterestChangeCallback not registered") + } + } + + def printRemainingBuffer(buffer: ByteBuffer) { + val bytes = new Array[Byte](buffer.remaining) + val curPosition = buffer.position + buffer.get(bytes) + bytes.foreach(x => print(x + " ")) + buffer.position(curPosition) + print(" (" + bytes.size + ")") + } + + def printBuffer(buffer: ByteBuffer, position: Int, length: Int) { + val bytes = new Array[Byte](length) + val curPosition = buffer.position + buffer.position(position) + buffer.get(bytes) + bytes.foreach(x => print(x + " ")) + print(" (" + position + ", " + length + ")") + buffer.position(curPosition) + } + +} + + +class SendingConnection(val address: InetSocketAddress, selector_ : Selector) +extends Connection(SocketChannel.open, selector_) { + + class Outbox(fair: Int = 0) { + val messages = new Queue[Message]() + val defaultChunkSize = 65536 //32768 //16384 + var nextMessageToBeUsed = 0 + + def addMessage(message: Message): Unit = { + messages.synchronized{ + /*messages += message*/ + messages.enqueue(message) + logInfo("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]") + } + } + + def getChunk(): Option[MessageChunk] = { + fair match { + case 0 => getChunkFIFO() + case 1 => getChunkRR() + case _ => throw new Exception("Unexpected fairness policy in outbox") + } + } + + private def getChunkFIFO(): Option[MessageChunk] = { + /*logInfo("Using FIFO")*/ + messages.synchronized { + while (!messages.isEmpty) { + val message = messages(0) + val chunk = message.getChunkForSending(defaultChunkSize) + if (chunk.isDefined) { + messages += message // this is probably incorrect, it wont work as fifo + if (!message.started) logDebug("Starting to send [" + message + "]") + message.started = true + return chunk + } + /*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/ + logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken ) + } + } + None + } + + private def getChunkRR(): Option[MessageChunk] = { + messages.synchronized { + while (!messages.isEmpty) { + /*nextMessageToBeUsed = nextMessageToBeUsed % messages.size */ + /*val message = messages(nextMessageToBeUsed)*/ + val message = messages.dequeue + val chunk = message.getChunkForSending(defaultChunkSize) + if (chunk.isDefined) { + messages.enqueue(message) + nextMessageToBeUsed = nextMessageToBeUsed + 1 + if (!message.started) { + logDebug("Starting to send [" + message + "] to [" + remoteConnectionManagerId + "]") + message.started = true + message.startTime = System.currentTimeMillis + } + logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]") + return chunk + } + /*messages -= message*/ + message.finishTime = System.currentTimeMillis + logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken ) + } + } + None + } + } + + val outbox = new Outbox(1) + val currentBuffers = new ArrayBuffer[ByteBuffer]() + + /*channel.socket.setSendBufferSize(256 * 1024)*/ + + override def getRemoteAddress() = address + + def send(message: Message) { + outbox.synchronized { + outbox.addMessage(message) + if (channel.isConnected) { + changeConnectionKeyInterest(SelectionKey.OP_WRITE) + } + } + } + + def connect() { + try{ + channel.connect(address) + channel.register(selector, SelectionKey.OP_CONNECT) + logInfo("Initiating connection to [" + address + "]") + } catch { + case e: Exception => { + logError("Error connecting to " + address, e) + callOnExceptionCallback(e) + } + } + } + + def finishConnect() { + try { + channel.finishConnect + changeConnectionKeyInterest(SelectionKey.OP_WRITE) + logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending") + } catch { + case e: Exception => { + logWarning("Error finishing connection to " + address, e) + callOnExceptionCallback(e) + } + } + } + + override def write() { + try{ + while(true) { + if (currentBuffers.size == 0) { + outbox.synchronized { + outbox.getChunk match { + case Some(chunk) => { + currentBuffers ++= chunk.buffers + } + case None => { + changeConnectionKeyInterest(0) + /*key.interestOps(0)*/ + return + } + } + } + } + + if (currentBuffers.size > 0) { + val buffer = currentBuffers(0) + val remainingBytes = buffer.remaining + val writtenBytes = channel.write(buffer) + if (buffer.remaining == 0) { + currentBuffers -= buffer + } + if (writtenBytes < remainingBytes) { + return + } + } + } + } catch { + case e: Exception => { + logWarning("Error writing in connection to " + remoteConnectionManagerId, e) + callOnExceptionCallback(e) + close() + } + } + } +} + + +class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) +extends Connection(channel_, selector_) { + + class Inbox() { + val messages = new HashMap[Int, BufferMessage]() + + def getChunk(header: MessageChunkHeader): Option[MessageChunk] = { + + def createNewMessage: BufferMessage = { + val newMessage = Message.create(header).asInstanceOf[BufferMessage] + newMessage.started = true + newMessage.startTime = System.currentTimeMillis + logDebug("Starting to receive [" + newMessage + "] from [" + remoteConnectionManagerId + "]") + messages += ((newMessage.id, newMessage)) + newMessage + } + + val message = messages.getOrElseUpdate(header.id, createNewMessage) + logTrace("Receiving chunk of [" + message + "] from [" + remoteConnectionManagerId + "]") + message.getChunkForReceiving(header.chunkSize) + } + + def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = { + messages.get(chunk.header.id) + } + + def removeMessage(message: Message) { + messages -= message.id + } + } + + val inbox = new Inbox() + val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE) + var onReceiveCallback: (Connection , Message) => Unit = null + var currentChunk: MessageChunk = null + + channel.register(selector, SelectionKey.OP_READ) + + override def read() { + try { + while (true) { + if (currentChunk == null) { + val headerBytesRead = channel.read(headerBuffer) + if (headerBytesRead == -1) { + close() + return + } + if (headerBuffer.remaining > 0) { + return + } + headerBuffer.flip + if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) { + throw new Exception("Unexpected number of bytes (" + headerBuffer.remaining + ") in the header") + } + val header = MessageChunkHeader.create(headerBuffer) + headerBuffer.clear() + header.typ match { + case Message.BUFFER_MESSAGE => { + if (header.totalSize == 0) { + if (onReceiveCallback != null) { + onReceiveCallback(this, Message.create(header)) + } + currentChunk = null + return + } else { + currentChunk = inbox.getChunk(header).orNull + } + } + case _ => throw new Exception("Message of unknown type received") + } + } + + if (currentChunk == null) throw new Exception("No message chunk to receive data") + + val bytesRead = channel.read(currentChunk.buffer) + if (bytesRead == 0) { + return + } else if (bytesRead == -1) { + close() + return + } + + /*logDebug("Read " + bytesRead + " bytes for the buffer")*/ + + if (currentChunk.buffer.remaining == 0) { + /*println("Filled buffer at " + System.currentTimeMillis)*/ + val bufferMessage = inbox.getMessageForChunk(currentChunk).get + if (bufferMessage.isCompletelyReceived) { + bufferMessage.flip + bufferMessage.finishTime = System.currentTimeMillis + logDebug("Finished receiving [" + bufferMessage + "] from [" + remoteConnectionManagerId + "] in " + bufferMessage.timeTaken) + if (onReceiveCallback != null) { + onReceiveCallback(this, bufferMessage) + } + inbox.removeMessage(bufferMessage) + } + currentChunk = null + } + } + } catch { + case e: Exception => { + logWarning("Error reading from connection to " + remoteConnectionManagerId, e) + callOnExceptionCallback(e) + close() + } + } + } + + def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback} +} diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala new file mode 100644 index 0000000000000000000000000000000000000000..e9f254d0f3b9624cedddfa8698f90eebadbc561d --- /dev/null +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -0,0 +1,467 @@ +package spark.network + +import spark._ + +import scala.actors.Future +import scala.actors.Futures.future +import scala.collection.mutable.HashMap +import scala.collection.mutable.SynchronizedMap +import scala.collection.mutable.SynchronizedQueue +import scala.collection.mutable.Queue +import scala.collection.mutable.ArrayBuffer + +import java.io._ +import java.nio._ +import java.nio.channels._ +import java.nio.channels.spi._ +import java.net._ +import java.util.concurrent.Executors + +case class ConnectionManagerId(val host: String, val port: Int) { + def toSocketAddress() = new InetSocketAddress(host, port) +} + +object ConnectionManagerId { + def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { + new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort()) + } +} + +class ConnectionManager(port: Int) extends Logging { + + case class MessageStatus(message: Message, connectionManagerId: ConnectionManagerId) { + var ackMessage: Option[Message] = None + var attempted = false + var acked = false + } + + val selector = SelectorProvider.provider.openSelector() + /*val handleMessageExecutor = new ThreadPoolExecutor(4, 4, 600, TimeUnit.SECONDS, new LinkedBlockingQueue()) */ + val handleMessageExecutor = Executors.newFixedThreadPool(4) + val serverChannel = ServerSocketChannel.open() + val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] + val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] + val messageStatuses = new HashMap[Int, MessageStatus] + val connectionRequests = new SynchronizedQueue[SendingConnection] + val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] + val sendMessageRequests = new Queue[(Message, SendingConnection)] + + var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null + + serverChannel.configureBlocking(false) + serverChannel.socket.setReuseAddress(true) + serverChannel.socket.setReceiveBufferSize(256 * 1024) + + serverChannel.socket.bind(new InetSocketAddress(port)) + serverChannel.register(selector, SelectionKey.OP_ACCEPT) + + val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) + logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) + + val thisInstance = this + var selectorThread = new Thread("connection-manager-thread") { + override def run() { + thisInstance.run() + } + } + selectorThread.setDaemon(true) + selectorThread.start() + + def run() { + try { + var interrupted = false + while(!interrupted) { + while(!connectionRequests.isEmpty) { + val sendingConnection = connectionRequests.dequeue + sendingConnection.connect() + addConnection(sendingConnection) + } + sendMessageRequests.synchronized { + while(!sendMessageRequests.isEmpty) { + val (message, connection) = sendMessageRequests.dequeue + connection.send(message) + } + } + + while(!keyInterestChangeRequests.isEmpty) { + val (key, ops) = keyInterestChangeRequests.dequeue + val connection = connectionsByKey(key) + val lastOps = key.interestOps() + key.interestOps(ops) + + def intToOpStr(op: Int): String = { + val opStrs = ArrayBuffer[String]() + if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ" + if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE" + if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT" + if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT" + if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " " + } + + logTrace("Changed key for connection to [" + connection.remoteConnectionManagerId + + "] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]") + + } + + val selectedKeysCount = selector.select() + if (selectedKeysCount == 0) logInfo("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys") + + interrupted = selectorThread.isInterrupted + + val selectedKeys = selector.selectedKeys().iterator() + while (selectedKeys.hasNext()) { + val key = selectedKeys.next.asInstanceOf[SelectionKey] + selectedKeys.remove() + if (key.isValid) { + if (key.isAcceptable) { + acceptConnection(key) + } else + if (key.isConnectable) { + connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect() + } else + if (key.isReadable) { + connectionsByKey(key).read() + } else + if (key.isWritable) { + connectionsByKey(key).write() + } + } + } + } + } catch { + case e: Exception => logError("Error in select loop", e) + } + } + + def acceptConnection(key: SelectionKey) { + val serverChannel = key.channel.asInstanceOf[ServerSocketChannel] + val newChannel = serverChannel.accept() + val newConnection = new ReceivingConnection(newChannel, selector) + newConnection.onReceive(receiveMessage) + newConnection.onClose(removeConnection) + addConnection(newConnection) + logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]") + } + + def addConnection(connection: Connection) { + connectionsByKey += ((connection.key, connection)) + if (connection.isInstanceOf[SendingConnection]) { + val sendingConnection = connection.asInstanceOf[SendingConnection] + connectionsById += ((sendingConnection.remoteConnectionManagerId, sendingConnection)) + } + connection.onKeyInterestChange(changeConnectionKeyInterest) + connection.onException(handleConnectionError) + connection.onClose(removeConnection) + } + + def removeConnection(connection: Connection) { + /*logInfo("Removing connection")*/ + connectionsByKey -= connection.key + if (connection.isInstanceOf[SendingConnection]) { + val sendingConnection = connection.asInstanceOf[SendingConnection] + val sendingConnectionManagerId = sendingConnection.remoteConnectionManagerId + logInfo("Removing SendingConnection to " + sendingConnectionManagerId) + + connectionsById -= sendingConnectionManagerId + + messageStatuses.synchronized { + messageStatuses + .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => { + logInfo("Notifying " + status) + status.synchronized { + status.attempted = true + status.acked = false + status.notifyAll + } + }) + + messageStatuses.retain((i, status) => { + status.connectionManagerId != sendingConnectionManagerId + }) + } + } else if (connection.isInstanceOf[ReceivingConnection]) { + val receivingConnection = connection.asInstanceOf[ReceivingConnection] + val remoteConnectionManagerId = receivingConnection.remoteConnectionManagerId + logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId) + + val sendingConnectionManagerId = connectionsById.keys.find(_.host == remoteConnectionManagerId.host).orNull + if (sendingConnectionManagerId == null) { + logError("Corresponding SendingConnectionManagerId not found") + return + } + logInfo("Corresponding SendingConnectionManagerId is " + sendingConnectionManagerId) + + val sendingConnection = connectionsById(sendingConnectionManagerId) + sendingConnection.close() + connectionsById -= sendingConnectionManagerId + + messageStatuses.synchronized { + messageStatuses + .values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => { + logInfo("Notifying " + status) + status.synchronized { + status.attempted = true + status.acked = false + status.notifyAll + } + }) + + messageStatuses.retain((i, status) => { + status.connectionManagerId != sendingConnectionManagerId + }) + } + } + } + + def handleConnectionError(connection: Connection, e: Exception) { + logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId) + removeConnection(connection) + } + + def changeConnectionKeyInterest(connection: Connection, ops: Int) { + keyInterestChangeRequests += ((connection.key, ops)) + } + + def receiveMessage(connection: Connection, message: Message) { + val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress) + logInfo("Received [" + message + "] from [" + connectionManagerId + "]") + val runnable = new Runnable() { + val creationTime = System.currentTimeMillis + def run() { + logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") + handleMessage(connectionManagerId, message) + logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") + } + } + handleMessageExecutor.execute(runnable) + /*handleMessage(connection, message)*/ + } + + private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) { + logInfo("Handling [" + message + "] from [" + connectionManagerId + "]") + message match { + case bufferMessage: BufferMessage => { + if (bufferMessage.hasAckId) { + val sentMessageStatus = messageStatuses.synchronized { + messageStatuses.get(bufferMessage.ackId) match { + case Some(status) => { + messageStatuses -= bufferMessage.ackId + status + } + case None => { + throw new Exception("Could not find reference for received ack message " + message.id) + null + } + } + } + sentMessageStatus.synchronized { + sentMessageStatus.ackMessage = Some(message) + sentMessageStatus.attempted = true + sentMessageStatus.acked = true + sentMessageStatus.notifyAll + } + } else { + val ackMessage = if (onReceiveCallback != null) { + logDebug("Calling back") + onReceiveCallback(bufferMessage, connectionManagerId) + } else { + logWarning("Not calling back as callback is null") + None + } + + if (ackMessage.isDefined) { + if (!ackMessage.get.isInstanceOf[BufferMessage]) { + logWarning("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass()) + } else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) { + logWarning("Response to " + bufferMessage + " does not have ack id set") + ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id + } + } + + sendMessage(connectionManagerId, ackMessage.getOrElse { + Message.createBufferMessage(bufferMessage.id) + }) + } + } + case _ => throw new Exception("Unknown type message received") + } + } + + private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { + def startNewConnection(): SendingConnection = { + val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) + val newConnection = new SendingConnection(inetSocketAddress, selector) + connectionRequests += newConnection + newConnection + } + val connection = connectionsById.getOrElse(connectionManagerId, startNewConnection) + message.senderAddress = id.toSocketAddress() + logInfo("Sending [" + message + "] to [" + connectionManagerId + "]") + /*connection.send(message)*/ + sendMessageRequests.synchronized { + sendMessageRequests += ((message, connection)) + } + selector.wakeup() + } + + def sendMessageReliably(connectionManagerId: ConnectionManagerId, message: Message): Future[Option[Message]] = { + val messageStatus = new MessageStatus(message, connectionManagerId) + messageStatuses.synchronized { + messageStatuses += ((message.id, messageStatus)) + } + sendMessage(connectionManagerId, message) + future { + messageStatus.synchronized { + if (!messageStatus.attempted) { + logTrace("Waiting, " + messageStatuses.size + " statuses" ) + messageStatus.wait() + logTrace("Done waiting") + } + } + messageStatus.ackMessage + } + } + + def sendMessageReliablySync(connectionManagerId: ConnectionManagerId, message: Message): Option[Message] = { + sendMessageReliably(connectionManagerId, message)() + } + + def onReceiveMessage(callback: (Message, ConnectionManagerId) => Option[Message]) { + onReceiveCallback = callback + } + + def stop() { + selectorThread.interrupt() + selectorThread.join() + selector.close() + val connections = connectionsByKey.values + connections.foreach(_.close()) + if (connectionsByKey.size != 0) { + logWarning("All connections not cleaned up") + } + handleMessageExecutor.shutdown() + logInfo("ConnectionManager stopped") + } +} + + +object ConnectionManager { + + def main(args: Array[String]) { + + val manager = new ConnectionManager(9999) + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + println("Received [" + msg + "] from [" + id + "]") + None + }) + + /*testSequentialSending(manager)*/ + /*System.gc()*/ + + /*testParallelSending(manager)*/ + /*System.gc()*/ + + /*testParallelDecreasingSending(manager)*/ + /*System.gc()*/ + + testContinuousSending(manager) + System.gc() + } + + def testSequentialSending(manager: ConnectionManager) { + println("--------------------------") + println("Sequential Sending") + println("--------------------------") + val size = 10 * 1024 * 1024 + val count = 10 + + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + (0 until count).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliablySync(manager.id, bufferMessage) + }) + println("--------------------------") + println() + } + + def testParallelSending(manager: ConnectionManager) { + println("--------------------------") + println("Parallel Sending") + println("--------------------------") + val size = 10 * 1024 * 1024 + val count = 10 + + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + val startTime = System.currentTimeMillis + (0 until count).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliably(manager.id, bufferMessage) + }).foreach(f => {if (!f().isDefined) println("Failed")}) + val finishTime = System.currentTimeMillis + + val mb = size * count / 1024.0 / 1024.0 + val ms = finishTime - startTime + val tput = mb * 1000.0 / ms + println("--------------------------") + println("Started at " + startTime + ", finished at " + finishTime) + println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)") + println("--------------------------") + println() + } + + def testParallelDecreasingSending(manager: ConnectionManager) { + println("--------------------------") + println("Parallel Decreasing Sending") + println("--------------------------") + val size = 10 * 1024 * 1024 + val count = 10 + val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte))) + buffers.foreach(_.flip) + val mb = buffers.map(_.remaining).reduceLeft(_ + _) / 1024.0 / 1024.0 + + val startTime = System.currentTimeMillis + (0 until count).map(i => { + val bufferMessage = Message.createBufferMessage(buffers(count - 1 - i).duplicate) + manager.sendMessageReliably(manager.id, bufferMessage) + }).foreach(f => {if (!f().isDefined) println("Failed")}) + val finishTime = System.currentTimeMillis + + val ms = finishTime - startTime + val tput = mb * 1000.0 / ms + println("--------------------------") + /*println("Started at " + startTime + ", finished at " + finishTime) */ + println("Sent " + mb + " MB in " + ms + " ms (" + tput + " MB/s)") + println("--------------------------") + println() + } + + def testContinuousSending(manager: ConnectionManager) { + println("--------------------------") + println("Continuous Sending") + println("--------------------------") + val size = 10 * 1024 * 1024 + val count = 10 + + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + val startTime = System.currentTimeMillis + while(true) { + (0 until count).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliably(manager.id, bufferMessage) + }).foreach(f => {if (!f().isDefined) println("Failed")}) + val finishTime = System.currentTimeMillis + Thread.sleep(1000) + val mb = size * count / 1024.0 / 1024.0 + val ms = finishTime - startTime + val tput = mb * 1000.0 / ms + println("--------------------------") + println() + } + } +} diff --git a/core/src/main/scala/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/spark/network/ConnectionManagerTest.scala new file mode 100644 index 0000000000000000000000000000000000000000..5d21bb793f3dcefce2af736edeb602c47ff0c56f --- /dev/null +++ b/core/src/main/scala/spark/network/ConnectionManagerTest.scala @@ -0,0 +1,74 @@ +package spark.network + +import spark._ +import spark.SparkContext._ + +import scala.io.Source + +import java.nio.ByteBuffer +import java.net.InetAddress + +object ConnectionManagerTest extends Logging{ + def main(args: Array[String]) { + if (args.length < 2) { + println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>") + System.exit(1) + } + + if (args(0).startsWith("local")) { + println("This runs only on a mesos cluster") + } + + val sc = new SparkContext(args(0), "ConnectionManagerTest") + val slavesFile = Source.fromFile(args(1)) + val slaves = slavesFile.mkString.split("\n") + slavesFile.close() + + /*println("Slaves")*/ + /*slaves.foreach(println)*/ + + val slaveConnManagerIds = sc.parallelize(0 until slaves.length, slaves.length).map( + i => SparkEnv.get.connectionManager.id).collect() + println("\nSlave ConnectionManagerIds") + slaveConnManagerIds.foreach(println) + println + + val count = 10 + (0 until count).foreach(i => { + val resultStrs = sc.parallelize(0 until slaves.length, slaves.length).map(i => { + val connManager = SparkEnv.get.connectionManager + val thisConnManagerId = connManager.id + connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + logInfo("Received [" + msg + "] from [" + id + "]") + None + }) + + val size = 100 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + val startTime = System.currentTimeMillis + val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map(slaveConnManagerId => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") + connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) + }) + val results = futures.map(f => f()) + val finishTime = System.currentTimeMillis + Thread.sleep(5000) + + val mb = size * results.size / 1024.0 / 1024.0 + val ms = finishTime - startTime + val resultStr = "Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s" + logInfo(resultStr) + resultStr + }).collect() + + println("---------------------") + println("Run " + i) + resultStrs.foreach(println) + println("---------------------") + }) + } +} + diff --git a/core/src/main/scala/spark/network/Message.scala b/core/src/main/scala/spark/network/Message.scala new file mode 100644 index 0000000000000000000000000000000000000000..2e858036791d2e5e80020c7527d6cfecf6bd9f07 --- /dev/null +++ b/core/src/main/scala/spark/network/Message.scala @@ -0,0 +1,219 @@ +package spark.network + +import spark._ + +import scala.collection.mutable.ArrayBuffer + +import java.nio.ByteBuffer +import java.net.InetAddress +import java.net.InetSocketAddress + +class MessageChunkHeader( + val typ: Long, + val id: Int, + val totalSize: Int, + val chunkSize: Int, + val other: Int, + val address: InetSocketAddress) { + lazy val buffer = { + val ip = address.getAddress.getAddress() + val port = address.getPort() + ByteBuffer. + allocate(MessageChunkHeader.HEADER_SIZE). + putLong(typ). + putInt(id). + putInt(totalSize). + putInt(chunkSize). + putInt(other). + putInt(ip.size). + put(ip). + putInt(port). + position(MessageChunkHeader.HEADER_SIZE). + flip.asInstanceOf[ByteBuffer] + } + + override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + + " and sizes " + totalSize + " / " + chunkSize + " bytes" +} + +class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { + val size = if (buffer == null) 0 else buffer.remaining + lazy val buffers = { + val ab = new ArrayBuffer[ByteBuffer]() + ab += header.buffer + if (buffer != null) { + ab += buffer + } + ab + } + + override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")" +} + +abstract class Message(val typ: Long, val id: Int) { + var senderAddress: InetSocketAddress = null + var started = false + var startTime = -1L + var finishTime = -1L + + def size: Int + + def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] + + def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] + + def timeTaken(): String = (finishTime - startTime).toString + " ms" + + override def toString = "" + this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")" +} + +class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) +extends Message(Message.BUFFER_MESSAGE, id_) { + + val initialSize = currentSize() + var gotChunkForSendingOnce = false + + def size = initialSize + + def currentSize() = { + if (buffers == null || buffers.isEmpty) { + 0 + } else { + buffers.map(_.remaining).reduceLeft(_ + _) + } + } + + def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = { + if (maxChunkSize <= 0) { + throw new Exception("Max chunk size is " + maxChunkSize) + } + + if (size == 0 && gotChunkForSendingOnce == false) { + val newChunk = new MessageChunk(new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null) + gotChunkForSendingOnce = true + return Some(newChunk) + } + + while(!buffers.isEmpty) { + val buffer = buffers(0) + if (buffer.remaining == 0) { + buffers -= buffer + } else { + val newBuffer = if (buffer.remaining <= maxChunkSize) { + buffer.duplicate + } else { + buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer] + } + buffer.position(buffer.position + newBuffer.remaining) + val newChunk = new MessageChunk(new MessageChunkHeader( + typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + gotChunkForSendingOnce = true + return Some(newChunk) + } + } + None + } + + def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = { + // STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer + if (buffers.size > 1) { + throw new Exception("Attempting to get chunk from message with multiple data buffers") + } + val buffer = buffers(0) + if (buffer.remaining > 0) { + if (buffer.remaining < chunkSize) { + throw new Exception("Not enough space in data buffer for receiving chunk") + } + val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] + buffer.position(buffer.position + newBuffer.remaining) + val newChunk = new MessageChunk(new MessageChunkHeader( + typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + return Some(newChunk) + } + None + } + + def flip() { + buffers.foreach(_.flip) + } + + def hasAckId() = (ackId != 0) + + def isCompletelyReceived() = !buffers(0).hasRemaining + + override def toString = { + if (hasAckId) { + "BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")" + } else { + "BufferMessage(id = " + id + ", size = " + size + ")" + } + + } +} + +object MessageChunkHeader { + val HEADER_SIZE = 40 + + def create(buffer: ByteBuffer): MessageChunkHeader = { + if (buffer.remaining != HEADER_SIZE) { + throw new IllegalArgumentException("Cannot convert buffer data to Message") + } + val typ = buffer.getLong() + val id = buffer.getInt() + val totalSize = buffer.getInt() + val chunkSize = buffer.getInt() + val other = buffer.getInt() + val ipSize = buffer.getInt() + val ipBytes = new Array[Byte](ipSize) + buffer.get(ipBytes) + val ip = InetAddress.getByAddress(ipBytes) + val port = buffer.getInt() + new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port)) + } +} + +object Message { + val BUFFER_MESSAGE = 1111111111L + + var lastId = 1 + + def getNewId() = synchronized { + lastId += 1 + if (lastId == 0) lastId += 1 + lastId + } + + def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = { + if (dataBuffers == null) { + return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId) + } + if (dataBuffers.exists(_ == null)) { + throw new Exception("Attempting to create buffer message with null buffer") + } + return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer] ++= dataBuffers, ackId) + } + + def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage = + createBufferMessage(dataBuffers, 0) + + def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = { + if (dataBuffer == null) { + return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId) + } else { + return createBufferMessage(Array(dataBuffer), ackId) + } + } + + def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage = + createBufferMessage(dataBuffer, 0) + + def createBufferMessage(ackId: Int): BufferMessage = createBufferMessage(new Array[ByteBuffer](0), ackId) + + def create(header: MessageChunkHeader): Message = { + val newMessage: Message = header.typ match { + case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other) + } + newMessage.senderAddress = header.address + newMessage + } +} diff --git a/core/src/main/scala/spark/network/ReceiverTest.scala b/core/src/main/scala/spark/network/ReceiverTest.scala new file mode 100644 index 0000000000000000000000000000000000000000..e1ba7c06c04dfd615ef5f23ae710fc73faaf6e11 --- /dev/null +++ b/core/src/main/scala/spark/network/ReceiverTest.scala @@ -0,0 +1,20 @@ +package spark.network + +import java.nio.ByteBuffer +import java.net.InetAddress + +object ReceiverTest { + + def main(args: Array[String]) { + val manager = new ConnectionManager(9999) + println("Started connection manager with id = " + manager.id) + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + /*println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis)*/ + val buffer = ByteBuffer.wrap("response".getBytes()) + Some(Message.createBufferMessage(buffer, msg.id)) + }) + Thread.currentThread.join() + } +} + diff --git a/core/src/main/scala/spark/network/SenderTest.scala b/core/src/main/scala/spark/network/SenderTest.scala new file mode 100644 index 0000000000000000000000000000000000000000..4ab6dd34140992fdc9d6b1642b5b4d6ae1e69e2c --- /dev/null +++ b/core/src/main/scala/spark/network/SenderTest.scala @@ -0,0 +1,53 @@ +package spark.network + +import java.nio.ByteBuffer +import java.net.InetAddress + +object SenderTest { + + def main(args: Array[String]) { + + if (args.length < 2) { + println("Usage: SenderTest <target host> <target port>") + System.exit(1) + } + + val targetHost = args(0) + val targetPort = args(1).toInt + val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort) + + val manager = new ConnectionManager(0) + println("Started connection manager with id = " + manager.id) + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + println("Received [" + msg + "] from [" + id + "]") + None + }) + + val size = 100 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + val targetServer = args(0) + + val count = 100 + (0 until count).foreach(i => { + val dataMessage = Message.createBufferMessage(buffer.duplicate) + val startTime = System.currentTimeMillis + /*println("Started timer at " + startTime)*/ + val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) match { + case Some(response) => + val buffer = response.asInstanceOf[BufferMessage].buffers(0) + new String(buffer.array) + case None => "none" + } + val finishTime = System.currentTimeMillis + val mb = size / 1024.0 / 1024.0 + val ms = finishTime - startTime + /*val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms at " + (mb / ms * 1000.0) + " MB/s"*/ + val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms (" + (mb / ms * 1000.0).toInt + "MB/s) | Response = " + responseStr + println(resultStr) + }) + } +} + diff --git a/core/src/main/scala/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/spark/partial/ApproximateActionListener.scala new file mode 100644 index 0000000000000000000000000000000000000000..260547902bb4a743e7a48ec1fb2d5a8b3b56da9c --- /dev/null +++ b/core/src/main/scala/spark/partial/ApproximateActionListener.scala @@ -0,0 +1,66 @@ +package spark.partial + +import spark._ +import spark.scheduler.JobListener + +/** + * A JobListener for an approximate single-result action, such as count() or non-parallel reduce(). + * This listener waits up to timeout milliseconds and will return a partial answer even if the + * complete answer is not available by then. + * + * This class assumes that the action is performed on an entire RDD[T] via a function that computes + * a result of type U for each partition, and that the action returns a partial or complete result + * of type R. Note that the type R must *include* any error bars on it (e.g. see BoundedInt). + */ +class ApproximateActionListener[T, U, R]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + evaluator: ApproximateEvaluator[U, R], + timeout: Long) + extends JobListener { + + val startTime = System.currentTimeMillis() + val totalTasks = rdd.splits.size + var finishedTasks = 0 + var failure: Option[Exception] = None // Set if the job has failed (permanently) + var resultObject: Option[PartialResult[R]] = None // Set if we've already returned a PartialResult + + override def taskSucceeded(index: Int, result: Any): Unit = synchronized { + evaluator.merge(index, result.asInstanceOf[U]) + finishedTasks += 1 + if (finishedTasks == totalTasks) { + // If we had already returned a PartialResult, set its final value + resultObject.foreach(r => r.setFinalValue(evaluator.currentResult())) + // Notify any waiting thread that may have called getResult + this.notifyAll() + } + } + + override def jobFailed(exception: Exception): Unit = synchronized { + failure = Some(exception) + this.notifyAll() + } + + /** + * Waits for up to timeout milliseconds since the listener was created and then returns a + * PartialResult with the result so far. This may be complete if the whole job is done. + */ + def getResult(): PartialResult[R] = synchronized { + val finishTime = startTime + timeout + while (true) { + val time = System.currentTimeMillis() + if (failure != None) { + throw failure.get + } else if (finishedTasks == totalTasks) { + return new PartialResult(evaluator.currentResult(), true) + } else if (time >= finishTime) { + resultObject = Some(new PartialResult(evaluator.currentResult(), false)) + return resultObject.get + } else { + this.wait(finishTime - time) + } + } + // Should never be reached, but required to keep the compiler happy + return null + } +} diff --git a/core/src/main/scala/spark/partial/ApproximateEvaluator.scala b/core/src/main/scala/spark/partial/ApproximateEvaluator.scala new file mode 100644 index 0000000000000000000000000000000000000000..4772e43ef04118cc25a2555ca3c250268496264f --- /dev/null +++ b/core/src/main/scala/spark/partial/ApproximateEvaluator.scala @@ -0,0 +1,10 @@ +package spark.partial + +/** + * An object that computes a function incrementally by merging in results of type U from multiple + * tasks. Allows partial evaluation at any point by calling currentResult(). + */ +trait ApproximateEvaluator[U, R] { + def merge(outputId: Int, taskResult: U): Unit + def currentResult(): R +} diff --git a/core/src/main/scala/spark/partial/BoundedDouble.scala b/core/src/main/scala/spark/partial/BoundedDouble.scala new file mode 100644 index 0000000000000000000000000000000000000000..463c33d6e238ebc688390accd0b66e4b4ef10cf5 --- /dev/null +++ b/core/src/main/scala/spark/partial/BoundedDouble.scala @@ -0,0 +1,8 @@ +package spark.partial + +/** + * A Double with error bars on it. + */ +class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) { + override def toString(): String = "[%.3f, %.3f]".format(low, high) +} diff --git a/core/src/main/scala/spark/partial/CountEvaluator.scala b/core/src/main/scala/spark/partial/CountEvaluator.scala new file mode 100644 index 0000000000000000000000000000000000000000..1bc90d6b3930aab7b870cbca4a2b0731723be1e8 --- /dev/null +++ b/core/src/main/scala/spark/partial/CountEvaluator.scala @@ -0,0 +1,38 @@ +package spark.partial + +import cern.jet.stat.Probability + +/** + * An ApproximateEvaluator for counts. + * + * TODO: There's currently a lot of shared code between this and GroupedCountEvaluator. It might + * be best to make this a special case of GroupedCountEvaluator with one group. + */ +class CountEvaluator(totalOutputs: Int, confidence: Double) + extends ApproximateEvaluator[Long, BoundedDouble] { + + var outputsMerged = 0 + var sum: Long = 0 + + override def merge(outputId: Int, taskResult: Long) { + outputsMerged += 1 + sum += taskResult + } + + override def currentResult(): BoundedDouble = { + if (outputsMerged == totalOutputs) { + new BoundedDouble(sum, 1.0, sum, sum) + } else if (outputsMerged == 0) { + new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) + } else { + val p = outputsMerged.toDouble / totalOutputs + val mean = (sum + 1 - p) / p + val variance = (sum + 1) * (1 - p) / (p * p) + val stdev = math.sqrt(variance) + val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2) + val low = mean - confFactor * stdev + val high = mean + confFactor * stdev + new BoundedDouble(mean, confidence, low, high) + } + } +} diff --git a/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala new file mode 100644 index 0000000000000000000000000000000000000000..3e631c0efc5517c184126ff4602988d1e79297e6 --- /dev/null +++ b/core/src/main/scala/spark/partial/GroupedCountEvaluator.scala @@ -0,0 +1,62 @@ +package spark.partial + +import java.util.{HashMap => JHashMap} +import java.util.{Map => JMap} + +import scala.collection.Map +import scala.collection.mutable.HashMap +import scala.collection.JavaConversions.mapAsScalaMap + +import cern.jet.stat.Probability + +import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} + +/** + * An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval. + */ +class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double) + extends ApproximateEvaluator[OLMap[T], Map[T, BoundedDouble]] { + + var outputsMerged = 0 + var sums = new OLMap[T] // Sum of counts for each key + + override def merge(outputId: Int, taskResult: OLMap[T]) { + outputsMerged += 1 + val iter = taskResult.object2LongEntrySet.fastIterator() + while (iter.hasNext) { + val entry = iter.next() + sums.put(entry.getKey, sums.getLong(entry.getKey) + entry.getLongValue) + } + } + + override def currentResult(): Map[T, BoundedDouble] = { + if (outputsMerged == totalOutputs) { + val result = new JHashMap[T, BoundedDouble](sums.size) + val iter = sums.object2LongEntrySet.fastIterator() + while (iter.hasNext) { + val entry = iter.next() + val sum = entry.getLongValue() + result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum) + } + result + } else if (outputsMerged == 0) { + new HashMap[T, BoundedDouble] + } else { + val p = outputsMerged.toDouble / totalOutputs + val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2) + val result = new JHashMap[T, BoundedDouble](sums.size) + val iter = sums.object2LongEntrySet.fastIterator() + while (iter.hasNext) { + val entry = iter.next() + val sum = entry.getLongValue + val mean = (sum + 1 - p) / p + val variance = (sum + 1) * (1 - p) / (p * p) + val stdev = math.sqrt(variance) + val low = mean - confFactor * stdev + val high = mean + confFactor * stdev + result(entry.getKey) = new BoundedDouble(mean, confidence, low, high) + } + result + } + } +} diff --git a/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala b/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala new file mode 100644 index 0000000000000000000000000000000000000000..2a9ccba2055efc5121de8789b225a9808bb475b9 --- /dev/null +++ b/core/src/main/scala/spark/partial/GroupedMeanEvaluator.scala @@ -0,0 +1,65 @@ +package spark.partial + +import java.util.{HashMap => JHashMap} +import java.util.{Map => JMap} + +import scala.collection.mutable.HashMap +import scala.collection.Map +import scala.collection.JavaConversions.mapAsScalaMap + +import spark.util.StatCounter + +/** + * An ApproximateEvaluator for means by key. Returns a map of key to confidence interval. + */ +class GroupedMeanEvaluator[T](totalOutputs: Int, confidence: Double) + extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { + + var outputsMerged = 0 + var sums = new JHashMap[T, StatCounter] // Sum of counts for each key + + override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) { + outputsMerged += 1 + val iter = taskResult.entrySet.iterator() + while (iter.hasNext) { + val entry = iter.next() + val old = sums.get(entry.getKey) + if (old != null) { + old.merge(entry.getValue) + } else { + sums.put(entry.getKey, entry.getValue) + } + } + } + + override def currentResult(): Map[T, BoundedDouble] = { + if (outputsMerged == totalOutputs) { + val result = new JHashMap[T, BoundedDouble](sums.size) + val iter = sums.entrySet.iterator() + while (iter.hasNext) { + val entry = iter.next() + val mean = entry.getValue.mean + result(entry.getKey) = new BoundedDouble(mean, 1.0, mean, mean) + } + result + } else if (outputsMerged == 0) { + new HashMap[T, BoundedDouble] + } else { + val p = outputsMerged.toDouble / totalOutputs + val studentTCacher = new StudentTCacher(confidence) + val result = new JHashMap[T, BoundedDouble](sums.size) + val iter = sums.entrySet.iterator() + while (iter.hasNext) { + val entry = iter.next() + val counter = entry.getValue + val mean = counter.mean + val stdev = math.sqrt(counter.sampleVariance / counter.count) + val confFactor = studentTCacher.get(counter.count) + val low = mean - confFactor * stdev + val high = mean + confFactor * stdev + result(entry.getKey) = new BoundedDouble(mean, confidence, low, high) + } + result + } + } +} diff --git a/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala b/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala new file mode 100644 index 0000000000000000000000000000000000000000..6a2ec7a7bd30e53bf4844ff1f4382f3118bbc635 --- /dev/null +++ b/core/src/main/scala/spark/partial/GroupedSumEvaluator.scala @@ -0,0 +1,72 @@ +package spark.partial + +import java.util.{HashMap => JHashMap} +import java.util.{Map => JMap} + +import scala.collection.mutable.HashMap +import scala.collection.Map +import scala.collection.JavaConversions.mapAsScalaMap + +import spark.util.StatCounter + +/** + * An ApproximateEvaluator for sums by key. Returns a map of key to confidence interval. + */ +class GroupedSumEvaluator[T](totalOutputs: Int, confidence: Double) + extends ApproximateEvaluator[JHashMap[T, StatCounter], Map[T, BoundedDouble]] { + + var outputsMerged = 0 + var sums = new JHashMap[T, StatCounter] // Sum of counts for each key + + override def merge(outputId: Int, taskResult: JHashMap[T, StatCounter]) { + outputsMerged += 1 + val iter = taskResult.entrySet.iterator() + while (iter.hasNext) { + val entry = iter.next() + val old = sums.get(entry.getKey) + if (old != null) { + old.merge(entry.getValue) + } else { + sums.put(entry.getKey, entry.getValue) + } + } + } + + override def currentResult(): Map[T, BoundedDouble] = { + if (outputsMerged == totalOutputs) { + val result = new JHashMap[T, BoundedDouble](sums.size) + val iter = sums.entrySet.iterator() + while (iter.hasNext) { + val entry = iter.next() + val sum = entry.getValue.sum + result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum) + } + result + } else if (outputsMerged == 0) { + new HashMap[T, BoundedDouble] + } else { + val p = outputsMerged.toDouble / totalOutputs + val studentTCacher = new StudentTCacher(confidence) + val result = new JHashMap[T, BoundedDouble](sums.size) + val iter = sums.entrySet.iterator() + while (iter.hasNext) { + val entry = iter.next() + val counter = entry.getValue + val meanEstimate = counter.mean + val meanVar = counter.sampleVariance / counter.count + val countEstimate = (counter.count + 1 - p) / p + val countVar = (counter.count + 1) * (1 - p) / (p * p) + val sumEstimate = meanEstimate * countEstimate + val sumVar = (meanEstimate * meanEstimate * countVar) + + (countEstimate * countEstimate * meanVar) + + (meanVar * countVar) + val sumStdev = math.sqrt(sumVar) + val confFactor = studentTCacher.get(counter.count) + val low = sumEstimate - confFactor * sumStdev + val high = sumEstimate + confFactor * sumStdev + result(entry.getKey) = new BoundedDouble(sumEstimate, confidence, low, high) + } + result + } + } +} diff --git a/core/src/main/scala/spark/partial/MeanEvaluator.scala b/core/src/main/scala/spark/partial/MeanEvaluator.scala new file mode 100644 index 0000000000000000000000000000000000000000..b8c7cb8863539096ec9577e1c43ec1831c545423 --- /dev/null +++ b/core/src/main/scala/spark/partial/MeanEvaluator.scala @@ -0,0 +1,41 @@ +package spark.partial + +import cern.jet.stat.Probability + +import spark.util.StatCounter + +/** + * An ApproximateEvaluator for means. + */ +class MeanEvaluator(totalOutputs: Int, confidence: Double) + extends ApproximateEvaluator[StatCounter, BoundedDouble] { + + var outputsMerged = 0 + var counter = new StatCounter + + override def merge(outputId: Int, taskResult: StatCounter) { + outputsMerged += 1 + counter.merge(taskResult) + } + + override def currentResult(): BoundedDouble = { + if (outputsMerged == totalOutputs) { + new BoundedDouble(counter.mean, 1.0, counter.mean, counter.mean) + } else if (outputsMerged == 0) { + new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) + } else { + val mean = counter.mean + val stdev = math.sqrt(counter.sampleVariance / counter.count) + val confFactor = { + if (counter.count > 100) { + Probability.normalInverse(1 - (1 - confidence) / 2) + } else { + Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt) + } + } + val low = mean - confFactor * stdev + val high = mean + confFactor * stdev + new BoundedDouble(mean, confidence, low, high) + } + } +} diff --git a/core/src/main/scala/spark/partial/PartialResult.scala b/core/src/main/scala/spark/partial/PartialResult.scala new file mode 100644 index 0000000000000000000000000000000000000000..7095bc8ca1bbf4d134a3ce01b3cd1826e3a93722 --- /dev/null +++ b/core/src/main/scala/spark/partial/PartialResult.scala @@ -0,0 +1,86 @@ +package spark.partial + +class PartialResult[R](initialVal: R, isFinal: Boolean) { + private var finalValue: Option[R] = if (isFinal) Some(initialVal) else None + private var failure: Option[Exception] = None + private var completionHandler: Option[R => Unit] = None + private var failureHandler: Option[Exception => Unit] = None + + def initialValue: R = initialVal + + def isInitialValueFinal: Boolean = isFinal + + /** + * Blocking method to wait for and return the final value. + */ + def getFinalValue(): R = synchronized { + while (finalValue == None && failure == None) { + this.wait() + } + if (finalValue != None) { + return finalValue.get + } else { + throw failure.get + } + } + + /** + * Set a handler to be called when this PartialResult completes. Only one completion handler + * is supported per PartialResult. + */ + def onComplete(handler: R => Unit): PartialResult[R] = synchronized { + if (completionHandler != None) { + throw new UnsupportedOperationException("onComplete cannot be called twice") + } + completionHandler = Some(handler) + if (finalValue != None) { + // We already have a final value, so let's call the handler + handler(finalValue.get) + } + return this + } + + /** + * Set a handler to be called if this PartialResult's job fails. Only one failure handler + * is supported per PartialResult. + */ + def onFail(handler: Exception => Unit): Unit = synchronized { + if (failureHandler != None) { + throw new UnsupportedOperationException("onFail cannot be called twice") + } + failureHandler = Some(handler) + if (failure != None) { + // We already have a failure, so let's call the handler + handler(failure.get) + } + } + + private[spark] def setFinalValue(value: R): Unit = synchronized { + if (finalValue != None) { + throw new UnsupportedOperationException("setFinalValue called twice on a PartialResult") + } + finalValue = Some(value) + // Call the completion handler if it was set + completionHandler.foreach(h => h(value)) + // Notify any threads that may be calling getFinalValue() + this.notifyAll() + } + + private[spark] def setFailure(exception: Exception): Unit = synchronized { + if (failure != None) { + throw new UnsupportedOperationException("setFailure called twice on a PartialResult") + } + failure = Some(exception) + // Call the failure handler if it was set + failureHandler.foreach(h => h(exception)) + // Notify any threads that may be calling getFinalValue() + this.notifyAll() + } + + override def toString: String = synchronized { + finalValue match { + case Some(value) => "(final: " + value + ")" + case None => "(partial: " + initialValue + ")" + } + } +} diff --git a/core/src/main/scala/spark/partial/StudentTCacher.scala b/core/src/main/scala/spark/partial/StudentTCacher.scala new file mode 100644 index 0000000000000000000000000000000000000000..6263ee3518d8c21beb081d4c26dd0aa837f683d5 --- /dev/null +++ b/core/src/main/scala/spark/partial/StudentTCacher.scala @@ -0,0 +1,26 @@ +package spark.partial + +import cern.jet.stat.Probability + +/** + * A utility class for caching Student's T distribution values for a given confidence level + * and various sample sizes. This is used by the MeanEvaluator to efficiently calculate + * confidence intervals for many keys. + */ +class StudentTCacher(confidence: Double) { + val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation + val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2) + val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0) + + def get(sampleSize: Long): Double = { + if (sampleSize >= NORMAL_APPROX_SAMPLE_SIZE) { + normalApprox + } else { + val size = sampleSize.toInt + if (cache(size) < 0) { + cache(size) = Probability.studentTInverse(1 - confidence, size - 1) + } + cache(size) + } + } +} diff --git a/core/src/main/scala/spark/partial/SumEvaluator.scala b/core/src/main/scala/spark/partial/SumEvaluator.scala new file mode 100644 index 0000000000000000000000000000000000000000..0357a6bff860a78729f759d44ff63feae76236fa --- /dev/null +++ b/core/src/main/scala/spark/partial/SumEvaluator.scala @@ -0,0 +1,51 @@ +package spark.partial + +import cern.jet.stat.Probability + +import spark.util.StatCounter + +/** + * An ApproximateEvaluator for sums. It estimates the mean and the cont and multiplies them + * together, then uses the formula for the variance of two independent random variables to get + * a variance for the result and compute a confidence interval. + */ +class SumEvaluator(totalOutputs: Int, confidence: Double) + extends ApproximateEvaluator[StatCounter, BoundedDouble] { + + var outputsMerged = 0 + var counter = new StatCounter + + override def merge(outputId: Int, taskResult: StatCounter) { + outputsMerged += 1 + counter.merge(taskResult) + } + + override def currentResult(): BoundedDouble = { + if (outputsMerged == totalOutputs) { + new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum) + } else if (outputsMerged == 0) { + new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) + } else { + val p = outputsMerged.toDouble / totalOutputs + val meanEstimate = counter.mean + val meanVar = counter.sampleVariance / counter.count + val countEstimate = (counter.count + 1 - p) / p + val countVar = (counter.count + 1) * (1 - p) / (p * p) + val sumEstimate = meanEstimate * countEstimate + val sumVar = (meanEstimate * meanEstimate * countVar) + + (countEstimate * countEstimate * meanVar) + + (meanVar * countVar) + val sumStdev = math.sqrt(sumVar) + val confFactor = { + if (counter.count > 100) { + Probability.normalInverse(1 - (1 - confidence) / 2) + } else { + Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt) + } + } + val low = sumEstimate - confFactor * sumStdev + val high = sumEstimate + confFactor * sumStdev + new BoundedDouble(sumEstimate, confidence, low, high) + } + } +} diff --git a/core/src/main/scala/spark/scheduler/ActiveJob.scala b/core/src/main/scala/spark/scheduler/ActiveJob.scala new file mode 100644 index 0000000000000000000000000000000000000000..0ecff9ce77ea773c30d9947a342327d2bf88fa29 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/ActiveJob.scala @@ -0,0 +1,18 @@ +package spark.scheduler + +import spark.TaskContext + +/** + * Tracks information about an active job in the DAGScheduler. + */ +class ActiveJob( + val runId: Int, + val finalStage: Stage, + val func: (TaskContext, Iterator[_]) => _, + val partitions: Array[Int], + val listener: JobListener) { + + val numPartitions = partitions.length + val finished = Array.fill[Boolean](numPartitions)(false) + var numFinished = 0 +} diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala new file mode 100644 index 0000000000000000000000000000000000000000..f31e2c65a050d59302810be769a28a6c9bed67aa --- /dev/null +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -0,0 +1,532 @@ +package spark.scheduler + +import java.net.URI +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.Future +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.TimeUnit + +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map} + +import spark._ +import spark.partial.ApproximateActionListener +import spark.partial.ApproximateEvaluator +import spark.partial.PartialResult +import spark.storage.BlockManagerMaster +import spark.storage.BlockManagerId + +/** + * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for + * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal + * schedule to run the job. Subclasses only need to implement the code to send a task to the cluster + * and to report fetch failures (the submitTasks method, and code to add CompletionEvents). + */ +class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with Logging { + taskSched.setListener(this) + + // Called by TaskScheduler to report task completions or failures. + override def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Map[Long, Any]) { + eventQueue.put(CompletionEvent(task, reason, result, accumUpdates)) + } + + // Called by TaskScheduler when a host fails. + override def hostLost(host: String) { + eventQueue.put(HostLost(host)) + } + + // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; + // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one + // as more failure events come in + val RESUBMIT_TIMEOUT = 50L + + // The time, in millis, to wake up between polls of the completion queue in order to potentially + // resubmit failed stages + val POLL_TIMEOUT = 10L + + private val lock = new Object // Used for access to the entire DAGScheduler + + private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent] + + val nextRunId = new AtomicInteger(0) + + val nextStageId = new AtomicInteger(0) + + val idToStage = new HashMap[Int, Stage] + + val shuffleToMapStage = new HashMap[Int, Stage] + + var cacheLocs = new HashMap[Int, Array[List[String]]] + + val env = SparkEnv.get + val cacheTracker = env.cacheTracker + val mapOutputTracker = env.mapOutputTracker + + val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back; + // that's not going to be a realistic assumption in general + + val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done + val running = new HashSet[Stage] // Stages we are running right now + val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures + val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] // Missing tasks from each stage + var lastFetchFailureTime: Long = 0 // Used to wait a bit to avoid repeated resubmits + + val activeJobs = new HashSet[ActiveJob] + val resultStageToJob = new HashMap[Stage, ActiveJob] + + // Start a thread to run the DAGScheduler event loop + new Thread("DAGScheduler") { + setDaemon(true) + override def run() { + DAGScheduler.this.run() + } + }.start() + + def getCacheLocs(rdd: RDD[_]): Array[List[String]] = { + cacheLocs(rdd.id) + } + + def updateCacheLocs() { + cacheLocs = cacheTracker.getLocationsSnapshot() + } + + /** + * Get or create a shuffle map stage for the given shuffle dependency's map side. + * The priority value passed in will be used if the stage doesn't already exist with + * a lower priority (we assume that priorities always increase across jobs for now). + */ + def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_,_], priority: Int): Stage = { + shuffleToMapStage.get(shuffleDep.shuffleId) match { + case Some(stage) => stage + case None => + val stage = newStage(shuffleDep.rdd, Some(shuffleDep), priority) + shuffleToMapStage(shuffleDep.shuffleId) = stage + stage + } + } + + /** + * Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or + * as a result stage for the final RDD used directly in an action. The stage will also be given + * the provided priority. + */ + def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]], priority: Int): Stage = { + // Kind of ugly: need to register RDDs with the cache and map output tracker here + // since we can't do it in the RDD constructor because # of splits is unknown + logInfo("Registering RDD " + rdd.id + ": " + rdd) + cacheTracker.registerRDD(rdd.id, rdd.splits.size) + if (shuffleDep != None) { + mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.splits.size) + } + val id = nextStageId.getAndIncrement() + val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority) + idToStage(id) = stage + stage + } + + /** + * Get or create the list of parent stages for a given RDD. The stages will be assigned the + * provided priority if they haven't already been created with a lower priority. + */ + def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = { + val parents = new HashSet[Stage] + val visited = new HashSet[RDD[_]] + def visit(r: RDD[_]) { + if (!visited(r)) { + visited += r + // Kind of ugly: need to register RDDs with the cache here since + // we can't do it in its constructor because # of splits is unknown + logInfo("Registering parent RDD " + r.id + ": " + r) + cacheTracker.registerRDD(r.id, r.splits.size) + for (dep <- r.dependencies) { + dep match { + case shufDep: ShuffleDependency[_,_,_] => + parents += getShuffleMapStage(shufDep, priority) + case _ => + visit(dep.rdd) + } + } + } + } + visit(rdd) + parents.toList + } + + def getMissingParentStages(stage: Stage): List[Stage] = { + val missing = new HashSet[Stage] + val visited = new HashSet[RDD[_]] + def visit(rdd: RDD[_]) { + if (!visited(rdd)) { + visited += rdd + val locs = getCacheLocs(rdd) + for (p <- 0 until rdd.splits.size) { + if (locs(p) == Nil) { + for (dep <- rdd.dependencies) { + dep match { + case shufDep: ShuffleDependency[_,_,_] => + val mapStage = getShuffleMapStage(shufDep, stage.priority) + if (!mapStage.isAvailable) { + missing += mapStage + } + case narrowDep: NarrowDependency[_] => + visit(narrowDep.rdd) + } + } + } + } + } + } + visit(stage.rdd) + missing.toList + } + + def runJob[T, U]( + finalRdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + allowLocal: Boolean) + (implicit m: ClassManifest[U]): Array[U] = + { + val waiter = new JobWaiter(partitions.size) + val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] + eventQueue.put(JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, waiter)) + waiter.getResult() match { + case JobSucceeded(results: Seq[_]) => + return results.asInstanceOf[Seq[U]].toArray + case JobFailed(exception: Exception) => + throw exception + } + } + + def runApproximateJob[T, U, R]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + evaluator: ApproximateEvaluator[U, R], + timeout: Long + ): PartialResult[R] = + { + val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) + val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] + val partitions = (0 until rdd.splits.size).toArray + eventQueue.put(JobSubmitted(rdd, func2, partitions, false, listener)) + return listener.getResult() // Will throw an exception if the job fails + } + + /** + * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure + * events and responds by launching tasks. This runs in a dedicated thread and receives events + * via the eventQueue. + */ + def run() = { + SparkEnv.set(env) + + while (true) { + val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS) + val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability + if (event != null) { + logDebug("Got event of type " + event.getClass.getName) + } + + event match { + case JobSubmitted(finalRDD, func, partitions, allowLocal, listener) => + val runId = nextRunId.getAndIncrement() + val finalStage = newStage(finalRDD, None, runId) + val job = new ActiveJob(runId, finalStage, func, partitions, listener) + updateCacheLocs() + logInfo("Got job " + job.runId + " with " + partitions.length + " output partitions") + logInfo("Final stage: " + finalStage) + logInfo("Parents of final stage: " + finalStage.parents) + logInfo("Missing parents: " + getMissingParentStages(finalStage)) + if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) { + // Compute very short actions like first() or take() with no parent stages locally. + runLocally(job) + } else { + activeJobs += job + resultStageToJob(finalStage) = job + submitStage(finalStage) + } + + case HostLost(host) => + handleHostLost(host) + + case completion: CompletionEvent => + handleTaskCompletion(completion) + + case null => + // queue.poll() timed out, ignore it + } + + // Periodically resubmit failed stages if some map output fetches have failed and we have + // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails, + // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at + // the same time, so we want to make sure we've identified all the reduce tasks that depend + // on the failed node. + if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) { + logInfo("Resubmitting failed stages") + updateCacheLocs() + val failed2 = failed.toArray + failed.clear() + for (stage <- failed2.sortBy(_.priority)) { + submitStage(stage) + } + } else { + // TODO: We might want to run this less often, when we are sure that something has become + // runnable that wasn't before. + logDebug("Checking for newly runnable parent stages") + logDebug("running: " + running) + logDebug("waiting: " + waiting) + logDebug("failed: " + failed) + val waiting2 = waiting.toArray + waiting.clear() + for (stage <- waiting2.sortBy(_.priority)) { + submitStage(stage) + } + } + } + } + + /** + * Run a job on an RDD locally, assuming it has only a single partition and no dependencies. + * We run the operation in a separate thread just in case it takes a bunch of time, so that we + * don't block the DAGScheduler event loop or other concurrent jobs. + */ + def runLocally(job: ActiveJob) { + logInfo("Computing the requested partition locally") + new Thread("Local computation of job " + job.runId) { + override def run() { + try { + SparkEnv.set(env) + val rdd = job.finalStage.rdd + val split = rdd.splits(job.partitions(0)) + val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0) + val result = job.func(taskContext, rdd.iterator(split)) + job.listener.taskSucceeded(0, result) + } catch { + case e: Exception => + job.listener.jobFailed(e) + } + } + }.start() + } + + def submitStage(stage: Stage) { + logDebug("submitStage(" + stage + ")") + if (!waiting(stage) && !running(stage) && !failed(stage)) { + val missing = getMissingParentStages(stage).sortBy(_.id) + logDebug("missing: " + missing) + if (missing == Nil) { + logInfo("Submitting " + stage + ", which has no missing parents") + submitMissingTasks(stage) + running += stage + } else { + for (parent <- missing) { + submitStage(parent) + } + waiting += stage + } + } + } + + def submitMissingTasks(stage: Stage) { + logDebug("submitMissingTasks(" + stage + ")") + // Get our pending tasks and remember them in our pendingTasks entry + val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet) + myPending.clear() + var tasks = ArrayBuffer[Task[_]]() + if (stage.isShuffleMap) { + for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { + val locs = getPreferredLocs(stage.rdd, p) + tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs) + } + } else { + // This is a final stage; figure out its job's missing partitions + val job = resultStageToJob(stage) + for (id <- 0 until job.numPartitions if (!job.finished(id))) { + val partition = job.partitions(id) + val locs = getPreferredLocs(stage.rdd, partition) + tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id) + } + } + if (tasks.size > 0) { + logInfo("Submitting " + tasks.size + " missing tasks from " + stage) + myPending ++= tasks + logDebug("New pending tasks: " + myPending) + taskSched.submitTasks( + new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority)) + } else { + logDebug("Stage " + stage + " is actually done; %b %d %d".format( + stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) + running -= stage + } + } + + /** + * Responds to a task finishing. This is called inside the event loop so it assumes that it can + * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. + */ + def handleTaskCompletion(event: CompletionEvent) { + val task = event.task + val stage = idToStage(task.stageId) + event.reason match { + case Success => + logInfo("Completed " + task) + if (event.accumUpdates != null) { + Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted + } + pendingTasks(stage) -= task + task match { + case rt: ResultTask[_, _] => + resultStageToJob.get(stage) match { + case Some(job) => + if (!job.finished(rt.outputId)) { + job.finished(rt.outputId) = true + job.numFinished += 1 + job.listener.taskSucceeded(rt.outputId, event.result) + // If the whole job has finished, remove it + if (job.numFinished == job.numPartitions) { + activeJobs -= job + resultStageToJob -= stage + running -= stage + } + } + case None => + logInfo("Ignoring result from " + rt + " because its job has finished") + } + + case smt: ShuffleMapTask => + val stage = idToStage(smt.stageId) + val bmAddress = event.result.asInstanceOf[BlockManagerId] + val host = bmAddress.ip + logInfo("ShuffleMapTask finished with host " + host) + if (!deadHosts.contains(host)) { // TODO: Make sure hostnames are consistent with Mesos + stage.addOutputLoc(smt.partition, bmAddress) + } + if (running.contains(stage) && pendingTasks(stage).isEmpty) { + logInfo(stage + " finished; looking for newly runnable stages") + running -= stage + logInfo("running: " + running) + logInfo("waiting: " + waiting) + logInfo("failed: " + failed) + if (stage.shuffleDep != None) { + mapOutputTracker.registerMapOutputs( + stage.shuffleDep.get.shuffleId, + stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray) + } + updateCacheLocs() + if (stage.outputLocs.count(_ == Nil) != 0) { + // Some tasks had failed; let's resubmit this stage + // TODO: Lower-level scheduler should also deal with this + logInfo("Resubmitting " + stage + " because some of its tasks had failed: " + + stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", ")) + submitStage(stage) + } else { + val newlyRunnable = new ArrayBuffer[Stage] + for (stage <- waiting) { + logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage)) + } + for (stage <- waiting if getMissingParentStages(stage) == Nil) { + newlyRunnable += stage + } + waiting --= newlyRunnable + running ++= newlyRunnable + for (stage <- newlyRunnable.sortBy(_.id)) { + submitMissingTasks(stage) + } + } + } + } + + case Resubmitted => + logInfo("Resubmitted " + task + ", so marking it as still running") + pendingTasks(stage) += task + + case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => + // Mark the stage that the reducer was in as unrunnable + val failedStage = idToStage(task.stageId) + running -= failedStage + failed += failedStage + // TODO: Cancel running tasks in the stage + logInfo("Marking " + failedStage + " for resubmision due to a fetch failure") + // Mark the map whose fetch failed as broken in the map stage + val mapStage = shuffleToMapStage(shuffleId) + mapStage.removeOutputLoc(mapId, bmAddress) + mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + logInfo("The failed fetch was from " + mapStage + "; marking it for resubmission") + failed += mapStage + // Remember that a fetch failed now; this is used to resubmit the broken + // stages later, after a small wait (to give other tasks the chance to fail) + lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock + // TODO: mark the host as failed only if there were lots of fetch failures on it + if (bmAddress != null) { + handleHostLost(bmAddress.ip) + } + + case _ => + // Non-fetch failure -- probably a bug in the job, so bail out + // TODO: Cancel all tasks that are still running + resultStageToJob.get(stage) match { + case Some(job) => + val error = new SparkException("Task failed: " + task + ", reason: " + event.reason) + job.listener.jobFailed(error) + activeJobs -= job + resultStageToJob -= stage + case None => + logInfo("Ignoring result from " + task + " because its job has finished") + } + } + } + + /** + * Responds to a host being lost. This is called inside the event loop so it assumes that it can + * modify the scheduler's internal state. Use hostLost() to post a host lost event from outside. + */ + def handleHostLost(host: String) { + if (!deadHosts.contains(host)) { + logInfo("Host lost: " + host) + deadHosts += host + BlockManagerMaster.notifyADeadHost(host) + // TODO: This will be really slow if we keep accumulating shuffle map stages + for ((shuffleId, stage) <- shuffleToMapStage) { + stage.removeOutputsOnHost(host) + val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray + mapOutputTracker.registerMapOutputs(shuffleId, locs, true) + } + cacheTracker.cacheLost(host) + updateCacheLocs() + } + } + + def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = { + // If the partition is cached, return the cache locations + val cached = getCacheLocs(rdd)(partition) + if (cached != Nil) { + return cached + } + // If the RDD has some placement preferences (as is the case for input RDDs), get those + val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList + if (rddPrefs != Nil) { + return rddPrefs + } + // If the RDD has narrow dependencies, pick the first partition of the first narrow dep + // that has any placement preferences. Ideally we would choose based on transfer sizes, + // but this will do for now. + rdd.dependencies.foreach(_ match { + case n: NarrowDependency[_] => + for (inPart <- n.getParents(partition)) { + val locs = getPreferredLocs(n.rdd, inPart) + if (locs != Nil) + return locs; + } + case _ => + }) + return Nil + } + + def stop() { + // TODO: Put a stop event on our queue and break the event loop + taskSched.stop() + } +} diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala new file mode 100644 index 0000000000000000000000000000000000000000..c10abc92028993d9200676d60139493ee5df5f62 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala @@ -0,0 +1,30 @@ +package spark.scheduler + +import scala.collection.mutable.Map + +import spark._ + +/** + * Types of events that can be handled by the DAGScheduler. The DAGScheduler uses an event queue + * architecture where any thread can post an event (e.g. a task finishing or a new job being + * submitted) but there is a single "logic" thread that reads these events and takes decisions. + * This greatly simplifies synchronization. + */ +sealed trait DAGSchedulerEvent + +case class JobSubmitted( + finalRDD: RDD[_], + func: (TaskContext, Iterator[_]) => _, + partitions: Array[Int], + allowLocal: Boolean, + listener: JobListener) + extends DAGSchedulerEvent + +case class CompletionEvent( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Map[Long, Any]) + extends DAGSchedulerEvent + +case class HostLost(host: String) extends DAGSchedulerEvent diff --git a/core/src/main/scala/spark/scheduler/JobListener.scala b/core/src/main/scala/spark/scheduler/JobListener.scala new file mode 100644 index 0000000000000000000000000000000000000000..d4dd536a7de553f92d3c8a506df39805bb89d77f --- /dev/null +++ b/core/src/main/scala/spark/scheduler/JobListener.scala @@ -0,0 +1,11 @@ +package spark.scheduler + +/** + * Interface used to listen for job completion or failure events after submitting a job to the + * DAGScheduler. The listener is notified each time a task succeeds, as well as if the whole + * job fails (and no further taskSucceeded events will happen). + */ +trait JobListener { + def taskSucceeded(index: Int, result: Any) + def jobFailed(exception: Exception) +} diff --git a/core/src/main/scala/spark/scheduler/JobResult.scala b/core/src/main/scala/spark/scheduler/JobResult.scala new file mode 100644 index 0000000000000000000000000000000000000000..62b458eccbd22822592b236ba2c67ad15c4a2b4b --- /dev/null +++ b/core/src/main/scala/spark/scheduler/JobResult.scala @@ -0,0 +1,9 @@ +package spark.scheduler + +/** + * A result of a job in the DAGScheduler. + */ +sealed trait JobResult + +case class JobSucceeded(results: Seq[_]) extends JobResult +case class JobFailed(exception: Exception) extends JobResult diff --git a/core/src/main/scala/spark/scheduler/JobWaiter.scala b/core/src/main/scala/spark/scheduler/JobWaiter.scala new file mode 100644 index 0000000000000000000000000000000000000000..be8ec9bd7b07e9d8ac8e986ae9a20b575b9bbd0c --- /dev/null +++ b/core/src/main/scala/spark/scheduler/JobWaiter.scala @@ -0,0 +1,43 @@ +package spark.scheduler + +import scala.collection.mutable.ArrayBuffer + +/** + * An object that waits for a DAGScheduler job to complete. + */ +class JobWaiter(totalTasks: Int) extends JobListener { + private val taskResults = ArrayBuffer.fill[Any](totalTasks)(null) + private var finishedTasks = 0 + + private var jobFinished = false // Is the job as a whole finished (succeeded or failed)? + private var jobResult: JobResult = null // If the job is finished, this will be its result + + override def taskSucceeded(index: Int, result: Any) = synchronized { + if (jobFinished) { + throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter") + } + taskResults(index) = result + finishedTasks += 1 + if (finishedTasks == totalTasks) { + jobFinished = true + jobResult = JobSucceeded(taskResults) + this.notifyAll() + } + } + + override def jobFailed(exception: Exception) = synchronized { + if (jobFinished) { + throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter") + } + jobFinished = true + jobResult = JobFailed(exception) + this.notifyAll() + } + + def getResult(): JobResult = synchronized { + while (!jobFinished) { + this.wait() + } + return jobResult + } +} diff --git a/core/src/main/scala/spark/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala similarity index 71% rename from core/src/main/scala/spark/ResultTask.scala rename to core/src/main/scala/spark/scheduler/ResultTask.scala index 3952bf85b2cdb89f83aaed4bbca8c73086e08f5d..d2fab55b5e8a1aa3af9d0ea4f1f9607449dc5b2a 100644 --- a/core/src/main/scala/spark/ResultTask.scala +++ b/core/src/main/scala/spark/scheduler/ResultTask.scala @@ -1,14 +1,15 @@ -package spark +package spark.scheduler + +import spark._ class ResultTask[T, U]( - runId: Int, - stageId: Int, - rdd: RDD[T], + stageId: Int, + rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, - val partition: Int, - locs: Seq[String], + val partition: Int, + @transient locs: Seq[String], val outputId: Int) - extends DAGTask[U](runId, stageId) { + extends Task[U](stageId) { val split = rdd.splits(partition) diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala new file mode 100644 index 0000000000000000000000000000000000000000..317faa08510c9d9969f60d13978165080d761715 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -0,0 +1,135 @@ +package spark.scheduler + +import java.io._ +import java.util.HashMap +import java.util.zip.{GZIPInputStream, GZIPOutputStream} + +import scala.collection.mutable.ArrayBuffer + +import it.unimi.dsi.fastutil.io.FastBufferedOutputStream + +import com.ning.compress.lzf.LZFInputStream +import com.ning.compress.lzf.LZFOutputStream + +import spark._ +import spark.storage._ + +object ShuffleMapTask { + val serializedInfoCache = new HashMap[Int, Array[Byte]] + val deserializedInfoCache = new HashMap[Int, (RDD[_], ShuffleDependency[_,_,_])] + + def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_]): Array[Byte] = { + synchronized { + val old = serializedInfoCache.get(stageId) + if (old != null) { + return old + } else { + val out = new ByteArrayOutputStream + val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) + objOut.writeObject(rdd) + objOut.writeObject(dep) + objOut.close() + val bytes = out.toByteArray + serializedInfoCache.put(stageId, bytes) + return bytes + } + } + } + + def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_,_]) = { + synchronized { + val old = deserializedInfoCache.get(stageId) + if (old != null) { + return old + } else { + val loader = currentThread.getContextClassLoader + val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) + val objIn = new ObjectInputStream(in) { + override def resolveClass(desc: ObjectStreamClass) = + Class.forName(desc.getName, false, loader) + } + val rdd = objIn.readObject().asInstanceOf[RDD[_]] + val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_,_]] + val tuple = (rdd, dep) + deserializedInfoCache.put(stageId, tuple) + return tuple + } + } + } +} + +class ShuffleMapTask( + stageId: Int, + var rdd: RDD[_], + var dep: ShuffleDependency[_,_,_], + var partition: Int, + @transient var locs: Seq[String]) + extends Task[BlockManagerId](stageId) + with Externalizable + with Logging { + + def this() = this(0, null, null, 0, null) + + var split = if (rdd == null) { + null + } else { + rdd.splits(partition) + } + + override def writeExternal(out: ObjectOutput) { + out.writeInt(stageId) + val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) + out.writeInt(bytes.length) + out.write(bytes) + out.writeInt(partition) + out.writeObject(split) + } + + override def readExternal(in: ObjectInput) { + val stageId = in.readInt() + val numBytes = in.readInt() + val bytes = new Array[Byte](numBytes) + in.readFully(bytes) + val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes) + rdd = rdd_ + dep = dep_ + partition = in.readInt() + split = in.readObject().asInstanceOf[Split] + } + + override def run(attemptId: Int): BlockManagerId = { + val numOutputSplits = dep.partitioner.numPartitions + val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]] + val partitioner = dep.partitioner.asInstanceOf[Partitioner] + val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any]) + for (elem <- rdd.iterator(split)) { + val (k, v) = elem.asInstanceOf[(Any, Any)] + var bucketId = partitioner.getPartition(k) + val bucket = buckets(bucketId) + var existing = bucket.get(k) + if (existing == null) { + bucket.put(k, aggregator.createCombiner(v)) + } else { + bucket.put(k, aggregator.mergeValue(existing, v)) + } + } + val ser = SparkEnv.get.serializer.newInstance() + val blockManager = SparkEnv.get.blockManager + for (i <- 0 until numOutputSplits) { + val blockId = "shuffleid_" + dep.shuffleId + "_" + partition + "_" + i + val arr = new ArrayBuffer[Any] + val iter = buckets(i).entrySet().iterator() + while (iter.hasNext()) { + val entry = iter.next() + arr += ((entry.getKey(), entry.getValue())) + } + // TODO: This should probably be DISK_ONLY + blockManager.put(blockId, arr.iterator, StorageLevel.MEMORY_ONLY, false) + } + return SparkEnv.get.blockManager.blockManagerId + } + + override def preferredLocations: Seq[String] = locs + + override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition) +} diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala new file mode 100644 index 0000000000000000000000000000000000000000..cd660c9085a751193bcc99cc93c3499276b7b72a --- /dev/null +++ b/core/src/main/scala/spark/scheduler/Stage.scala @@ -0,0 +1,86 @@ +package spark.scheduler + +import java.net.URI + +import spark._ +import spark.storage.BlockManagerId + +/** + * A stage is a set of independent tasks all computing the same function that need to run as part + * of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run + * by the scheduler is split up into stages at the boundaries where shuffle occurs, and then the + * DAGScheduler runs these stages in topological order. + * + * Each Stage can either be a shuffle map stage, in which case its tasks' results are input for + * another stage, or a result stage, in which case its tasks directly compute the action that + * initiated a job (e.g. count(), save(), etc). For shuffle map stages, we also track the nodes + * that each output partition is on. + * + * Each Stage also has a priority, which is (by default) based on the job it was submitted in. + * This allows Stages from earlier jobs to be computed first or recovered faster on failure. + */ +class Stage( + val id: Int, + val rdd: RDD[_], + val shuffleDep: Option[ShuffleDependency[_,_,_]], // Output shuffle if stage is a map stage + val parents: List[Stage], + val priority: Int) + extends Logging { + + val isShuffleMap = shuffleDep != None + val numPartitions = rdd.splits.size + val outputLocs = Array.fill[List[BlockManagerId]](numPartitions)(Nil) + var numAvailableOutputs = 0 + + private var nextAttemptId = 0 + + def isAvailable: Boolean = { + if (/*parents.size == 0 &&*/ !isShuffleMap) { + true + } else { + numAvailableOutputs == numPartitions + } + } + + def addOutputLoc(partition: Int, bmAddress: BlockManagerId) { + val prevList = outputLocs(partition) + outputLocs(partition) = bmAddress :: prevList + if (prevList == Nil) + numAvailableOutputs += 1 + } + + def removeOutputLoc(partition: Int, bmAddress: BlockManagerId) { + val prevList = outputLocs(partition) + val newList = prevList.filterNot(_ == bmAddress) + outputLocs(partition) = newList + if (prevList != Nil && newList == Nil) { + numAvailableOutputs -= 1 + } + } + + def removeOutputsOnHost(host: String) { + var becameUnavailable = false + for (partition <- 0 until numPartitions) { + val prevList = outputLocs(partition) + val newList = prevList.filterNot(_.ip == host) + outputLocs(partition) = newList + if (prevList != Nil && newList == Nil) { + becameUnavailable = true + numAvailableOutputs -= 1 + } + } + if (becameUnavailable) { + logInfo("%s is now unavailable on %s (%d/%d, %s)".format(this, host, numAvailableOutputs, numPartitions, isAvailable)) + } + } + + def newAttemptId(): Int = { + val id = nextAttemptId + nextAttemptId += 1 + return id + } + + override def toString = "Stage " + id // + ": [RDD = " + rdd.id + ", isShuffle = " + isShuffleMap + "]" + + override def hashCode(): Int = id +} diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala new file mode 100644 index 0000000000000000000000000000000000000000..42325956baa51cf1681799ad9a2b82531a7ef4ce --- /dev/null +++ b/core/src/main/scala/spark/scheduler/Task.scala @@ -0,0 +1,11 @@ +package spark.scheduler + +/** + * A task to execute on a worker node. + */ +abstract class Task[T](val stageId: Int) extends Serializable { + def run(attemptId: Int): T + def preferredLocations: Seq[String] = Nil + + var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler. +} diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala new file mode 100644 index 0000000000000000000000000000000000000000..868ddb237c0a23ca8f55d443df8a2473f1604ddd --- /dev/null +++ b/core/src/main/scala/spark/scheduler/TaskResult.scala @@ -0,0 +1,34 @@ +package spark.scheduler + +import java.io._ + +import scala.collection.mutable.Map + +// Task result. Also contains updates to accumulator variables. +// TODO: Use of distributed cache to return result is a hack to get around +// what seems to be a bug with messages over 60KB in libprocess; fix it +class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any]) extends Externalizable { + def this() = this(null.asInstanceOf[T], null) + + override def writeExternal(out: ObjectOutput) { + out.writeObject(value) + out.writeInt(accumUpdates.size) + for ((key, value) <- accumUpdates) { + out.writeLong(key) + out.writeObject(value) + } + } + + override def readExternal(in: ObjectInput) { + value = in.readObject().asInstanceOf[T] + val numUpdates = in.readInt + if (numUpdates == 0) { + accumUpdates = null + } else { + accumUpdates = Map() + for (i <- 0 until numUpdates) { + accumUpdates(in.readLong()) = in.readObject() + } + } + } +} diff --git a/core/src/main/scala/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/spark/scheduler/TaskScheduler.scala new file mode 100644 index 0000000000000000000000000000000000000000..cb7c375d97e09e07c022fc3dcca238971efbf425 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/TaskScheduler.scala @@ -0,0 +1,27 @@ +package spark.scheduler + +/** + * Low-level task scheduler interface, implemented by both MesosScheduler and LocalScheduler. + * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage, + * and are responsible for sending the tasks to the cluster, running them, retrying if there + * are failures, and mitigating stragglers. They return events to the DAGScheduler through + * the TaskSchedulerListener interface. + */ +trait TaskScheduler { + def start(): Unit + + // Wait for registration with Mesos. + def waitForRegister(): Unit + + // Disconnect from the cluster. + def stop(): Unit + + // Submit a sequence of tasks to run. + def submitTasks(taskSet: TaskSet): Unit + + // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called. + def setListener(listener: TaskSchedulerListener): Unit + + // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. + def defaultParallelism(): Int +} diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala new file mode 100644 index 0000000000000000000000000000000000000000..a647eec9e477831f5c77b84f05344efaaa7ec2d5 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala @@ -0,0 +1,16 @@ +package spark.scheduler + +import scala.collection.mutable.Map + +import spark.TaskEndReason + +/** + * Interface for getting events back from the TaskScheduler. + */ +trait TaskSchedulerListener { + // A task has finished or failed. + def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit + + // A node was lost from the cluster. + def hostLost(host: String): Unit +} diff --git a/core/src/main/scala/spark/scheduler/TaskSet.scala b/core/src/main/scala/spark/scheduler/TaskSet.scala new file mode 100644 index 0000000000000000000000000000000000000000..6f29dd2e9d6dd0688c3a9ac4a38f3fae4fcddb4e --- /dev/null +++ b/core/src/main/scala/spark/scheduler/TaskSet.scala @@ -0,0 +1,9 @@ +package spark.scheduler + +/** + * A set of tasks submitted together to the low-level TaskScheduler, usually representing + * missing partitions of a particular stage. + */ +class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int) { + val id: String = stageId + "." + attempt +} diff --git a/core/src/main/scala/spark/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala similarity index 57% rename from core/src/main/scala/spark/LocalScheduler.scala rename to core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 3910c7b09e915c173c41c8d6b96bc427d2b6aea1..8339c0ae9025aab942f26f97a078d31235f99613 100644 --- a/core/src/main/scala/spark/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -1,16 +1,21 @@ -package spark +package spark.scheduler.local import java.util.concurrent.Executors import java.util.concurrent.atomic.AtomicInteger +import spark._ +import spark.scheduler._ + /** - * A simple Scheduler implementation that runs tasks locally in a thread pool. Optionally the - * scheduler also allows each task to fail up to maxFailures times, which is useful for testing - * fault recovery. + * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally + * the scheduler also allows each task to fail up to maxFailures times, which is useful for + * testing fault recovery. */ -private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGScheduler with Logging { +class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with Logging { var attemptId = new AtomicInteger(0) var threadPool = Executors.newFixedThreadPool(threads, DaemonThreadFactory) + val env = SparkEnv.get + var listener: TaskSchedulerListener = null // TODO: Need to take into account stage priority in scheduling @@ -18,7 +23,12 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule override def waitForRegister() {} - override def submitTasks(tasks: Seq[Task[_]], runId: Int) { + override def setListener(listener: TaskSchedulerListener) { + this.listener = listener + } + + override def submitTasks(taskSet: TaskSet) { + val tasks = taskSet.tasks val failCount = new Array[Int](tasks.size) def submitTask(task: Task[_], idInJob: Int) { @@ -38,23 +48,14 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule // Serialize and deserialize the task so that accumulators are changed to thread-local ones; // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. Accumulators.clear - val ser = SparkEnv.get.closureSerializer.newInstance() - val startTime = System.currentTimeMillis - val bytes = ser.serialize(task) - val timeTaken = System.currentTimeMillis - startTime - logInfo("Size of task %d is %d bytes and took %d ms to serialize".format( - idInJob, bytes.size, timeTaken)) - val deserializedTask = ser.deserialize[Task[_]](bytes, currentThread.getContextClassLoader) + val bytes = Utils.serialize(task) + logInfo("Size of task " + idInJob + " is " + bytes.size + " bytes") + val deserializedTask = Utils.deserialize[Task[_]]( + bytes, Thread.currentThread.getContextClassLoader) val result: Any = deserializedTask.run(attemptId) - - // Serialize and deserialize the result to emulate what the mesos - // executor does. This is useful to catch serialization errors early - // on in development (so when users move their local Spark programs - // to the cluster, they don't get surprised by serialization errors). - val resultToReturn = ser.deserialize[Any](ser.serialize(result)) val accumUpdates = Accumulators.values logInfo("Finished task " + idInJob) - taskEnded(task, Success, resultToReturn, accumUpdates) + listener.taskEnded(task, Success, result, accumUpdates) } catch { case t: Throwable => { logError("Exception in task " + idInJob, t) @@ -64,7 +65,7 @@ private class LocalScheduler(threads: Int, maxFailures: Int) extends DAGSchedule submitTask(task, idInJob) } else { // TODO: Do something nicer here to return all the way to the user - taskEnded(task, new ExceptionFailure(t), null, null) + listener.taskEnded(task, new ExceptionFailure(t), null, null) } } } diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala new file mode 100644 index 0000000000000000000000000000000000000000..8182901ce3abb6d80b5f8bbcf1008098fd44b304 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala @@ -0,0 +1,364 @@ +package spark.scheduler.mesos + +import java.io.{File, FileInputStream, FileOutputStream} +import java.util.{ArrayList => JArrayList} +import java.util.{List => JList} +import java.util.{HashMap => JHashMap} +import java.util.concurrent._ + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.collection.mutable.Map +import scala.collection.mutable.PriorityQueue +import scala.collection.JavaConversions._ +import scala.math.Ordering + +import akka.actor._ +import akka.actor.Actor +import akka.actor.Actor._ +import akka.actor.Channel +import akka.serialization.RemoteActorSerialization._ + +import com.google.protobuf.ByteString + +import org.apache.mesos.{Scheduler => MScheduler} +import org.apache.mesos._ +import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _} + +import spark._ +import spark.scheduler._ + +sealed trait CoarseMesosSchedulerMessage +case class RegisterSlave(slaveId: String, host: String, port: Int) extends CoarseMesosSchedulerMessage +case class StatusUpdate(slaveId: String, status: TaskStatus) extends CoarseMesosSchedulerMessage +case class LaunchTask(slaveId: String, task: MTaskInfo) extends CoarseMesosSchedulerMessage +case class ReviveOffers() extends CoarseMesosSchedulerMessage + +case class FakeOffer(slaveId: String, host: String, cores: Int) + +/** + * Mesos scheduler that uses coarse-grained tasks and does its own fine-grained scheduling inside + * them using Akka actors for messaging. Clients should first call start(), then submit task sets + * through the runTasks method. + * + * TODO: This is a pretty big hack for now. + */ +class CoarseMesosScheduler( + sc: SparkContext, + master: String, + frameworkName: String) + extends MesosScheduler(sc, master, frameworkName) { + + val CORES_PER_SLAVE = System.getProperty("spark.coarseMesosScheduler.coresPerSlave", "4").toInt + + class MasterActor extends Actor { + val slaveActor = new HashMap[String, ActorRef] + val slaveHost = new HashMap[String, String] + val freeCores = new HashMap[String, Int] + + def receive = { + case RegisterSlave(slaveId, host, port) => + slaveActor(slaveId) = remote.actorFor("WorkerActor", host, port) + logInfo("Slave actor: " + slaveActor(slaveId)) + slaveHost(slaveId) = host + freeCores(slaveId) = CORES_PER_SLAVE + makeFakeOffers() + + case StatusUpdate(slaveId, status) => + fakeStatusUpdate(status) + if (isFinished(status.getState)) { + freeCores(slaveId) += 1 + makeFakeOffers(slaveId) + } + + case LaunchTask(slaveId, task) => + freeCores(slaveId) -= 1 + slaveActor(slaveId) ! LaunchTask(slaveId, task) + + case ReviveOffers() => + logInfo("Reviving offers") + makeFakeOffers() + } + + // Make fake resource offers for all slaves + def makeFakeOffers() { + fakeResourceOffers(slaveHost.toSeq.map{case (id, host) => FakeOffer(id, host, freeCores(id))}) + } + + // Make fake resource offers for all slaves + def makeFakeOffers(slaveId: String) { + fakeResourceOffers(Seq(FakeOffer(slaveId, slaveHost(slaveId), freeCores(slaveId)))) + } + } + + val masterActor: ActorRef = actorOf(new MasterActor) + remote.register("MasterActor", masterActor) + masterActor.start() + + val taskIdsOnSlave = new HashMap[String, HashSet[String]] + + /** + * Method called by Mesos to offer resources on slaves. We resond by asking our active task sets + * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that + * tasks are balanced across the cluster. + */ + override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { + synchronized { + val tasks = offers.map(o => new JArrayList[MTaskInfo]) + for (i <- 0 until offers.size) { + val o = offers.get(i) + val slaveId = o.getSlaveId.getValue + if (!slaveIdToHost.contains(slaveId)) { + slaveIdToHost(slaveId) = o.getHostname + hostsAlive += o.getHostname + taskIdsOnSlave(slaveId) = new HashSet[String] + // Launch an infinite task on the node that will talk to the MasterActor to get fake tasks + val cpuRes = Resource.newBuilder() + .setName("cpus") + .setType(Value.Type.SCALAR) + .setScalar(Value.Scalar.newBuilder().setValue(1).build()) + .build() + val task = new WorkerTask(slaveId, o.getHostname) + val serializedTask = Utils.serialize(task) + tasks(i).add(MTaskInfo.newBuilder() + .setTaskId(newTaskId()) + .setSlaveId(o.getSlaveId) + .setExecutor(executorInfo) + .setName("worker task") + .addResources(cpuRes) + .setData(ByteString.copyFrom(serializedTask)) + .build()) + } + } + val filters = Filters.newBuilder().setRefuseSeconds(10).build() + for (i <- 0 until offers.size) { + d.launchTasks(offers(i).getId(), tasks(i), filters) + } + } + } + + override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { + val tid = status.getTaskId.getValue + var taskSetToUpdate: Option[TaskSetManager] = None + var taskFailed = false + synchronized { + try { + taskIdToTaskSetId.get(tid) match { + case Some(taskSetId) => + if (activeTaskSets.contains(taskSetId)) { + //activeTaskSets(taskSetId).statusUpdate(status) + taskSetToUpdate = Some(activeTaskSets(taskSetId)) + } + if (isFinished(status.getState)) { + taskIdToTaskSetId.remove(tid) + if (taskSetTaskIds.contains(taskSetId)) { + taskSetTaskIds(taskSetId) -= tid + } + val slaveId = taskIdToSlaveId(tid) + taskIdToSlaveId -= tid + taskIdsOnSlave(slaveId) -= tid + } + if (status.getState == TaskState.TASK_FAILED) { + taskFailed = true + } + case None => + logInfo("Ignoring update from TID " + tid + " because its task set is gone") + } + } catch { + case e: Exception => logError("Exception in statusUpdate", e) + } + } + // Update the task set and DAGScheduler without holding a lock on this, because that can deadlock + if (taskSetToUpdate != None) { + taskSetToUpdate.get.statusUpdate(status) + } + if (taskFailed) { + // Revive offers if a task had failed for some reason other than host lost + reviveOffers() + } + } + + override def slaveLost(d: SchedulerDriver, s: SlaveID) { + logInfo("Slave lost: " + s.getValue) + var failedHost: Option[String] = None + var lostTids: Option[HashSet[String]] = None + synchronized { + val slaveId = s.getValue + val host = slaveIdToHost(slaveId) + if (hostsAlive.contains(host)) { + slaveIdsWithExecutors -= slaveId + hostsAlive -= host + failedHost = Some(host) + lostTids = Some(taskIdsOnSlave(slaveId)) + logInfo("failedHost: " + host) + logInfo("lostTids: " + lostTids) + taskIdsOnSlave -= slaveId + activeTaskSetsQueue.foreach(_.hostLost(host)) + } + } + if (failedHost != None) { + // Report all the tasks on the failed host as lost, without holding a lock on this + for (tid <- lostTids.get; taskSetId <- taskIdToTaskSetId.get(tid)) { + // TODO: Maybe call our statusUpdate() instead to clean our internal data structures + activeTaskSets(taskSetId).statusUpdate(TaskStatus.newBuilder() + .setTaskId(TaskID.newBuilder().setValue(tid).build()) + .setState(TaskState.TASK_LOST) + .build()) + } + // Also report the loss to the DAGScheduler + listener.hostLost(failedHost.get) + reviveOffers(); + } + } + + override def offerRescinded(d: SchedulerDriver, o: OfferID) {} + + // Check for speculatable tasks in all our active jobs. + override def checkSpeculatableTasks() { + var shouldRevive = false + synchronized { + for (ts <- activeTaskSetsQueue) { + shouldRevive |= ts.checkSpeculatableTasks() + } + } + if (shouldRevive) { + reviveOffers() + } + } + + + val lock2 = new Object + var firstWait = true + + override def waitForRegister() { + lock2.synchronized { + if (firstWait) { + super.waitForRegister() + Thread.sleep(5000) + firstWait = false + } + } + } + + def fakeStatusUpdate(status: TaskStatus) { + statusUpdate(driver, status) + } + + def fakeResourceOffers(offers: Seq[FakeOffer]) { + logDebug("fakeResourceOffers: " + offers) + val availableCpus = offers.map(_.cores.toDouble).toArray + var launchedTask = false + for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) { + do { + launchedTask = false + for (i <- 0 until offers.size if hostsAlive.contains(offers(i).host)) { + manager.slaveOffer(offers(i).slaveId, offers(i).host, availableCpus(i)) match { + case Some(task) => + val tid = task.getTaskId.getValue + val sid = offers(i).slaveId + taskIdToTaskSetId(tid) = manager.taskSet.id + taskSetTaskIds(manager.taskSet.id) += tid + taskIdToSlaveId(tid) = sid + taskIdsOnSlave(sid) += tid + slaveIdsWithExecutors += sid + availableCpus(i) -= getResource(task.getResourcesList(), "cpus") + launchedTask = true + masterActor ! LaunchTask(sid, task) + + case None => {} + } + } + } while (launchedTask) + } + } + + override def reviveOffers() { + masterActor ! ReviveOffers() + } +} + +class WorkerTask(slaveId: String, host: String) extends Task[Unit](-1) { + generation = 0 + + def run(id: Int): Unit = { + val actor = actorOf(new WorkerActor(slaveId, host)) + if (!remote.isRunning) { + remote.start(Utils.localIpAddress, 7078) + } + remote.register("WorkerActor", actor) + actor.start() + while (true) { + Thread.sleep(10000) + } + } +} + +class WorkerActor(slaveId: String, host: String) extends Actor with Logging { + val env = SparkEnv.get + val classLoader = currentThread.getContextClassLoader + val threadPool = new ThreadPoolExecutor( + 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable]) + + val masterIp: String = System.getProperty("spark.master.host", "localhost") + val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt + val masterActor = remote.actorFor("MasterActor", masterIp, masterPort) + + class TaskRunner(desc: MTaskInfo) + extends Runnable { + override def run() = { + val tid = desc.getTaskId.getValue + logInfo("Running task ID " + tid) + try { + SparkEnv.set(env) + Thread.currentThread.setContextClassLoader(classLoader) + Accumulators.clear + val task = Utils.deserialize[Task[Any]](desc.getData.toByteArray, classLoader) + env.mapOutputTracker.updateGeneration(task.generation) + val value = task.run(tid.toInt) + val accumUpdates = Accumulators.values + val result = new TaskResult(value, accumUpdates) + masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder() + .setTaskId(desc.getTaskId) + .setState(TaskState.TASK_FINISHED) + .setData(ByteString.copyFrom(Utils.serialize(result))) + .build()) + logInfo("Finished task ID " + tid) + } catch { + case ffe: FetchFailedException => { + val reason = ffe.toTaskEndReason + masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder() + .setTaskId(desc.getTaskId) + .setState(TaskState.TASK_FAILED) + .setData(ByteString.copyFrom(Utils.serialize(reason))) + .build()) + } + case t: Throwable => { + val reason = ExceptionFailure(t) + masterActor ! StatusUpdate(slaveId, TaskStatus.newBuilder() + .setTaskId(desc.getTaskId) + .setState(TaskState.TASK_FAILED) + .setData(ByteString.copyFrom(Utils.serialize(reason))) + .build()) + + // TODO: Should we exit the whole executor here? On the one hand, the failed task may + // have left some weird state around depending on when the exception was thrown, but on + // the other hand, maybe we could detect that when future tasks fail and exit then. + logError("Exception in task ID " + tid, t) + //System.exit(1) + } + } + } + } + + override def preStart { + val ref = toRemoteActorRefProtocol(self).toByteArray + logInfo("Registering with master") + masterActor ! RegisterSlave(slaveId, host, remote.address.getPort) + } + + override def receive = { + case LaunchTask(slaveId, task) => + threadPool.execute(new TaskRunner(task)) + } +} diff --git a/core/src/main/scala/spark/MesosScheduler.scala b/core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala similarity index 58% rename from core/src/main/scala/spark/MesosScheduler.scala rename to core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala index a7711e0d352f04c004aa3030413f1593f4a76849..f72618c03fc8a1b996f32c86678b19de6ecf31cd 100644 --- a/core/src/main/scala/spark/MesosScheduler.scala +++ b/core/src/main/scala/spark/scheduler/mesos/MesosScheduler.scala @@ -1,4 +1,4 @@ -package spark +package spark.scheduler.mesos import java.io.{File, FileInputStream, FileOutputStream} import java.util.{ArrayList => JArrayList} @@ -17,20 +17,23 @@ import com.google.protobuf.ByteString import org.apache.mesos.{Scheduler => MScheduler} import org.apache.mesos._ -import org.apache.mesos.Protos._ +import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _} + +import spark._ +import spark.scheduler._ /** - * The main Scheduler implementation, which runs jobs on Mesos. Clients should first call start(), - * then submit tasks through the runTasks method. + * The main TaskScheduler implementation, which runs tasks on Mesos. Clients should first call + * start(), then submit task sets through the runTasks method. */ -private class MesosScheduler( +class MesosScheduler( sc: SparkContext, master: String, frameworkName: String) - extends MScheduler - with DAGScheduler + extends TaskScheduler + with MScheduler with Logging { - + // Environment variables to pass to our executors val ENV_VARS_TO_SEND_TO_EXECUTORS = Array( "SPARK_MEM", @@ -49,55 +52,60 @@ private class MesosScheduler( } } + // How often to check for speculative tasks + val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong + // Lock used to wait for scheduler to be registered - private var isRegistered = false - private val registeredLock = new Object() + var isRegistered = false + val registeredLock = new Object() - private val activeJobs = new HashMap[Int, Job] - private var activeJobsQueue = new ArrayBuffer[Job] + val activeTaskSets = new HashMap[String, TaskSetManager] + var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] - private val taskIdToJobId = new HashMap[String, Int] - private val taskIdToSlaveId = new HashMap[String, String] - private val jobTasks = new HashMap[Int, HashSet[String]] + val taskIdToTaskSetId = new HashMap[String, String] + val taskIdToSlaveId = new HashMap[String, String] + val taskSetTaskIds = new HashMap[String, HashSet[String]] - // Incrementing job and task IDs - private var nextJobId = 0 - private var nextTaskId = 0 + // Incrementing Mesos task IDs + var nextTaskId = 0 // Driver for talking to Mesos var driver: SchedulerDriver = null - // Which nodes we have executors on - private val slavesWithExecutors = new HashSet[String] + // Which hosts in the cluster are alive (contains hostnames) + val hostsAlive = new HashSet[String] + + // Which slave IDs we have executors on + val slaveIdsWithExecutors = new HashSet[String] + + val slaveIdToHost = new HashMap[String, String] // JAR server, if any JARs were added by the user to the SparkContext var jarServer: HttpServer = null // URIs of JARs to pass to executor var jarUris: String = "" - + // Create an ExecutorInfo for our tasks val executorInfo = createExecutorInfo() - // Sorts jobs in reverse order of run ID for use in our priority queue (so lower IDs run first) - private val jobOrdering = new Ordering[Job] { - override def compare(j1: Job, j2: Job): Int = j2.runId - j1.runId - } - - def newJobId(): Int = this.synchronized { - val id = nextJobId - nextJobId += 1 - return id + // Listener object to pass upcalls into + var listener: TaskSchedulerListener = null + + val mapOutputTracker = SparkEnv.get.mapOutputTracker + + override def setListener(listener: TaskSchedulerListener) { + this.listener = listener } def newTaskId(): TaskID = { - val id = "" + nextTaskId; - nextTaskId += 1; - return TaskID.newBuilder().setValue(id).build() + val id = TaskID.newBuilder().setValue("" + nextTaskId).build() + nextTaskId += 1 + return id } override def start() { - new Thread("Spark scheduler") { + new Thread("MesosScheduler driver") { setDaemon(true) override def run { val sched = MesosScheduler.this @@ -110,12 +118,27 @@ private class MesosScheduler( case e: Exception => logError("driver.run() failed", e) } } - }.start + }.start() + if (System.getProperty("spark.speculation", "false") == "true") { + new Thread("MesosScheduler speculation check") { + setDaemon(true) + override def run { + waitForRegister() + while (true) { + try { + Thread.sleep(SPECULATION_INTERVAL) + } catch { case e: InterruptedException => {} } + checkSpeculatableTasks() + } + } + }.start() + } } def createExecutorInfo(): ExecutorInfo = { val sparkHome = sc.getSparkHome match { - case Some(path) => path + case Some(path) => + path case None => throw new SparkException("Spark home is not set; set it through the spark.home system " + "property, the SPARK_HOME environment variable or the SparkContext constructor") @@ -151,27 +174,26 @@ private class MesosScheduler( .build() } - def submitTasks(tasks: Seq[Task[_]], runId: Int) { - logInfo("Got a job with " + tasks.size + " tasks") + def submitTasks(taskSet: TaskSet) { + val tasks = taskSet.tasks + logInfo("Adding task set " + taskSet.id + " with " + tasks.size + " tasks") waitForRegister() this.synchronized { - val jobId = newJobId() - val myJob = new SimpleJob(this, tasks, runId, jobId) - activeJobs(jobId) = myJob - activeJobsQueue += myJob - logInfo("Adding job with ID " + jobId) - jobTasks(jobId) = HashSet.empty[String] + val manager = new TaskSetManager(this, taskSet) + activeTaskSets(taskSet.id) = manager + activeTaskSetsQueue += manager + taskSetTaskIds(taskSet.id) = new HashSet() } - driver.reviveOffers(); + reviveOffers(); } - def jobFinished(job: Job) { + def taskSetFinished(manager: TaskSetManager) { this.synchronized { - activeJobs -= job.jobId - activeJobsQueue -= job - taskIdToJobId --= jobTasks(job.jobId) - taskIdToSlaveId --= jobTasks(job.jobId) - jobTasks.remove(job.jobId) + activeTaskSets -= manager.taskSet.id + activeTaskSetsQueue -= manager + taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) + taskIdToSlaveId --= taskSetTaskIds(manager.taskSet.id) + taskSetTaskIds.remove(manager.taskSet.id) } } @@ -196,33 +218,40 @@ private class MesosScheduler( override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} /** - * Method called by Mesos to offer resources on slaves. We resond by asking our active jobs for - * tasks in FIFO order. We fill each node with tasks in a round-robin manner so that tasks are - * balanced across the cluster. + * Method called by Mesos to offer resources on slaves. We resond by asking our active task sets + * for tasks in order of priority. We fill each node with tasks in a round-robin manner so that + * tasks are balanced across the cluster. */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { synchronized { - val tasks = offers.map(o => new JArrayList[TaskInfo]) + // Mark each slave as alive and remember its hostname + for (o <- offers) { + slaveIdToHost(o.getSlaveId.getValue) = o.getHostname + hostsAlive += o.getHostname + } + // Build a list of tasks to assign to each slave + val tasks = offers.map(o => new JArrayList[MTaskInfo]) val availableCpus = offers.map(o => getResource(o.getResourcesList(), "cpus")) val enoughMem = offers.map(o => { val mem = getResource(o.getResourcesList(), "mem") val slaveId = o.getSlaveId.getValue - mem >= EXECUTOR_MEMORY || slavesWithExecutors.contains(slaveId) + mem >= EXECUTOR_MEMORY || slaveIdsWithExecutors.contains(slaveId) }) var launchedTask = false - for (job <- activeJobsQueue.sorted(jobOrdering)) { + for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) { do { launchedTask = false for (i <- 0 until offers.size if enoughMem(i)) { - job.slaveOffer(offers(i), availableCpus(i)) match { + val sid = offers(i).getSlaveId.getValue + val host = offers(i).getHostname + manager.slaveOffer(sid, host, availableCpus(i)) match { case Some(task) => tasks(i).add(task) val tid = task.getTaskId.getValue - val sid = offers(i).getSlaveId.getValue - taskIdToJobId(tid) = job.jobId - jobTasks(job.jobId) += tid + taskIdToTaskSetId(tid) = manager.taskSet.id + taskSetTaskIds(manager.taskSet.id) += tid taskIdToSlaveId(tid) = sid - slavesWithExecutors += sid + slaveIdsWithExecutors += sid availableCpus(i) -= getResource(task.getResourcesList(), "cpus") launchedTask = true @@ -256,53 +285,74 @@ private class MesosScheduler( } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { - var jobToUpdate: Option[Job] = None + val tid = status.getTaskId.getValue + var taskSetToUpdate: Option[TaskSetManager] = None + var failedHost: Option[String] = None + var taskFailed = false synchronized { try { - val tid = status.getTaskId.getValue - if (status.getState == TaskState.TASK_LOST - && taskIdToSlaveId.contains(tid)) { + if (status.getState == TaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) { // We lost the executor on this slave, so remember that it's gone - slavesWithExecutors -= taskIdToSlaveId(tid) + val slaveId = taskIdToSlaveId(tid) + val host = slaveIdToHost(slaveId) + if (hostsAlive.contains(host)) { + slaveIdsWithExecutors -= slaveId + hostsAlive -= host + activeTaskSetsQueue.foreach(_.hostLost(host)) + failedHost = Some(host) + } } - taskIdToJobId.get(tid) match { - case Some(jobId) => - if (activeJobs.contains(jobId)) { - jobToUpdate = Some(activeJobs(jobId)) + taskIdToTaskSetId.get(tid) match { + case Some(taskSetId) => + if (activeTaskSets.contains(taskSetId)) { + //activeTaskSets(taskSetId).statusUpdate(status) + taskSetToUpdate = Some(activeTaskSets(taskSetId)) } if (isFinished(status.getState)) { - taskIdToJobId.remove(tid) - if (jobTasks.contains(jobId)) { - jobTasks(jobId) -= tid + taskIdToTaskSetId.remove(tid) + if (taskSetTaskIds.contains(taskSetId)) { + taskSetTaskIds(taskSetId) -= tid } taskIdToSlaveId.remove(tid) } + if (status.getState == TaskState.TASK_FAILED) { + taskFailed = true + } case None => - logInfo("Ignoring update from TID " + tid + " because its job is gone") + logInfo("Ignoring update from TID " + tid + " because its task set is gone") } } catch { case e: Exception => logError("Exception in statusUpdate", e) } } - for (j <- jobToUpdate) { - j.statusUpdate(status) + // Update the task set and DAGScheduler without holding a lock on this, because that can deadlock + if (taskSetToUpdate != None) { + taskSetToUpdate.get.statusUpdate(status) + } + if (failedHost != None) { + listener.hostLost(failedHost.get) + reviveOffers(); + } + if (taskFailed) { + // Also revive offers if a task had failed for some reason other than host lost + reviveOffers() } } override def error(d: SchedulerDriver, message: String) { logError("Mesos error: " + message) synchronized { - if (activeJobs.size > 0) { - // Have each job throw a SparkException with the error - for ((jobId, activeJob) <- activeJobs) { + if (activeTaskSets.size > 0) { + // Have each task set throw a SparkException with the error + for ((taskSetId, manager) <- activeTaskSets) { try { - activeJob.error(message) + manager.error(message) } catch { case e: Exception => logError("Exception in error callback", e) } } } else { - // No jobs are active but we still got an error. Just exit since this + // No task sets are active but we still got an error. Just exit since this // must mean the error is during registration. // It might be good to do something smarter here in the future. System.exit(1) @@ -373,41 +423,68 @@ private class MesosScheduler( return Utils.serialize(props.toArray) } - override def frameworkMessage( - d: SchedulerDriver, - e: ExecutorID, - s: SlaveID, - b: Array[Byte]) {} + override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} override def slaveLost(d: SchedulerDriver, s: SlaveID) { - slavesWithExecutors.remove(s.getValue) + var failedHost: Option[String] = None + synchronized { + val slaveId = s.getValue + val host = slaveIdToHost(slaveId) + if (hostsAlive.contains(host)) { + slaveIdsWithExecutors -= slaveId + hostsAlive -= host + activeTaskSetsQueue.foreach(_.hostLost(host)) + failedHost = Some(host) + } + } + if (failedHost != None) { + listener.hostLost(failedHost.get) + reviveOffers(); + } } override def executorLost(d: SchedulerDriver, e: ExecutorID, s: SlaveID, status: Int) { - slavesWithExecutors.remove(s.getValue) + logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) + slaveLost(d, s) } override def offerRescinded(d: SchedulerDriver, o: OfferID) {} + + // Check for speculatable tasks in all our active jobs. + def checkSpeculatableTasks() { + var shouldRevive = false + synchronized { + for (ts <- activeTaskSetsQueue) { + shouldRevive |= ts.checkSpeculatableTasks() + } + } + if (shouldRevive) { + reviveOffers() + } + } + + def reviveOffers() { + driver.reviveOffers() + } } object MesosScheduler { /** - * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. - * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM + * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. + * This is used to figure out how much memory to claim from Mesos based on the SPARK_MEM * environment variable. */ def memoryStringToMb(str: String): Int = { val lower = str.toLowerCase if (lower.endsWith("k")) { - (lower.substring(0, lower.length - 1).toLong / 1024).toInt + (lower.substring(0, lower.length-1).toLong / 1024).toInt } else if (lower.endsWith("m")) { - lower.substring(0, lower.length - 1).toInt + lower.substring(0, lower.length-1).toInt } else if (lower.endsWith("g")) { - lower.substring(0, lower.length - 1).toInt * 1024 + lower.substring(0, lower.length-1).toInt * 1024 } else if (lower.endsWith("t")) { - lower.substring(0, lower.length - 1).toInt * 1024 * 1024 - } else { - // no suffix, so it's just a number in bytes + lower.substring(0, lower.length-1).toInt * 1024 * 1024 + } else {// no suffix, so it's just a number in bytes (lower.toLong / 1024 / 1024).toInt } } diff --git a/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala b/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala new file mode 100644 index 0000000000000000000000000000000000000000..af2f80ea6671756f768c66be2f4ae2142c9f23d4 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/mesos/TaskInfo.scala @@ -0,0 +1,32 @@ +package spark.scheduler.mesos + +/** + * Information about a running task attempt. + */ +class TaskInfo(val taskId: String, val index: Int, val launchTime: Long, val host: String) { + var finishTime: Long = 0 + var failed = false + + def markSuccessful(time: Long = System.currentTimeMillis) { + finishTime = time + } + + def markFailed(time: Long = System.currentTimeMillis) { + finishTime = time + failed = true + } + + def finished: Boolean = finishTime != 0 + + def successful: Boolean = finished && !failed + + def duration: Long = { + if (!finished) { + throw new UnsupportedOperationException("duration() called on unfinished tasks") + } else { + finishTime - launchTime + } + } + + def timeRunning(currentTime: Long): Long = currentTime - launchTime +} diff --git a/core/src/main/scala/spark/SimpleJob.scala b/core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala similarity index 50% rename from core/src/main/scala/spark/SimpleJob.scala rename to core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala index 01c7efff1e0af2bed9c6085b0958847968441c37..535c17d9d4db78f29acca2b7e458159664a28391 100644 --- a/core/src/main/scala/spark/SimpleJob.scala +++ b/core/src/main/scala/spark/scheduler/mesos/TaskSetManager.scala @@ -1,28 +1,32 @@ -package spark +package spark.scheduler.mesos +import java.util.Arrays import java.util.{HashMap => JHashMap} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.math.max +import scala.math.min import com.google.protobuf.ByteString import org.apache.mesos._ -import org.apache.mesos.Protos._ +import org.apache.mesos.Protos.{TaskInfo => MTaskInfo, _} + +import spark._ +import spark.scheduler._ /** - * A Job that runs a set of tasks with no interdependencies. + * Schedules the tasks within a single TaskSet in the MesosScheduler. */ -class SimpleJob( +class TaskSetManager( sched: MesosScheduler, - tasksSeq: Seq[Task[_]], - runId: Int, - jobId: Int) - extends Job(runId, jobId) - with Logging { + val taskSet: TaskSet) + extends Logging { // Maximum time to wait to run a task in a preferred location (in ms) - val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "5000").toLong + val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong // CPUs to request per task val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble @@ -30,18 +34,20 @@ class SimpleJob( // Maximum times a task is allowed to fail before failing the job val MAX_TASK_FAILURES = 4 + // Quantile of tasks at which to start speculation + val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble + val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble + // Serializer for closures and tasks. val ser = SparkEnv.get.closureSerializer.newInstance() - val callingThread = Thread.currentThread - val tasks = tasksSeq.toArray + val priority = taskSet.priority + val tasks = taskSet.tasks val numTasks = tasks.length - val launched = new Array[Boolean](numTasks) + val copiesRunning = new Array[Int](numTasks) val finished = new Array[Boolean](numTasks) val numFailures = new Array[Int](numTasks) - val tidToIndex = HashMap[String, Int]() - - var tasksLaunched = 0 + val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) var tasksFinished = 0 // Last time when we launched a preferred task (for delay scheduling) @@ -62,6 +68,13 @@ class SimpleJob( // List containing all pending tasks (also used as a stack, as above) val allPendingTasks = new ArrayBuffer[Int] + // Tasks that can be specualted. Since these will be a small fraction of total + // tasks, we'll just hold them in a HaskSet. + val speculatableTasks = new HashSet[Int] + + // Task index, start and finish time for each task attempt (indexed by task ID) + val taskInfos = new HashMap[String, TaskInfo] + // Did the job fail? var failed = false var causeOfFailure = "" @@ -76,6 +89,12 @@ class SimpleJob( // exceptions automatically. val recentExceptions = HashMap[String, (Int, Long)]() + // Figure out the current map output tracker generation and set it on all tasks + val generation = sched.mapOutputTracker.getGeneration + for (t <- tasks) { + t.generation = generation + } + // Add all our tasks to the pending lists. We do this in reverse order // of task index so that tasks with low indices get launched first. for (i <- (0 until numTasks).reverse) { @@ -84,7 +103,7 @@ class SimpleJob( // Add a task to all the pending-task lists that it should be on. def addPendingTask(index: Int) { - val locations = tasks(index).preferredLocations + val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive if (locations.size == 0) { pendingTasksWithNoPrefs += index } else { @@ -110,13 +129,37 @@ class SimpleJob( while (!list.isEmpty) { val index = list.last list.trimEnd(1) - if (!launched(index) && !finished(index)) { + if (copiesRunning(index) == 0 && !finished(index)) { return Some(index) } } return None } + // Return a speculative task for a given host if any are available. The task should not have an + // attempt running on this host, in case the host is slow. In addition, if localOnly is set, the + // task must have a preference for this host (or no preferred locations at all). + def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = { + speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set + val localTask = speculatableTasks.find { index => + val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive + val attemptLocs = taskAttempts(index).map(_.host) + (locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host) + } + if (localTask != None) { + speculatableTasks -= localTask.get + return localTask + } + if (!localOnly && speculatableTasks.size > 0) { + val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.host).contains(host)) + if (nonLocalTask != None) { + speculatableTasks -= nonLocalTask.get + return nonLocalTask + } + } + return None + } + // Dequeue a pending task for a given node and return its index. // If localOnly is set to false, allow non-local tasks as well. def findTask(host: String, localOnly: Boolean): Option[Int] = { @@ -129,10 +172,13 @@ class SimpleJob( return noPrefTask } if (!localOnly) { - return findTaskFromList(allPendingTasks) // Look for non-local task - } else { - return None + val nonLocalTask = findTaskFromList(allPendingTasks) + if (nonLocalTask != None) { + return nonLocalTask + } } + // Finally, if all else has failed, find a speculative task + return findSpeculativeTask(host, localOnly) } // Does a host count as a preferred location for a task? This is true if @@ -144,11 +190,11 @@ class SimpleJob( } // Respond to an offer of a single slave from the scheduler by finding a task - def slaveOffer(offer: Offer, availableCpus: Double): Option[TaskInfo] = { - if (tasksLaunched < numTasks && availableCpus >= CPUS_PER_TASK) { + def slaveOffer(slaveId: String, host: String, availableCpus: Double): Option[MTaskInfo] = { + if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { val time = System.currentTimeMillis - val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT) - val host = offer.getHostname + var localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT) + findTask(host, localOnly) match { case Some(index) => { // Found a task; do some bookkeeping and return a Mesos task for it @@ -156,17 +202,17 @@ class SimpleJob( val taskId = sched.newTaskId() // Figure out whether this should count as a preferred launch val preferred = isPreferredLocation(task, host) - val prefStr = if(preferred) "preferred" else "non-preferred" - val message = - "Starting task %d:%d as TID %s on slave %s: %s (%s)".format( - jobId, index, taskId.getValue, offer.getSlaveId.getValue, host, prefStr) - logInfo(message) + val prefStr = if (preferred) "preferred" else "non-preferred" + logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format( + taskSet.id, index, taskId.getValue, slaveId, host, prefStr)) // Do various bookkeeping - tidToIndex(taskId.getValue) = index - launched(index) = true - tasksLaunched += 1 - if (preferred) + copiesRunning(index) += 1 + val info = new TaskInfo(taskId.getValue, index, time, host) + taskInfos(taskId.getValue) = info + taskAttempts(index) = info :: taskAttempts(index) + if (preferred) { lastPreferredLaunchTime = time + } // Create and return the Mesos task object val cpuRes = Resource.newBuilder() .setName("cpus") @@ -178,13 +224,13 @@ class SimpleJob( val serializedTask = ser.serialize(task) val timeTaken = System.currentTimeMillis - startTime - logInfo("Size of task %d:%d is %d bytes and took %d ms to serialize by %s" - .format(jobId, index, serializedTask.size, timeTaken, ser.getClass.getName)) + logInfo("Serialized task %s:%d as %d bytes in %d ms".format( + taskSet.id, index, serializedTask.limit, timeTaken)) - val taskName = "task %d:%d".format(jobId, index) - return Some(TaskInfo.newBuilder() + val taskName = "task %s:%d".format(taskSet.id, index) + return Some(MTaskInfo.newBuilder() .setTaskId(taskId) - .setSlaveId(offer.getSlaveId) + .setSlaveId(SlaveID.newBuilder().setValue(slaveId)) .setExecutor(sched.executorInfo) .setName(taskName) .addResources(cpuRes) @@ -213,18 +259,21 @@ class SimpleJob( def taskFinished(status: TaskStatus) { val tid = status.getTaskId.getValue - val index = tidToIndex(tid) + val info = taskInfos(tid) + val index = info.index + info.markSuccessful() if (!finished(index)) { tasksFinished += 1 - logInfo("Finished TID %s (progress: %d/%d)".format(tid, tasksFinished, numTasks)) - // Deserialize task result - val result = ser.deserialize[TaskResult[_]]( - status.getData.toByteArray, getClass.getClassLoader) - sched.taskEnded(tasks(index), Success, result.value, result.accumUpdates) + logInfo("Finished TID %s in %d ms (progress: %d/%d)".format( + tid, info.duration, tasksFinished, numTasks)) + // Deserialize task result and pass it to the scheduler + val result = ser.deserialize[TaskResult[_]](status.getData.asReadOnlyByteBuffer) + sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates) // Mark finished and stop if we've finished all the tasks finished(index) = true - if (tasksFinished == numTasks) - sched.jobFinished(this) + if (tasksFinished == numTasks) { + sched.taskSetFinished(this) + } } else { logInfo("Ignoring task-finished event for TID " + tid + " because task " + index + " is already finished") @@ -233,30 +282,29 @@ class SimpleJob( def taskLost(status: TaskStatus) { val tid = status.getTaskId.getValue - val index = tidToIndex(tid) + val info = taskInfos(tid) + val index = info.index + info.markFailed() if (!finished(index)) { - logInfo("Lost TID %s (task %d:%d)".format(tid, jobId, index)) - launched(index) = false - tasksLaunched -= 1 + logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) + copiesRunning(index) -= 1 // Check if the problem is a map output fetch failure. In that case, this // task will never succeed on any node, so tell the scheduler about it. if (status.getData != null && status.getData.size > 0) { - val reason = ser.deserialize[TaskEndReason]( - status.getData.toByteArray, getClass.getClassLoader) + val reason = ser.deserialize[TaskEndReason](status.getData.asReadOnlyByteBuffer) reason match { case fetchFailed: FetchFailed => - logInfo("Loss was due to fetch failure from " + fetchFailed.serverUri) - sched.taskEnded(tasks(index), fetchFailed, null, null) + logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress) + sched.listener.taskEnded(tasks(index), fetchFailed, null, null) finished(index) = true tasksFinished += 1 - if (tasksFinished == numTasks) { - sched.jobFinished(this) - } + sched.taskSetFinished(this) return + case ef: ExceptionFailure => val key = ef.exception.toString val now = System.currentTimeMillis - val (printFull, dupCount) = + val (printFull, dupCount) = { if (recentExceptions.contains(key)) { val (dupCount, printTime) = recentExceptions(key) if (now - printTime > EXCEPTION_PRINT_INTERVAL) { @@ -267,32 +315,28 @@ class SimpleJob( (false, dupCount + 1) } } else { - recentExceptions += Tuple(key, (0, now)) + recentExceptions(key) = (0, now) (true, 0) } - + } if (printFull) { - val stackTrace = - for (elem <- ef.exception.getStackTrace) - yield "\tat %s".format(elem.toString) - logInfo("Loss was due to %s\n%s".format( - ef.exception.toString, stackTrace.mkString("\n"))) + val locs = ef.exception.getStackTrace.map(loc => "\tat %s".format(loc.toString)) + logInfo("Loss was due to %s\n%s".format(ef.exception.toString, locs.mkString("\n"))) } else { - logInfo("Loss was due to %s [duplicate %d]".format( - ef.exception.toString, dupCount)) + logInfo("Loss was due to %s [duplicate %d]".format(ef.exception.toString, dupCount)) } + case _ => {} } } - // On other failures, re-enqueue the task as pending for a max number of retries + // On non-fetch failures, re-enqueue the task as pending for a max number of retries addPendingTask(index) - // Count attempts only on FAILED and LOST state (not on KILLED) - if (status.getState == TaskState.TASK_FAILED || - status.getState == TaskState.TASK_LOST) { + // Count failed attempts only on FAILED and LOST state (not on KILLED) + if (status.getState == TaskState.TASK_FAILED || status.getState == TaskState.TASK_LOST) { numFailures(index) += 1 if (numFailures(index) > MAX_TASK_FAILURES) { - logError("Task %d:%d failed more than %d times; aborting job".format( - jobId, index, MAX_TASK_FAILURES)) + logError("Task %s:%d failed more than %d times; aborting job".format( + taskSet.id, index, MAX_TASK_FAILURES)) abort("Task %d failed more than %d times".format(index, MAX_TASK_FAILURES)) } } @@ -311,6 +355,71 @@ class SimpleJob( failed = true causeOfFailure = message // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.jobFinished(this) + sched.taskSetFinished(this) + } + + def hostLost(hostname: String) { + logInfo("Re-queueing tasks for " + hostname) + // If some task has preferred locations only on hostname, put it in the no-prefs list + // to avoid the wait from delay scheduling + for (index <- getPendingTasksForHost(hostname)) { + val newLocs = tasks(index).preferredLocations.toSet & sched.hostsAlive + if (newLocs.isEmpty) { + pendingTasksWithNoPrefs += index + } + } + // Also re-enqueue any tasks that ran on the failed host if this is a shuffle map stage + if (tasks(0).isInstanceOf[ShuffleMapTask]) { + for ((tid, info) <- taskInfos if info.host == hostname) { + val index = taskInfos(tid).index + if (finished(index)) { + finished(index) = false + copiesRunning(index) -= 1 + tasksFinished -= 1 + addPendingTask(index) + // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our + // stage finishes when a total of tasks.size tasks finish. + sched.listener.taskEnded(tasks(index), Resubmitted, null, null) + } + } + } + } + + /** + * Check for tasks to be speculated and return true if there are any. This is called periodically + * by the MesosScheduler. + * + * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that + * we don't scan the whole task set. It might also help to make this sorted by launch time. + */ + def checkSpeculatableTasks(): Boolean = { + // Can't speculate if we only have one task, or if all tasks have finished. + if (numTasks == 1 || tasksFinished == numTasks) { + return false + } + var foundTasks = false + val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt + logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) + if (tasksFinished >= minFinishedForSpeculation) { + val time = System.currentTimeMillis() + val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray + Arrays.sort(durations) + val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1)) + val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100) + // TODO: Threshold should also look at standard deviation of task durations and have a lower + // bound based on that. + logDebug("Task length threshold for speculation: " + threshold) + for ((tid, info) <- taskInfos) { + val index = info.index + if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && + !speculatableTasks.contains(index)) { + logInfo("Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format( + taskSet.id, index, info.host, threshold)) + speculatableTasks += index + foundTasks = true + } + } + } + return foundTasks } } diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala new file mode 100644 index 0000000000000000000000000000000000000000..367c79dd7655336188097ad07e3d792b4333374b --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -0,0 +1,507 @@ +package spark.storage + +import java.io._ +import java.nio._ +import java.nio.channels.FileChannel.MapMode +import java.util.{HashMap => JHashMap} +import java.util.LinkedHashMap +import java.util.UUID +import java.util.Collections + +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.Future +import scala.actors.Futures.future +import scala.actors.remote._ +import scala.actors.remote.RemoteActor._ +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.JavaConversions._ + +import it.unimi.dsi.fastutil.io._ + +import spark.CacheTracker +import spark.Logging +import spark.Serializer +import spark.SizeEstimator +import spark.SparkEnv +import spark.SparkException +import spark.Utils +import spark.network._ + +class BlockManagerId(var ip: String, var port: Int) extends Externalizable { + def this() = this(null, 0) + + override def writeExternal(out: ObjectOutput) { + out.writeUTF(ip) + out.writeInt(port) + } + + override def readExternal(in: ObjectInput) { + ip = in.readUTF() + port = in.readInt() + } + + override def toString = "BlockManagerId(" + ip + ", " + port + ")" + + override def hashCode = ip.hashCode * 41 + port + + override def equals(that: Any) = that match { + case id: BlockManagerId => port == id.port && ip == id.ip + case _ => false + } +} + + +case class BlockException(blockId: String, message: String, ex: Exception = null) extends Exception(message) + + +class BlockLocker(numLockers: Int) { + private val hashLocker = Array.fill(numLockers)(new Object()) + + def getLock(blockId: String): Object = { + return hashLocker(Math.abs(blockId.hashCode % numLockers)) + } +} + + +/** + * A start towards a block manager class. This will eventually be used for both RDD persistence + * and shuffle outputs. + * + * TODO: Should make the communication with Master or Peers code more robust and log friendly. + */ +class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging { + + private val NUM_LOCKS = 337 + private val locker = new BlockLocker(NUM_LOCKS) + + private val storageLevels = Collections.synchronizedMap(new JHashMap[String, StorageLevel]) + + private val memoryStore: BlockStore = new MemoryStore(this, maxMemory) + private val diskStore: BlockStore = new DiskStore(this, + System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))) + + val connectionManager = new ConnectionManager(0) + + val connectionManagerId = connectionManager.id + val blockManagerId = new BlockManagerId(connectionManagerId.host, connectionManagerId.port) + + // TODO(Haoyuan): This will be removed after cacheTracker is removed from the code base. + var cacheTracker: CacheTracker = null + + initLogging() + + initialize() + + /** + * Construct a BlockManager with a memory limit set based on system properties. + */ + def this(serializer: Serializer) = + this(BlockManager.getMaxMemoryFromSystemProperties(), serializer) + + /** + * Initialize the BlockManager. Register to the BlockManagerMaster, and start the + * BlockManagerWorker actor. + */ + def initialize() { + BlockManagerMaster.mustRegisterBlockManager( + RegisterBlockManager(blockManagerId, maxMemory, maxMemory)) + BlockManagerWorker.startBlockManagerWorker(this) + } + + /** + * Get locations of the block. + */ + def getLocations(blockId: String): Seq[String] = { + val startTimeMs = System.currentTimeMillis + var managers: Array[BlockManagerId] = BlockManagerMaster.mustGetLocations(GetLocations(blockId)) + val locations = managers.map((manager: BlockManagerId) => { manager.ip }).toSeq + logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs)) + return locations + } + + /** + * Get locations of an array of blocks + */ + def getLocationsMultipleBlockIds(blockIds: Array[String]): Array[Seq[String]] = { + val startTimeMs = System.currentTimeMillis + val locations = BlockManagerMaster.mustGetLocationsMultipleBlockIds( + GetLocationsMultipleBlockIds(blockIds)).map(_.map(_.ip).toSeq).toArray + logDebug("Get multiple block location in " + Utils.getUsedTimeMs(startTimeMs)) + return locations + } + + def getLocal(blockId: String): Option[Iterator[Any]] = { + logDebug("Getting block " + blockId) + locker.getLock(blockId).synchronized { + + // Check storage level of block + val level = storageLevels.get(blockId) + if (level != null) { + logDebug("Level for block " + blockId + " is " + level + " on local machine") + + // Look for the block in memory + if (level.useMemory) { + logDebug("Getting block " + blockId + " from memory") + memoryStore.getValues(blockId) match { + case Some(iterator) => { + logDebug("Block " + blockId + " found in memory") + return Some(iterator) + } + case None => { + logDebug("Block " + blockId + " not found in memory") + } + } + } else { + logDebug("Not getting block " + blockId + " from memory") + } + + // Look for block in disk + if (level.useDisk) { + logDebug("Getting block " + blockId + " from disk") + diskStore.getValues(blockId) match { + case Some(iterator) => { + logDebug("Block " + blockId + " found in disk") + return Some(iterator) + } + case None => { + throw new Exception("Block " + blockId + " not found in disk") + return None + } + } + } else { + logDebug("Not getting block " + blockId + " from disk") + } + + } else { + logDebug("Level for block " + blockId + " not found") + } + } + return None + } + + def getRemote(blockId: String): Option[Iterator[Any]] = { + // Get locations of block + val locations = BlockManagerMaster.mustGetLocations(GetLocations(blockId)) + + // Get block from remote locations + for (loc <- locations) { + val data = BlockManagerWorker.syncGetBlock( + GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port)) + if (data != null) { + logDebug("Data is not null: " + data) + return Some(dataDeserialize(data)) + } + logDebug("Data is null") + } + logDebug("Data not found") + return None + } + + /** + * Read a block from the block manager. + */ + def get(blockId: String): Option[Iterator[Any]] = { + getLocal(blockId).orElse(getRemote(blockId)) + } + + /** + * Read many blocks from block manager using their BlockManagerIds. + */ + def get(blocksByAddress: Seq[(BlockManagerId, Seq[String])]): HashMap[String, Option[Iterator[Any]]] = { + logDebug("Getting " + blocksByAddress.map(_._2.size).sum + " blocks") + var startTime = System.currentTimeMillis + val blocks = new HashMap[String,Option[Iterator[Any]]]() + val localBlockIds = new ArrayBuffer[String]() + val remoteBlockIds = new ArrayBuffer[String]() + val remoteBlockIdsPerLocation = new HashMap[BlockManagerId, Seq[String]]() + + // Split local and remote blocks + for ((address, blockIds) <- blocksByAddress) { + if (address == blockManagerId) { + localBlockIds ++= blockIds + } else { + remoteBlockIds ++= blockIds + remoteBlockIdsPerLocation(address) = blockIds + } + } + + // Start getting remote blocks + val remoteBlockFutures = remoteBlockIdsPerLocation.toSeq.map { case (bmId, bIds) => + val cmId = ConnectionManagerId(bmId.ip, bmId.port) + val blockMessages = bIds.map(bId => BlockMessage.fromGetBlock(GetBlock(bId))) + val blockMessageArray = new BlockMessageArray(blockMessages) + val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) + (cmId, future) + } + logDebug("Started remote gets for " + remoteBlockIds.size + " blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + + // Get the local blocks while remote blocks are being fetched + startTime = System.currentTimeMillis + localBlockIds.foreach(id => { + get(id) match { + case Some(block) => { + blocks.update(id, Some(block)) + logDebug("Got local block " + id) + } + case None => { + throw new BlockException(id, "Could not get block " + id + " from local machine") + } + } + }) + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + + // wait for and gather all the remote blocks + for ((cmId, future) <- remoteBlockFutures) { + var count = 0 + val oneBlockId = remoteBlockIdsPerLocation(new BlockManagerId(cmId.host, cmId.port)).first + future() match { + case Some(message) => { + val bufferMessage = message.asInstanceOf[BufferMessage] + val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) + blockMessageArray.foreach(blockMessage => { + if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { + throw new BlockException(oneBlockId, "Unexpected message received from " + cmId) + } + val buffer = blockMessage.getData() + val blockId = blockMessage.getId() + val block = dataDeserialize(buffer) + blocks.update(blockId, Some(block)) + logDebug("Got remote block " + blockId + " in " + Utils.getUsedTimeMs(startTime)) + count += 1 + }) + } + case None => { + throw new BlockException(oneBlockId, "Could not get blocks from " + cmId) + } + } + logDebug("Got remote " + count + " blocks from " + cmId.host + " in " + Utils.getUsedTimeMs(startTime) + " ms") + } + + logDebug("Got all blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + return blocks + } + + /** + * Write a new block to the block manager. + */ + def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean = true) { + if (!level.useDisk && !level.useMemory) { + throw new IllegalArgumentException("Storage level has neither useMemory nor useDisk set") + } + + val startTimeMs = System.currentTimeMillis + var bytes: ByteBuffer = null + + locker.getLock(blockId).synchronized { + logDebug("Put for block " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + + " to get into synchronized block") + + // Check and warn if block with same id already exists + if (storageLevels.get(blockId) != null) { + logWarning("Block " + blockId + " already exists in local machine") + return + } + + // Store the storage level + storageLevels.put(blockId, level) + + if (level.useMemory && level.useDisk) { + // If saving to both memory and disk, then serialize only once + memoryStore.putValues(blockId, values, level) match { + case Left(newValues) => + diskStore.putValues(blockId, newValues, level) match { + case Right(newBytes) => bytes = newBytes + case _ => throw new Exception("Unexpected return value") + } + case Right(newBytes) => + bytes = newBytes + diskStore.putBytes(blockId, newBytes, level) + } + } else if (level.useMemory) { + // If only save to memory + memoryStore.putValues(blockId, values, level) match { + case Right(newBytes) => bytes = newBytes + case _ => + } + } else { + // If only save to disk + diskStore.putValues(blockId, values, level) match { + case Right(newBytes) => bytes = newBytes + case _ => throw new Exception("Unexpected return value") + } + } + + if (tellMaster) { + notifyMaster(HeartBeat(blockManagerId, blockId, level, 0, 0)) + logDebug("Put block " + blockId + " after notifying the master " + Utils.getUsedTimeMs(startTimeMs)) + } + } + + // Replicate block if required + if (level.replication > 1) { + if (bytes == null) { + bytes = dataSerialize(values) // serialize the block if not already done + } + replicate(blockId, bytes, level) + } + + // TODO(Haoyuan): This code will be removed when CacheTracker is gone. + if (blockId.startsWith("rdd")) { + notifyTheCacheTracker(blockId) + } + logDebug("Put block " + blockId + " after notifying the CacheTracker " + Utils.getUsedTimeMs(startTimeMs)) + } + + + def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) { + val startTime = System.currentTimeMillis + if (!level.useDisk && !level.useMemory) { + throw new IllegalArgumentException("Storage level has neither useMemory nor useDisk set") + } else if (level.deserialized) { + throw new IllegalArgumentException("Storage level cannot have deserialized when putBytes is used") + } + val replicationFuture = if (level.replication > 1) { + future { + replicate(blockId, bytes, level) + } + } else { + null + } + + locker.getLock(blockId).synchronized { + logDebug("PutBytes for block " + blockId + " used " + Utils.getUsedTimeMs(startTime) + + " to get into synchronized block") + if (storageLevels.get(blockId) != null) { + logWarning("Block " + blockId + " already exists") + return + } + storageLevels.put(blockId, level) + + if (level.useMemory) { + memoryStore.putBytes(blockId, bytes, level) + } + if (level.useDisk) { + diskStore.putBytes(blockId, bytes, level) + } + if (tellMaster) { + notifyMaster(HeartBeat(blockManagerId, blockId, level, 0, 0)) + } + } + + if (blockId.startsWith("rdd")) { + notifyTheCacheTracker(blockId) + } + + if (level.replication > 1) { + if (replicationFuture == null) { + throw new Exception("Unexpected") + } + replicationFuture() + } + + val finishTime = System.currentTimeMillis + if (level.replication > 1) { + logDebug("PutBytes with replication took " + (finishTime - startTime) + " ms") + } else { + logDebug("PutBytes without replication took " + (finishTime - startTime) + " ms") + } + + } + + private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { + val tLevel: StorageLevel = + new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) + var peers: Array[BlockManagerId] = BlockManagerMaster.mustGetPeers( + GetPeers(blockManagerId, level.replication - 1)) + for (peer: BlockManagerId <- peers) { + val start = System.nanoTime + logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is " + + data.array().length + " Bytes. To node: " + peer) + if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel), + new ConnectionManagerId(peer.ip, peer.port))) { + logError("Failed to call syncPutBlock to " + peer) + } + logDebug("Replicated BlockId " + blockId + " once used " + + (System.nanoTime - start) / 1e6 + " s; The size of the data is " + + data.array().length + " bytes.") + } + } + + // TODO(Haoyuan): This code will be removed when CacheTracker is gone. + def notifyTheCacheTracker(key: String) { + val rddInfo = key.split(":") + val rddId: Int = rddInfo(1).toInt + val splitIndex: Int = rddInfo(2).toInt + val host = System.getProperty("spark.hostname", Utils.localHostName) + cacheTracker.notifyTheCacheTrackerFromBlockManager(spark.AddedToCache(rddId, splitIndex, host)) + } + + /** + * Read a block consisting of a single object. + */ + def getSingle(blockId: String): Option[Any] = { + get(blockId).map(_.next) + } + + /** + * Write a block consisting of a single object. + */ + def putSingle(blockId: String, value: Any, level: StorageLevel) { + put(blockId, Iterator(value), level) + } + + /** + * Drop block from memory (called when memory store has reached it limit) + */ + def dropFromMemory(blockId: String) { + locker.getLock(blockId).synchronized { + val level = storageLevels.get(blockId) + if (level == null) { + logWarning("Block " + blockId + " cannot be removed from memory as it does not exist") + return + } + if (!level.useMemory) { + logWarning("Block " + blockId + " cannot be removed from memory as it is not in memory") + return + } + memoryStore.remove(blockId) + if (!level.useDisk) { + storageLevels.remove(blockId) + } else { + val newLevel = level.clone + newLevel.useMemory = false + storageLevels.remove(blockId) + storageLevels.put(blockId, newLevel) + } + } + } + + def dataSerialize(values: Iterator[Any]): ByteBuffer = { + /*serializer.newInstance().serializeMany(values)*/ + val byteStream = new FastByteArrayOutputStream(4096) + serializer.newInstance().serializeStream(byteStream).writeAll(values).close() + byteStream.trim() + ByteBuffer.wrap(byteStream.array) + } + + def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = { + /*serializer.newInstance().deserializeMany(bytes)*/ + val ser = serializer.newInstance() + return ser.deserializeStream(new FastByteArrayInputStream(bytes.array())).toIterator + } + + private def notifyMaster(heartBeat: HeartBeat) { + BlockManagerMaster.mustHeartBeat(heartBeat) + } +} + +object BlockManager extends Logging { + def getMaxMemoryFromSystemProperties(): Long = { + val memoryFraction = System.getProperty("spark.storage.memoryFraction", "0.66").toDouble + val bytes = (Runtime.getRuntime.totalMemory * memoryFraction).toLong + logInfo("Maximum memory to use: " + bytes) + bytes + } +} diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala new file mode 100644 index 0000000000000000000000000000000000000000..bd94c185e9a6287f9b3bf2dfe493611951438335 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -0,0 +1,516 @@ +package spark.storage + +import java.io._ + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.util.Random + +import akka.actor._ +import akka.actor.Actor +import akka.actor.Actor._ +import akka.util.duration._ + +import spark.Logging +import spark.Utils + +sealed trait ToBlockManagerMaster + +case class RegisterBlockManager( + blockManagerId: BlockManagerId, + maxMemSize: Long, + maxDiskSize: Long) + extends ToBlockManagerMaster + +class HeartBeat( + var blockManagerId: BlockManagerId, + var blockId: String, + var storageLevel: StorageLevel, + var deserializedSize: Long, + var size: Long) + extends ToBlockManagerMaster + with Externalizable { + + def this() = this(null, null, null, 0, 0) // For deserialization only + + override def writeExternal(out: ObjectOutput) { + blockManagerId.writeExternal(out) + out.writeUTF(blockId) + storageLevel.writeExternal(out) + out.writeInt(deserializedSize.toInt) + out.writeInt(size.toInt) + } + + override def readExternal(in: ObjectInput) { + blockManagerId = new BlockManagerId() + blockManagerId.readExternal(in) + blockId = in.readUTF() + storageLevel = new StorageLevel() + storageLevel.readExternal(in) + deserializedSize = in.readInt() + size = in.readInt() + } +} + +object HeartBeat { + def apply(blockManagerId: BlockManagerId, + blockId: String, + storageLevel: StorageLevel, + deserializedSize: Long, + size: Long): HeartBeat = { + new HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) + } + + + // For pattern-matching + def unapply(h: HeartBeat): Option[(BlockManagerId, String, StorageLevel, Long, Long)] = { + Some((h.blockManagerId, h.blockId, h.storageLevel, h.deserializedSize, h.size)) + } +} + +case class GetLocations( + blockId: String) + extends ToBlockManagerMaster + +case class GetLocationsMultipleBlockIds( + blockIds: Array[String]) + extends ToBlockManagerMaster + +case class GetPeers( + blockManagerId: BlockManagerId, + size: Int) + extends ToBlockManagerMaster + +case class RemoveHost( + host: String) + extends ToBlockManagerMaster + +class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { + class BlockManagerInfo( + timeMs: Long, + maxMem: Long, + maxDisk: Long) { + private var lastSeenMs = timeMs + private var remainedMem = maxMem + private var remainedDisk = maxDisk + private val blocks = new HashMap[String, StorageLevel] + + def updateLastSeenMs() { + lastSeenMs = System.currentTimeMillis() / 1000 + } + + def addBlock(blockId: String, storageLevel: StorageLevel, deserializedSize: Long, size: Long) = + synchronized { + updateLastSeenMs() + + if (blocks.contains(blockId)) { + val oriLevel: StorageLevel = blocks(blockId) + + if (oriLevel.deserialized) { + remainedMem += deserializedSize + } + if (oriLevel.useMemory) { + remainedMem += size + } + if (oriLevel.useDisk) { + remainedDisk += size + } + } + + blocks += (blockId -> storageLevel) + + if (storageLevel.deserialized) { + remainedMem -= deserializedSize + } + if (storageLevel.useMemory) { + remainedMem -= size + } + if (storageLevel.useDisk) { + remainedDisk -= size + } + + if (!(storageLevel.deserialized || storageLevel.useMemory || storageLevel.useDisk)) { + blocks.remove(blockId) + } + } + + def getLastSeenMs(): Long = { + return lastSeenMs + } + + def getRemainedMem(): Long = { + return remainedMem + } + + def getRemainedDisk(): Long = { + return remainedDisk + } + + override def toString(): String = { + return "BlockManagerInfo " + timeMs + " " + remainedMem + " " + remainedDisk + } + } + + private val blockManagerInfo = new HashMap[BlockManagerId, BlockManagerInfo] + private val blockIdMap = new HashMap[String, Pair[Int, HashSet[BlockManagerId]]] + + initLogging() + + def removeHost(host: String) { + logInfo("Trying to remove the host: " + host + " from BlockManagerMaster.") + logInfo("Previous hosts: " + blockManagerInfo.keySet.toSeq) + val ip = host.split(":")(0) + val port = host.split(":")(1) + blockManagerInfo.remove(new BlockManagerId(ip, port.toInt)) + logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq) + self.reply(true) + } + + def receive = { + case RegisterBlockManager(blockManagerId, maxMemSize, maxDiskSize) => + register(blockManagerId, maxMemSize, maxDiskSize) + + case HeartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) => + heartBeat(blockManagerId, blockId, storageLevel, deserializedSize, size) + + case GetLocations(blockId) => + getLocations(blockId) + + case GetLocationsMultipleBlockIds(blockIds) => + getLocationsMultipleBlockIds(blockIds) + + case GetPeers(blockManagerId, size) => + getPeers_Deterministic(blockManagerId, size) + /*getPeers(blockManagerId, size)*/ + + case RemoveHost(host) => + removeHost(host) + + case msg => + logInfo("Got unknown msg: " + msg) + } + + private def register(blockManagerId: BlockManagerId, maxMemSize: Long, maxDiskSize: Long) { + val startTimeMs = System.currentTimeMillis() + val tmp = " " + blockManagerId + " " + logDebug("Got in register 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) + logInfo("Got Register Msg from " + blockManagerId) + if (blockManagerId.ip == Utils.localHostName() && !isLocal) { + logInfo("Got Register Msg from master node, don't register it") + } else { + blockManagerInfo += (blockManagerId -> new BlockManagerInfo( + System.currentTimeMillis() / 1000, maxMemSize, maxDiskSize)) + } + logDebug("Got in register 1" + tmp + Utils.getUsedTimeMs(startTimeMs)) + self.reply(true) + } + + private def heartBeat( + blockManagerId: BlockManagerId, + blockId: String, + storageLevel: StorageLevel, + deserializedSize: Long, + size: Long) { + + val startTimeMs = System.currentTimeMillis() + val tmp = " " + blockManagerId + " " + blockId + " " + logDebug("Got in heartBeat 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) + + if (blockId == null) { + blockManagerInfo(blockManagerId).updateLastSeenMs() + logDebug("Got in heartBeat 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) + self.reply(true) + } + + blockManagerInfo(blockManagerId).addBlock(blockId, storageLevel, deserializedSize, size) + logDebug("Got in heartBeat 2" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) + + var locations: HashSet[BlockManagerId] = null + if (blockIdMap.contains(blockId)) { + locations = blockIdMap(blockId)._2 + } else { + locations = new HashSet[BlockManagerId] + blockIdMap += (blockId -> (storageLevel.replication, locations)) + } + logDebug("Got in heartBeat 3" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) + + if (storageLevel.deserialized || storageLevel.useDisk || storageLevel.useMemory) { + locations += blockManagerId + } else { + locations.remove(blockManagerId) + } + logDebug("Got in heartBeat 4" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) + + if (locations.size == 0) { + blockIdMap.remove(blockId) + } + + logDebug("Got in heartBeat 5" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) + self.reply(true) + } + + private def getLocations(blockId: String) { + val startTimeMs = System.currentTimeMillis() + val tmp = " " + blockId + " " + logDebug("Got in getLocations 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) + if (blockIdMap.contains(blockId)) { + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + res.appendAll(blockIdMap(blockId)._2) + logDebug("Got in getLocations 1" + tmp + " as "+ res.toSeq + " at " + + Utils.getUsedTimeMs(startTimeMs)) + self.reply(res.toSeq) + } else { + logDebug("Got in getLocations 2" + tmp + Utils.getUsedTimeMs(startTimeMs)) + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + self.reply(res) + } + } + + private def getLocationsMultipleBlockIds(blockIds: Array[String]) { + def getLocations(blockId: String): Seq[BlockManagerId] = { + val tmp = blockId + logDebug("Got in getLocationsMultipleBlockIds Sub 0 " + tmp) + if (blockIdMap.contains(blockId)) { + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + res.appendAll(blockIdMap(blockId)._2) + logDebug("Got in getLocationsMultipleBlockIds Sub 1 " + tmp + " " + res.toSeq) + return res.toSeq + } else { + logDebug("Got in getLocationsMultipleBlockIds Sub 2 " + tmp) + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + return res.toSeq + } + } + + logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq) + var res: ArrayBuffer[Seq[BlockManagerId]] = new ArrayBuffer[Seq[BlockManagerId]] + for (blockId <- blockIds) { + res.append(getLocations(blockId)) + } + logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq + " : " + res.toSeq) + self.reply(res.toSeq) + } + + private def getPeers(blockManagerId: BlockManagerId, size: Int) { + val startTimeMs = System.currentTimeMillis() + val tmp = " " + blockManagerId + " " + logDebug("Got in getPeers 0" + tmp + Utils.getUsedTimeMs(startTimeMs)) + var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + res.appendAll(peers) + res -= blockManagerId + val rand = new Random(System.currentTimeMillis()) + logDebug("Got in getPeers 1" + tmp + Utils.getUsedTimeMs(startTimeMs)) + while (res.length > size) { + res.remove(rand.nextInt(res.length)) + } + logDebug("Got in getPeers 2" + tmp + Utils.getUsedTimeMs(startTimeMs)) + self.reply(res.toSeq) + } + + private def getPeers_Deterministic(blockManagerId: BlockManagerId, size: Int) { + val startTimeMs = System.currentTimeMillis() + var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray + var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + + val peersWithIndices = peers.zipWithIndex + val selfIndex = peersWithIndices.find(_._1 == blockManagerId).map(_._2).getOrElse(-1) + if (selfIndex == -1) { + throw new Exception("Self index for " + blockManagerId + " not found") + } + + var index = selfIndex + while (res.size < size) { + index += 1 + if (index == selfIndex) { + throw new Exception("More peer expected than available") + } + res += peers(index % peers.size) + } + val resStr = res.map(_.toString).reduceLeft(_ + ", " + _) + logDebug("Got peers for " + blockManagerId + " as [" + resStr + "]") + self.reply(res.toSeq) + } +} + +object BlockManagerMaster extends Logging { + initLogging() + + val AKKA_ACTOR_NAME: String = "BlockMasterManager" + val REQUEST_RETRY_INTERVAL_MS = 100 + val DEFAULT_MASTER_IP: String = System.getProperty("spark.master.host", "localhost") + val DEFAULT_MASTER_PORT: Int = System.getProperty("spark.master.port", "7077").toInt + val DEFAULT_MANAGER_IP: String = Utils.localHostName() + val DEFAULT_MANAGER_PORT: String = "10902" + + implicit val TIME_OUT_SEC = Actor.Timeout(3000 millis) + var masterActor: ActorRef = null + + def startBlockManagerMaster(isMaster: Boolean, isLocal: Boolean) { + if (isMaster) { + masterActor = actorOf(new BlockManagerMaster(isLocal)) + remote.register(AKKA_ACTOR_NAME, masterActor) + logInfo("Registered BlockManagerMaster Actor: " + DEFAULT_MASTER_IP + ":" + DEFAULT_MASTER_PORT) + masterActor.start() + } else { + masterActor = remote.actorFor(AKKA_ACTOR_NAME, DEFAULT_MASTER_IP, DEFAULT_MASTER_PORT) + } + } + + def notifyADeadHost(host: String) { + (masterActor ? RemoveHost(host + ":" + DEFAULT_MANAGER_PORT)).as[Any] match { + case Some(true) => + logInfo("Removed " + host + " successfully. @ notifyADeadHost") + case Some(oops) => + logError("Failed @ notifyADeadHost: " + oops) + case None => + logError("None @ notifyADeadHost.") + } + } + + def mustRegisterBlockManager(msg: RegisterBlockManager) { + while (! syncRegisterBlockManager(msg)) { + logWarning("Failed to register " + msg) + Thread.sleep(REQUEST_RETRY_INTERVAL_MS) + } + } + + def syncRegisterBlockManager(msg: RegisterBlockManager): Boolean = { + //val masterActor = RemoteActor.select(node, name) + val startTimeMs = System.currentTimeMillis() + val tmp = " msg " + msg + " " + logDebug("Got in syncRegisterBlockManager 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) + + (masterActor ? msg).as[Any] match { + case Some(true) => + logInfo("BlockManager registered successfully @ syncRegisterBlockManager.") + logDebug("Got in syncRegisterBlockManager 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) + return true + case Some(oops) => + logError("Failed @ syncRegisterBlockManager: " + oops) + return false + case None => + logError("None @ syncRegisterBlockManager.") + return false + } + } + + def mustHeartBeat(msg: HeartBeat) { + while (! syncHeartBeat(msg)) { + logWarning("Failed to send heartbeat" + msg) + Thread.sleep(REQUEST_RETRY_INTERVAL_MS) + } + } + + def syncHeartBeat(msg: HeartBeat): Boolean = { + val startTimeMs = System.currentTimeMillis() + val tmp = " msg " + msg + " " + logDebug("Got in syncHeartBeat " + tmp + " 0 " + Utils.getUsedTimeMs(startTimeMs)) + + (masterActor ? msg).as[Any] match { + case Some(true) => + logInfo("Heartbeat sent successfully.") + logDebug("Got in syncHeartBeat " + tmp + " 1 " + Utils.getUsedTimeMs(startTimeMs)) + return true + case Some(oops) => + logError("Failed: " + oops) + return false + case None => + logError("None.") + return false + } + } + + def mustGetLocations(msg: GetLocations): Array[BlockManagerId] = { + var res: Array[BlockManagerId] = syncGetLocations(msg) + while (res == null) { + logInfo("Failed to get locations " + msg) + Thread.sleep(REQUEST_RETRY_INTERVAL_MS) + res = syncGetLocations(msg) + } + return res + } + + def syncGetLocations(msg: GetLocations): Array[BlockManagerId] = { + val startTimeMs = System.currentTimeMillis() + val tmp = " msg " + msg + " " + logDebug("Got in syncGetLocations 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) + + (masterActor ? msg).as[Seq[BlockManagerId]] match { + case Some(arr) => + logDebug("GetLocations successfully.") + logDebug("Got in syncGetLocations 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) + val res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + for (ele <- arr) { + res += ele + } + logDebug("Got in syncGetLocations 2 " + tmp + Utils.getUsedTimeMs(startTimeMs)) + return res.toArray + case None => + logError("GetLocations call returned None.") + return null + } + } + + def mustGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds): + Seq[Seq[BlockManagerId]] = { + var res: Seq[Seq[BlockManagerId]] = syncGetLocationsMultipleBlockIds(msg) + while (res == null) { + logWarning("Failed to GetLocationsMultipleBlockIds " + msg) + Thread.sleep(REQUEST_RETRY_INTERVAL_MS) + res = syncGetLocationsMultipleBlockIds(msg) + } + return res + } + + def syncGetLocationsMultipleBlockIds(msg: GetLocationsMultipleBlockIds): + Seq[Seq[BlockManagerId]] = { + val startTimeMs = System.currentTimeMillis + val tmp = " msg " + msg + " " + logDebug("Got in syncGetLocationsMultipleBlockIds 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) + + (masterActor ? msg).as[Any] match { + case Some(arr: Seq[Seq[BlockManagerId]]) => + logDebug("GetLocationsMultipleBlockIds successfully: " + arr) + logDebug("Got in syncGetLocationsMultipleBlockIds 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) + return arr + case Some(oops) => + logError("Failed: " + oops) + return null + case None => + logInfo("None.") + return null + } + } + + def mustGetPeers(msg: GetPeers): Array[BlockManagerId] = { + var res: Array[BlockManagerId] = syncGetPeers(msg) + while ((res == null) || (res.length != msg.size)) { + logInfo("Failed to get peers " + msg) + Thread.sleep(REQUEST_RETRY_INTERVAL_MS) + res = syncGetPeers(msg) + } + + return res + } + + def syncGetPeers(msg: GetPeers): Array[BlockManagerId] = { + val startTimeMs = System.currentTimeMillis + val tmp = " msg " + msg + " " + logDebug("Got in syncGetPeers 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) + + (masterActor ? msg).as[Seq[BlockManagerId]] match { + case Some(arr) => + logDebug("Got in syncGetPeers 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) + val res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] + logInfo("GetPeers successfully: " + arr.length) + res.appendAll(arr) + logDebug("Got in syncGetPeers 2 " + tmp + Utils.getUsedTimeMs(startTimeMs)) + return res.toArray + case None => + logError("GetPeers call returned None.") + return null + } + } +} diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala new file mode 100644 index 0000000000000000000000000000000000000000..a4cdbd8ddd3aa305263a7792022a973b265b86aa --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala @@ -0,0 +1,142 @@ +package spark.storage + +import java.nio._ + +import scala.actors._ +import scala.actors.Actor._ +import scala.actors.remote._ + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.util.Random + +import spark.Logging +import spark.Utils +import spark.SparkEnv +import spark.network._ + +/** + * This should be changed to use event model late. + */ +class BlockManagerWorker(val blockManager: BlockManager) extends Logging { + initLogging() + + blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive) + + def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { + logDebug("Handling message " + msg) + msg match { + case bufferMessage: BufferMessage => { + try { + logDebug("Handling as a buffer message " + bufferMessage) + val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) + logDebug("Parsed as a block message array") + val responseMessages = blockMessages.map(processBlockMessage _).filter(_ != None).map(_.get) + /*logDebug("Processed block messages")*/ + return Some(new BlockMessageArray(responseMessages).toBufferMessage) + } catch { + case e: Exception => logError("Exception handling buffer message: " + e.getMessage) + return None + } + } + case otherMessage: Any => { + logError("Unknown type message received: " + otherMessage) + return None + } + } + } + + def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = { + blockMessage.getType() match { + case BlockMessage.TYPE_PUT_BLOCK => { + val pB = PutBlock(blockMessage.getId(), blockMessage.getData(), blockMessage.getLevel()) + logInfo("Received [" + pB + "]") + putBlock(pB.id, pB.data, pB.level) + return None + } + case BlockMessage.TYPE_GET_BLOCK => { + val gB = new GetBlock(blockMessage.getId()) + logInfo("Received [" + gB + "]") + val buffer = getBlock(gB.id) + if (buffer == null) { + return None + } + return Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer))) + } + case _ => return None + } + } + + private def putBlock(id: String, bytes: ByteBuffer, level: StorageLevel) { + val startTimeMs = System.currentTimeMillis() + logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes) + blockManager.putBytes(id, bytes, level) + logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) + + " with data size: " + bytes.array().length) + } + + private def getBlock(id: String): ByteBuffer = { + val startTimeMs = System.currentTimeMillis() + logDebug("Getblock " + id + " started from " + startTimeMs) + val block = blockManager.get(id) + val buffer = block match { + case Some(tValues) => { + val values = tValues.asInstanceOf[Iterator[Any]] + val buffer = blockManager.dataSerialize(values) + buffer + } + case None => { + null + } + } + logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) + + " and got buffer " + buffer) + return buffer + } +} + +object BlockManagerWorker extends Logging { + private var blockManagerWorker: BlockManagerWorker = null + private val DATA_TRANSFER_TIME_OUT_MS: Long = 500 + private val REQUEST_RETRY_INTERVAL_MS: Long = 1000 + + initLogging() + + def startBlockManagerWorker(manager: BlockManager) { + blockManagerWorker = new BlockManagerWorker(manager) + } + + def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = { + val blockManager = blockManagerWorker.blockManager + val connectionManager = blockManager.connectionManager + val serializer = blockManager.serializer + val blockMessage = BlockMessage.fromPutBlock(msg) + val blockMessageArray = new BlockMessageArray(blockMessage) + val resultMessage = connectionManager.sendMessageReliablySync( + toConnManagerId, blockMessageArray.toBufferMessage()) + return (resultMessage != None) + } + + def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = { + val blockManager = blockManagerWorker.blockManager + val connectionManager = blockManager.connectionManager + val serializer = blockManager.serializer + val blockMessage = BlockMessage.fromGetBlock(msg) + val blockMessageArray = new BlockMessageArray(blockMessage) + val responseMessage = connectionManager.sendMessageReliablySync( + toConnManagerId, blockMessageArray.toBufferMessage()) + responseMessage match { + case Some(message) => { + val bufferMessage = message.asInstanceOf[BufferMessage] + logDebug("Response message received " + bufferMessage) + BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => { + logDebug("Found " + blockMessage) + return blockMessage.getData + }) + } + case None => logDebug("No response message received"); return null + } + return null + } +} diff --git a/core/src/main/scala/spark/storage/BlockMessage.scala b/core/src/main/scala/spark/storage/BlockMessage.scala new file mode 100644 index 0000000000000000000000000000000000000000..bb128dce7a6b8ad45c476c59d87ccf17c77ab667 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockMessage.scala @@ -0,0 +1,219 @@ +package spark.storage + +import java.nio._ + +import scala.collection.mutable.StringBuilder +import scala.collection.mutable.ArrayBuffer + +import spark._ +import spark.network._ + +case class GetBlock(id: String) +case class GotBlock(id: String, data: ByteBuffer) +case class PutBlock(id: String, data: ByteBuffer, level: StorageLevel) + +class BlockMessage() extends Logging{ + // Un-initialized: typ = 0 + // GetBlock: typ = 1 + // GotBlock: typ = 2 + // PutBlock: typ = 3 + private var typ: Int = BlockMessage.TYPE_NON_INITIALIZED + private var id: String = null + private var data: ByteBuffer = null + private var level: StorageLevel = null + + initLogging() + + def set(getBlock: GetBlock) { + typ = BlockMessage.TYPE_GET_BLOCK + id = getBlock.id + } + + def set(gotBlock: GotBlock) { + typ = BlockMessage.TYPE_GOT_BLOCK + id = gotBlock.id + data = gotBlock.data + } + + def set(putBlock: PutBlock) { + typ = BlockMessage.TYPE_PUT_BLOCK + id = putBlock.id + data = putBlock.data + level = putBlock.level + } + + def set(buffer: ByteBuffer) { + val startTime = System.currentTimeMillis + /* + println() + println("BlockMessage: ") + while(buffer.remaining > 0) { + print(buffer.get()) + } + buffer.rewind() + println() + println() + */ + typ = buffer.getInt() + val idLength = buffer.getInt() + val idBuilder = new StringBuilder(idLength) + for (i <- 1 to idLength) { + idBuilder += buffer.getChar() + } + id = idBuilder.toString() + + logDebug("Set from buffer Result: " + typ + " " + id) + logDebug("Buffer position is " + buffer.position) + if (typ == BlockMessage.TYPE_PUT_BLOCK) { + + val booleanInt = buffer.getInt() + val replication = buffer.getInt() + level = new StorageLevel(booleanInt, replication) + + val dataLength = buffer.getInt() + data = ByteBuffer.allocate(dataLength) + if (dataLength != buffer.remaining) { + throw new Exception("Error parsing buffer") + } + data.put(buffer) + data.flip() + logDebug("Set from buffer Result 2: " + level + " " + data) + } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { + + val dataLength = buffer.getInt() + logDebug("Data length is "+ dataLength) + logDebug("Buffer position is " + buffer.position) + data = ByteBuffer.allocate(dataLength) + if (dataLength != buffer.remaining) { + throw new Exception("Error parsing buffer") + } + data.put(buffer) + data.flip() + logDebug("Set from buffer Result 3: " + data) + } + + val finishTime = System.currentTimeMillis + logDebug("Converted " + id + " from bytebuffer in " + (finishTime - startTime) / 1000.0 + " s") + } + + def set(bufferMsg: BufferMessage) { + val buffer = bufferMsg.buffers.apply(0) + buffer.clear() + set(buffer) + } + + def getType(): Int = { + return typ + } + + def getId(): String = { + return id + } + + def getData(): ByteBuffer = { + return data + } + + def getLevel(): StorageLevel = { + return level + } + + def toBufferMessage(): BufferMessage = { + val startTime = System.currentTimeMillis + val buffers = new ArrayBuffer[ByteBuffer]() + var buffer = ByteBuffer.allocate(4 + 4 + id.length() * 2) + buffer.putInt(typ).putInt(id.length()) + id.foreach((x: Char) => buffer.putChar(x)) + buffer.flip() + buffers += buffer + + if (typ == BlockMessage.TYPE_PUT_BLOCK) { + buffer = ByteBuffer.allocate(8).putInt(level.toInt()).putInt(level.replication) + buffer.flip() + buffers += buffer + + buffer = ByteBuffer.allocate(4).putInt(data.remaining) + buffer.flip() + buffers += buffer + + buffers += data + } else if (typ == BlockMessage.TYPE_GOT_BLOCK) { + buffer = ByteBuffer.allocate(4).putInt(data.remaining) + buffer.flip() + buffers += buffer + + buffers += data + } + + logDebug("Start to log buffers.") + buffers.foreach((x: ByteBuffer) => logDebug("" + x)) + /* + println() + println("BlockMessage: ") + buffers.foreach(b => { + while(b.remaining > 0) { + print(b.get()) + } + b.rewind() + }) + println() + println() + */ + val finishTime = System.currentTimeMillis + logDebug("Converted " + id + " to buffer message in " + (finishTime - startTime) / 1000.0 + " s") + return Message.createBufferMessage(buffers) + } + + override def toString(): String = { + "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + + ", data = " + (if (data != null) data.remaining.toString else "null") + "]" + } +} + +object BlockMessage { + val TYPE_NON_INITIALIZED: Int = 0 + val TYPE_GET_BLOCK: Int = 1 + val TYPE_GOT_BLOCK: Int = 2 + val TYPE_PUT_BLOCK: Int = 3 + + def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = { + val newBlockMessage = new BlockMessage() + newBlockMessage.set(bufferMessage) + newBlockMessage + } + + def fromByteBuffer(buffer: ByteBuffer): BlockMessage = { + val newBlockMessage = new BlockMessage() + newBlockMessage.set(buffer) + newBlockMessage + } + + def fromGetBlock(getBlock: GetBlock): BlockMessage = { + val newBlockMessage = new BlockMessage() + newBlockMessage.set(getBlock) + newBlockMessage + } + + def fromGotBlock(gotBlock: GotBlock): BlockMessage = { + val newBlockMessage = new BlockMessage() + newBlockMessage.set(gotBlock) + newBlockMessage + } + + def fromPutBlock(putBlock: PutBlock): BlockMessage = { + val newBlockMessage = new BlockMessage() + newBlockMessage.set(putBlock) + newBlockMessage + } + + def main(args: Array[String]) { + val B = new BlockMessage() + B.set(new PutBlock("ABC", ByteBuffer.allocate(10), StorageLevel.DISK_AND_MEMORY_2)) + val bMsg = B.toBufferMessage() + val C = new BlockMessage() + C.set(bMsg) + + println(B.getId() + " " + B.getLevel()) + println(C.getId() + " " + C.getLevel()) + } +} diff --git a/core/src/main/scala/spark/storage/BlockMessageArray.scala b/core/src/main/scala/spark/storage/BlockMessageArray.scala new file mode 100644 index 0000000000000000000000000000000000000000..5f411d34884e12871405b12b24dcb0765af01427 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockMessageArray.scala @@ -0,0 +1,140 @@ +package spark.storage +import java.nio._ + +import scala.collection.mutable.StringBuilder +import scala.collection.mutable.ArrayBuffer + +import spark._ +import spark.network._ + +class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockMessage] with Logging { + + def this(bm: BlockMessage) = this(Array(bm)) + + def this() = this(null.asInstanceOf[Seq[BlockMessage]]) + + def apply(i: Int) = blockMessages(i) + + def iterator = blockMessages.iterator + + def length = blockMessages.length + + initLogging() + + def set(bufferMessage: BufferMessage) { + val startTime = System.currentTimeMillis + val newBlockMessages = new ArrayBuffer[BlockMessage]() + val buffer = bufferMessage.buffers(0) + buffer.clear() + /* + println() + println("BlockMessageArray: ") + while(buffer.remaining > 0) { + print(buffer.get()) + } + buffer.rewind() + println() + println() + */ + while(buffer.remaining() > 0) { + val size = buffer.getInt() + logDebug("Creating block message of size " + size + " bytes") + val newBuffer = buffer.slice() + newBuffer.clear() + newBuffer.limit(size) + logDebug("Trying to convert buffer " + newBuffer + " to block message") + val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer) + logDebug("Created " + newBlockMessage) + newBlockMessages += newBlockMessage + buffer.position(buffer.position() + size) + } + val finishTime = System.currentTimeMillis + logDebug("Converted block message array from buffer message in " + (finishTime - startTime) / 1000.0 + " s") + this.blockMessages = newBlockMessages + } + + def toBufferMessage(): BufferMessage = { + val buffers = new ArrayBuffer[ByteBuffer]() + + blockMessages.foreach(blockMessage => { + val bufferMessage = blockMessage.toBufferMessage + logDebug("Adding " + blockMessage) + val sizeBuffer = ByteBuffer.allocate(4).putInt(bufferMessage.size) + sizeBuffer.flip + buffers += sizeBuffer + buffers ++= bufferMessage.buffers + logDebug("Added " + bufferMessage) + }) + + logDebug("Buffer list:") + buffers.foreach((x: ByteBuffer) => logDebug("" + x)) + /* + println() + println("BlockMessageArray: ") + buffers.foreach(b => { + while(b.remaining > 0) { + print(b.get()) + } + b.rewind() + }) + println() + println() + */ + return Message.createBufferMessage(buffers) + } +} + +object BlockMessageArray { + + def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { + val newBlockMessageArray = new BlockMessageArray() + newBlockMessageArray.set(bufferMessage) + newBlockMessageArray + } + + def main(args: Array[String]) { + val blockMessages = + (0 until 10).map(i => { + if (i % 2 == 0) { + val buffer = ByteBuffer.allocate(100) + buffer.clear + BlockMessage.fromPutBlock(PutBlock(i.toString, buffer, StorageLevel.MEMORY_ONLY)) + } else { + BlockMessage.fromGetBlock(GetBlock(i.toString)) + } + }) + val blockMessageArray = new BlockMessageArray(blockMessages) + println("Block message array created") + + val bufferMessage = blockMessageArray.toBufferMessage + println("Converted to buffer message") + + val totalSize = bufferMessage.size + val newBuffer = ByteBuffer.allocate(totalSize) + newBuffer.clear() + bufferMessage.buffers.foreach(buffer => { + newBuffer.put(buffer) + buffer.rewind() + }) + newBuffer.flip + val newBufferMessage = Message.createBufferMessage(newBuffer) + println("Copied to new buffer message, size = " + newBufferMessage.size) + + val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) + println("Converted back to block message array") + newBlockMessageArray.foreach(blockMessage => { + blockMessage.getType() match { + case BlockMessage.TYPE_PUT_BLOCK => { + val pB = PutBlock(blockMessage.getId(), blockMessage.getData(), blockMessage.getLevel()) + println(pB) + } + case BlockMessage.TYPE_GET_BLOCK => { + val gB = new GetBlock(blockMessage.getId()) + println(gB) + } + } + }) + } +} + + diff --git a/core/src/main/scala/spark/storage/BlockStore.scala b/core/src/main/scala/spark/storage/BlockStore.scala new file mode 100644 index 0000000000000000000000000000000000000000..0584cc2d4f3992db7b43f445342f0c59c7eed835 --- /dev/null +++ b/core/src/main/scala/spark/storage/BlockStore.scala @@ -0,0 +1,282 @@ +package spark.storage + +import spark.{Utils, Logging, Serializer, SizeEstimator} + +import scala.collection.mutable.ArrayBuffer + +import java.io.{File, RandomAccessFile} +import java.nio.ByteBuffer +import java.nio.channels.FileChannel.MapMode +import java.util.{UUID, LinkedHashMap} +import java.util.concurrent.Executors + +import it.unimi.dsi.fastutil.io._ + +/** + * Abstract class to store blocks + */ +abstract class BlockStore(blockManager: BlockManager) extends Logging { + initLogging() + + def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) + + def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] + + def getBytes(blockId: String): Option[ByteBuffer] + + def getValues(blockId: String): Option[Iterator[Any]] + + def remove(blockId: String) + + def dataSerialize(values: Iterator[Any]): ByteBuffer = blockManager.dataSerialize(values) + + def dataDeserialize(bytes: ByteBuffer): Iterator[Any] = blockManager.dataDeserialize(bytes) +} + +/** + * Class to store blocks in memory + */ +class MemoryStore(blockManager: BlockManager, maxMemory: Long) + extends BlockStore(blockManager) { + + class Entry(var value: Any, val size: Long, val deserialized: Boolean) + + private val memoryStore = new LinkedHashMap[String, Entry](32, 0.75f, true) + private var currentMemory = 0L + + private val blockDropper = Executors.newSingleThreadExecutor() + + def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + if (level.deserialized) { + bytes.rewind() + val values = dataDeserialize(bytes) + val elements = new ArrayBuffer[Any] + elements ++= values + val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef]) + ensureFreeSpace(sizeEstimate) + val entry = new Entry(elements, sizeEstimate, true) + memoryStore.synchronized { memoryStore.put(blockId, entry) } + currentMemory += sizeEstimate + logDebug("Block " + blockId + " stored as values to memory") + } else { + val entry = new Entry(bytes, bytes.array().length, false) + ensureFreeSpace(bytes.array.length) + memoryStore.synchronized { memoryStore.put(blockId, entry) } + currentMemory += bytes.array().length + logDebug("Block " + blockId + " stored as " + bytes.array().length + " bytes to memory") + } + } + + def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = { + if (level.deserialized) { + val elements = new ArrayBuffer[Any] + elements ++= values + val sizeEstimate = SizeEstimator.estimate(elements.asInstanceOf[AnyRef]) + ensureFreeSpace(sizeEstimate) + val entry = new Entry(elements, sizeEstimate, true) + memoryStore.synchronized { memoryStore.put(blockId, entry) } + currentMemory += sizeEstimate + logDebug("Block " + blockId + " stored as values to memory") + return Left(elements.iterator) + } else { + val bytes = dataSerialize(values) + ensureFreeSpace(bytes.array().length) + val entry = new Entry(bytes, bytes.array().length, false) + memoryStore.synchronized { memoryStore.put(blockId, entry) } + currentMemory += bytes.array().length + logDebug("Block " + blockId + " stored as " + bytes.array.length + " bytes to memory") + return Right(bytes) + } + } + + def getBytes(blockId: String): Option[ByteBuffer] = { + throw new UnsupportedOperationException("Not implemented") + } + + def getValues(blockId: String): Option[Iterator[Any]] = { + val entry = memoryStore.synchronized { memoryStore.get(blockId) } + if (entry == null) { + return None + } + if (entry.deserialized) { + return Some(entry.value.asInstanceOf[ArrayBuffer[Any]].toIterator) + } else { + return Some(dataDeserialize(entry.value.asInstanceOf[ByteBuffer])) + } + } + + def remove(blockId: String) { + memoryStore.synchronized { + val entry = memoryStore.get(blockId) + if (entry != null) { + memoryStore.remove(blockId) + currentMemory -= entry.size + logDebug("Block " + blockId + " of size " + entry.size + " dropped from memory") + } else { + logWarning("Block " + blockId + " could not be removed as it doesnt exist") + } + } + } + + private def drop(blockId: String) { + blockDropper.submit(new Runnable() { + def run() { + blockManager.dropFromMemory(blockId) + } + }) + } + + private def ensureFreeSpace(space: Long) { + logInfo("ensureFreeSpace(%d) called with curMem=%d, maxMem=%d".format( + space, currentMemory, maxMemory)) + + val droppedBlockIds = new ArrayBuffer[String]() + var droppedMemory = 0L + + memoryStore.synchronized { + val iter = memoryStore.entrySet().iterator() + while (maxMemory - (currentMemory - droppedMemory) < space && iter.hasNext) { + val pair = iter.next() + val blockId = pair.getKey + droppedBlockIds += blockId + droppedMemory += pair.getValue.size + logDebug("Decided to drop " + blockId) + } + } + + for (blockId <- droppedBlockIds) { + drop(blockId) + } + + droppedBlockIds.clear + } +} + + +/** + * Class to store blocks in disk + */ +class DiskStore(blockManager: BlockManager, rootDirs: String) + extends BlockStore(blockManager) { + + val MAX_DIR_CREATION_ATTEMPTS: Int = 10 + val localDirs = createLocalDirs() + var lastLocalDirUsed = 0 + + addShutdownHook() + + def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + logDebug("Attempting to put block " + blockId) + val startTime = System.currentTimeMillis + val file = createFile(blockId) + if (file != null) { + val channel = new RandomAccessFile(file, "rw").getChannel() + val buffer = channel.map(MapMode.READ_WRITE, 0, bytes.array.length) + buffer.put(bytes.array) + channel.close() + val finishTime = System.currentTimeMillis + logDebug("Block " + blockId + " stored to file of " + bytes.array.length + " bytes to disk in " + (finishTime - startTime) + " ms") + } else { + logError("File not created for block " + blockId) + } + } + + def putValues(blockId: String, values: Iterator[Any], level: StorageLevel): Either[Iterator[Any], ByteBuffer] = { + val bytes = dataSerialize(values) + logDebug("Converted block " + blockId + " to " + bytes.array.length + " bytes") + putBytes(blockId, bytes, level) + return Right(bytes) + } + + def getBytes(blockId: String): Option[ByteBuffer] = { + val file = getFile(blockId) + val length = file.length().toInt + val channel = new RandomAccessFile(file, "r").getChannel() + val bytes = ByteBuffer.allocate(length) + bytes.put(channel.map(MapMode.READ_WRITE, 0, length)) + return Some(bytes) + } + + def getValues(blockId: String): Option[Iterator[Any]] = { + val file = getFile(blockId) + val length = file.length().toInt + val channel = new RandomAccessFile(file, "r").getChannel() + val bytes = channel.map(MapMode.READ_ONLY, 0, length) + val buffer = dataDeserialize(bytes) + channel.close() + return Some(buffer) + } + + def remove(blockId: String) { + throw new UnsupportedOperationException("Not implemented") + } + + private def createFile(blockId: String): File = { + val file = getFile(blockId) + if (file == null) { + lastLocalDirUsed = (lastLocalDirUsed + 1) % localDirs.size + val newFile = new File(localDirs(lastLocalDirUsed), blockId) + newFile.getParentFile.mkdirs() + return newFile + } else { + logError("File for block " + blockId + " already exists on disk, " + file) + return null + } + } + + private def getFile(blockId: String): File = { + logDebug("Getting file for block " + blockId) + // Search for the file in all the local directories, only one of them should have the file + val files = localDirs.map(localDir => new File(localDir, blockId)).filter(_.exists) + if (files.size > 1) { + throw new Exception("Multiple files for same block " + blockId + " exists: " + + files.map(_.toString).reduceLeft(_ + ", " + _)) + return null + } else if (files.size == 0) { + return null + } else { + logDebug("Got file " + files(0) + " of size " + files(0).length + " bytes") + return files(0) + } + } + + private def createLocalDirs(): Seq[File] = { + logDebug("Creating local directories at root dirs '" + rootDirs + "'") + rootDirs.split("[;,:]").map(rootDir => { + var foundLocalDir: Boolean = false + var localDir: File = null + var localDirUuid: UUID = null + var tries = 0 + while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) { + tries += 1 + try { + localDirUuid = UUID.randomUUID() + localDir = new File(rootDir, "spark-local-" + localDirUuid) + if (!localDir.exists) { + localDir.mkdirs() + foundLocalDir = true + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create local dir failed", e) + } + } + if (!foundLocalDir) { + logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + + " attempts to create local dir in " + rootDir) + System.exit(1) + } + logDebug("Created local directory at " + localDir) + localDir + }) + } + + private def addShutdownHook() { + Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { + override def run() { + logDebug("Shutdown hook called") + localDirs.foreach(localDir => Utils.deleteRecursively(localDir)) + } + }) + } +} diff --git a/core/src/main/scala/spark/storage/StorageLevel.scala b/core/src/main/scala/spark/storage/StorageLevel.scala new file mode 100644 index 0000000000000000000000000000000000000000..a2833a709063986d773dd356976392cf3a9c5a08 --- /dev/null +++ b/core/src/main/scala/spark/storage/StorageLevel.scala @@ -0,0 +1,78 @@ +package spark.storage + +import java.io._ + +class StorageLevel( + var useDisk: Boolean, + var useMemory: Boolean, + var deserialized: Boolean, + var replication: Int = 1) + extends Externalizable { + + // TODO: Also add fields for caching priority, dataset ID, and flushing. + + def this(booleanInt: Int, replication: Int) { + this(((booleanInt & 4) != 0), + ((booleanInt & 2) != 0), + ((booleanInt & 1) != 0), + replication) + } + + def this() = this(false, true, false) // For deserialization + + override def clone(): StorageLevel = new StorageLevel( + this.useDisk, this.useMemory, this.deserialized, this.replication) + + override def equals(other: Any): Boolean = other match { + case s: StorageLevel => + s.useDisk == useDisk && + s.useMemory == useMemory && + s.deserialized == deserialized && + s.replication == replication + case _ => + false + } + + def toInt(): Int = { + var ret = 0 + if (useDisk) { + ret += 4 + } + if (useMemory) { + ret += 2 + } + if (deserialized) { + ret += 1 + } + return ret + } + + override def writeExternal(out: ObjectOutput) { + out.writeByte(toInt().toByte) + out.writeByte(replication.toByte) + } + + override def readExternal(in: ObjectInput) { + val flags = in.readByte() + useDisk = (flags & 4) != 0 + useMemory = (flags & 2) != 0 + deserialized = (flags & 1) != 0 + replication = in.readByte() + } + + override def toString(): String = + "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication) +} + +object StorageLevel { + val NONE = new StorageLevel(false, false, false) + val DISK_ONLY = new StorageLevel(true, false, false) + val MEMORY_ONLY = new StorageLevel(false, true, false) + val MEMORY_ONLY_2 = new StorageLevel(false, true, false, 2) + val MEMORY_ONLY_DESER = new StorageLevel(false, true, true) + val MEMORY_ONLY_DESER_2 = new StorageLevel(false, true, true, 2) + val DISK_AND_MEMORY = new StorageLevel(true, true, false) + val DISK_AND_MEMORY_2 = new StorageLevel(true, true, false, 2) + val DISK_AND_MEMORY_DESER = new StorageLevel(true, true, true) + val DISK_AND_MEMORY_DESER_2 = new StorageLevel(true, true, true, 2) +} diff --git a/core/src/main/scala/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/spark/util/ByteBufferInputStream.scala new file mode 100644 index 0000000000000000000000000000000000000000..abe2d99dd8a5f6814aa57c4ee2fc15fb08b09ac2 --- /dev/null +++ b/core/src/main/scala/spark/util/ByteBufferInputStream.scala @@ -0,0 +1,30 @@ +package spark.util + +import java.io.InputStream +import java.nio.ByteBuffer + +class ByteBufferInputStream(buffer: ByteBuffer) extends InputStream { + override def read(): Int = { + if (buffer.remaining() == 0) { + -1 + } else { + buffer.get() + } + } + + override def read(dest: Array[Byte]): Int = { + read(dest, 0, dest.length) + } + + override def read(dest: Array[Byte], offset: Int, length: Int): Int = { + val amountToGet = math.min(buffer.remaining(), length) + buffer.get(dest, offset, amountToGet) + return amountToGet + } + + override def skip(bytes: Long): Long = { + val amountToSkip = math.min(bytes, buffer.remaining).toInt + buffer.position(buffer.position + amountToSkip) + return amountToSkip + } +} diff --git a/core/src/main/scala/spark/util/StatCounter.scala b/core/src/main/scala/spark/util/StatCounter.scala new file mode 100644 index 0000000000000000000000000000000000000000..efb1ae75290f5482cb44d46b3222d34b283d9270 --- /dev/null +++ b/core/src/main/scala/spark/util/StatCounter.scala @@ -0,0 +1,89 @@ +package spark.util + +/** + * A class for tracking the statistics of a set of numbers (count, mean and variance) in a + * numerically robust way. Includes support for merging two StatCounters. Based on Welford and + * Chan's algorithms described at http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance. + */ +class StatCounter(values: TraversableOnce[Double]) { + private var n: Long = 0 // Running count of our values + private var mu: Double = 0 // Running mean of our values + private var m2: Double = 0 // Running variance numerator (sum of (x - mean)^2) + + merge(values) + + def this() = this(Nil) + + def merge(value: Double): StatCounter = { + val delta = value - mu + n += 1 + mu += delta / n + m2 += delta * (value - mu) + this + } + + def merge(values: TraversableOnce[Double]): StatCounter = { + values.foreach(v => merge(v)) + this + } + + def merge(other: StatCounter): StatCounter = { + if (other == this) { + merge(other.copy()) // Avoid overwriting fields in a weird order + } else { + val delta = other.mu - mu + if (other.n * 10 < n) { + mu = mu + (delta * other.n) / (n + other.n) + } else if (n * 10 < other.n) { + mu = other.mu - (delta * n) / (n + other.n) + } else { + mu = (mu * n + other.mu * other.n) / (n + other.n) + } + m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n) + n += other.n + this + } + } + + def copy(): StatCounter = { + val other = new StatCounter + other.n = n + other.mu = mu + other.m2 = m2 + other + } + + def count: Long = n + + def mean: Double = mu + + def sum: Double = n * mu + + def variance: Double = { + if (n == 0) + Double.NaN + else + m2 / n + } + + def sampleVariance: Double = { + if (n <= 1) + Double.NaN + else + m2 / (n - 1) + } + + def stdev: Double = math.sqrt(variance) + + def sampleStdev: Double = math.sqrt(sampleVariance) + + override def toString: String = { + "(count: %d, mean: %f, stdev: %f)".format(count, mean, stdev) + } +} + +object StatCounter { + def apply(values: TraversableOnce[Double]) = new StatCounter(values) + + def apply(values: Double*) = new StatCounter(values) +} diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala index 60290d14cab69427a771004f5e1270f00708eaa4..3d170a6e22ef0cec8544454e5622d4432cb0c78c 100644 --- a/core/src/test/scala/spark/CacheTrackerSuite.scala +++ b/core/src/test/scala/spark/CacheTrackerSuite.scala @@ -1,95 +1,103 @@ package spark import org.scalatest.FunSuite -import collection.mutable.HashMap + +import scala.collection.mutable.HashMap + +import akka.actor._ +import akka.actor.Actor +import akka.actor.Actor._ class CacheTrackerSuite extends FunSuite { test("CacheTrackerActor slave initialization & cache status") { - System.setProperty("spark.master.port", "1345") + //System.setProperty("spark.master.port", "1345") val initialSize = 2L << 20 - val tracker = new CacheTrackerActor + val tracker = actorOf(new CacheTrackerActor) tracker.start() - tracker !? SlaveCacheStarted("host001", initialSize) + tracker !! SlaveCacheStarted("host001", initialSize) - assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 0L))) + assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 0L))) - tracker !? StopCacheTracker + tracker !! StopCacheTracker } test("RegisterRDD") { - System.setProperty("spark.master.port", "1345") + //System.setProperty("spark.master.port", "1345") val initialSize = 2L << 20 - val tracker = new CacheTrackerActor + val tracker = actorOf(new CacheTrackerActor) tracker.start() - tracker !? SlaveCacheStarted("host001", initialSize) + tracker !! SlaveCacheStarted("host001", initialSize) - tracker !? RegisterRDD(1, 3) - tracker !? RegisterRDD(2, 1) + tracker !! RegisterRDD(1, 3) + tracker !! RegisterRDD(2, 1) - assert(getCacheLocations(tracker) == Map(1 -> List(List(), List(), List()), 2 -> List(List()))) + assert(getCacheLocations(tracker) === Map(1 -> List(List(), List(), List()), 2 -> List(List()))) - tracker !? StopCacheTracker + tracker !! StopCacheTracker } test("AddedToCache") { - System.setProperty("spark.master.port", "1345") + //System.setProperty("spark.master.port", "1345") val initialSize = 2L << 20 - val tracker = new CacheTrackerActor + val tracker = actorOf(new CacheTrackerActor) tracker.start() - tracker !? SlaveCacheStarted("host001", initialSize) + tracker !! SlaveCacheStarted("host001", initialSize) - tracker !? RegisterRDD(1, 2) - tracker !? RegisterRDD(2, 1) + tracker !! RegisterRDD(1, 2) + tracker !! RegisterRDD(2, 1) - tracker !? AddedToCache(1, 0, "host001", 2L << 15) - tracker !? AddedToCache(1, 1, "host001", 2L << 11) - tracker !? AddedToCache(2, 0, "host001", 3L << 10) + tracker !! AddedToCache(1, 0, "host001", 2L << 15) + tracker !! AddedToCache(1, 1, "host001", 2L << 11) + tracker !! AddedToCache(2, 0, "host001", 3L << 10) - assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L))) + assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 72704L))) - assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) + assert(getCacheLocations(tracker) === + Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) - tracker !? StopCacheTracker + tracker !! StopCacheTracker } test("DroppedFromCache") { - System.setProperty("spark.master.port", "1345") + //System.setProperty("spark.master.port", "1345") val initialSize = 2L << 20 - val tracker = new CacheTrackerActor + val tracker = actorOf(new CacheTrackerActor) tracker.start() - tracker !? SlaveCacheStarted("host001", initialSize) + tracker !! SlaveCacheStarted("host001", initialSize) - tracker !? RegisterRDD(1, 2) - tracker !? RegisterRDD(2, 1) + tracker !! RegisterRDD(1, 2) + tracker !! RegisterRDD(2, 1) - tracker !? AddedToCache(1, 0, "host001", 2L << 15) - tracker !? AddedToCache(1, 1, "host001", 2L << 11) - tracker !? AddedToCache(2, 0, "host001", 3L << 10) + tracker !! AddedToCache(1, 0, "host001", 2L << 15) + tracker !! AddedToCache(1, 1, "host001", 2L << 11) + tracker !! AddedToCache(2, 0, "host001", 3L << 10) - assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 72704L))) - assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) + assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 72704L))) + assert(getCacheLocations(tracker) === + Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) - tracker !? DroppedFromCache(1, 1, "host001", 2L << 11) + tracker !! DroppedFromCache(1, 1, "host001", 2L << 11) - assert(tracker !? GetCacheStatus == Seq(("host001", 2097152L, 68608L))) - assert(getCacheLocations(tracker) == Map(1 -> List(List("host001"),List()), 2 -> List(List("host001")))) + assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 68608L))) + assert(getCacheLocations(tracker) === + Map(1 -> List(List("host001"),List()), 2 -> List(List("host001")))) - tracker !? StopCacheTracker + tracker !! StopCacheTracker } /** * Helper function to get cacheLocations from CacheTracker */ - def getCacheLocations(tracker: CacheTrackerActor) = tracker !? GetCacheLocations match { + def getCacheLocations(tracker: ActorRef) = (tracker ? GetCacheLocations).get match { case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]].map { case (i, arr) => (i -> arr.toList) } diff --git a/core/src/test/scala/spark/MesosSchedulerSuite.scala b/core/src/test/scala/spark/MesosSchedulerSuite.scala index 0e6820cbdcf31b0135d57283ef6b2b78681a5569..54421225d881e9b9e1f84b0cd1373498e64fa749 100644 --- a/core/src/test/scala/spark/MesosSchedulerSuite.scala +++ b/core/src/test/scala/spark/MesosSchedulerSuite.scala @@ -2,6 +2,8 @@ package spark import org.scalatest.FunSuite +import spark.scheduler.mesos.MesosScheduler + class MesosSchedulerSuite extends FunSuite { test("memoryStringToMb"){ diff --git a/core/src/test/scala/spark/UtilsSuite.scala b/core/src/test/scala/spark/UtilsSuite.scala index f31251e509a9c14460a573f7584f42d206362e4e..1ac4737f046d35294a89e7165692fe10f809c966 100644 --- a/core/src/test/scala/spark/UtilsSuite.scala +++ b/core/src/test/scala/spark/UtilsSuite.scala @@ -2,7 +2,7 @@ package spark import org.scalatest.FunSuite import java.io.{ByteArrayOutputStream, ByteArrayInputStream} -import util.Random +import scala.util.Random class UtilsSuite extends FunSuite { diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 08c5a990b489ad3da43c60a73d61d7d3c5e48947..a2faf7399c44225a3df71f13d1fe330674a29a39 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -33,6 +33,7 @@ object SparkBuild extends Build { "org.scalatest" %% "scalatest" % "1.6.1" % "test", "org.scala-tools.testing" %% "scalacheck" % "1.9" % "test" ), + parallelExecution in Test := false, /* Workaround for issue #206 (fixed after SBT 0.11.0) */ watchTransitiveSources <<= Defaults.inDependencies[Task[Seq[File]]](watchSources.task, const(std.TaskExtra.constant(Nil)), aggregate = true, includeRoot = true) apply { _.join.map(_.flatten) } @@ -57,8 +58,12 @@ object SparkBuild extends Build { "asm" % "asm-all" % "3.3.1", "com.google.protobuf" % "protobuf-java" % "2.4.1", "de.javakaffee" % "kryo-serializers" % "0.9", + "se.scalablesolutions.akka" % "akka-actor" % "1.3.1", + "se.scalablesolutions.akka" % "akka-remote" % "1.3.1", + "se.scalablesolutions.akka" % "akka-slf4j" % "1.3.1", "org.jboss.netty" % "netty" % "3.2.6.Final", - "it.unimi.dsi" % "fastutil" % "6.4.2" + "it.unimi.dsi" % "fastutil" % "6.4.4", + "colt" % "colt" % "1.2.0" ) ) ++ assemblySettings ++ Seq(test in assembly := {}) @@ -68,8 +73,7 @@ object SparkBuild extends Build { ) ++ assemblySettings ++ Seq(test in assembly := {}) def examplesSettings = sharedSettings ++ Seq( - name := "spark-examples", - libraryDependencies += "colt" % "colt" % "1.2.0" + name := "spark-examples" ) def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")