From 3920189932e95f78f84ab400e3e779c1601f90f1 Mon Sep 17 00:00:00 2001 From: Matei Zaharia <matei@eecs.berkeley.edu> Date: Thu, 28 Jun 2012 23:51:28 -0700 Subject: [PATCH] Upgraded to Akka 2 and fixed test execution (which was still parallel across projects). --- core/src/main/scala/spark/CacheTracker.scala | 116 +++++---- .../main/scala/spark/MapOutputTracker.scala | 50 +++- core/src/main/scala/spark/SparkContext.scala | 13 +- core/src/main/scala/spark/SparkEnv.scala | 40 +++- .../src/main/scala/spark/SparkException.scala | 6 +- .../mesos/CoarseMesosScheduler.scala | 42 ++-- .../scala/spark/storage/BlockManager.scala | 7 +- .../spark/storage/BlockManagerMaster.scala | 224 ++++++++++-------- .../test/scala/spark/CacheTrackerSuite.scala | 106 +++++---- .../spark/storage/BlockManagerSuite.scala | 23 +- project/SparkBuild.scala | 8 +- 11 files changed, 364 insertions(+), 271 deletions(-) diff --git a/core/src/main/scala/spark/CacheTracker.scala b/core/src/main/scala/spark/CacheTracker.scala index 64b4af0ae2..65e3803144 100644 --- a/core/src/main/scala/spark/CacheTracker.scala +++ b/core/src/main/scala/spark/CacheTracker.scala @@ -1,8 +1,11 @@ package spark import akka.actor._ -import akka.actor.Actor -import akka.actor.Actor._ +import akka.dispatch._ +import akka.pattern.ask +import akka.remote._ +import akka.util.Duration +import akka.util.Timeout import akka.util.duration._ import scala.collection.mutable.ArrayBuffer @@ -44,12 +47,12 @@ class CacheTrackerActor extends Actor with Logging { Utils.memoryBytesToString(size), host)) slaveCapacity.put(host, size) slaveUsage.put(host, 0) - self.reply(true) + sender ! true case RegisterRDD(rddId: Int, numPartitions: Int) => logInfo("Registering RDD " + rddId + " with " + numPartitions + " partitions") locs(rddId) = Array.fill[List[String]](numPartitions)(Nil) - self.reply(true) + sender ! true case AddedToCache(rddId, partition, host, size) => slaveUsage.put(host, getCacheUsage(host) + size) @@ -57,7 +60,7 @@ class CacheTrackerActor extends Actor with Logging { rddId, partition, host, Utils.memoryBytesToString(size), Utils.memoryBytesToString(getCacheAvailable(host)))) locs(rddId)(partition) = host :: locs(rddId)(partition) - self.reply(true) + sender ! true case DroppedFromCache(rddId, partition, host, size) => logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format( @@ -70,7 +73,7 @@ class CacheTrackerActor extends Actor with Logging { logError("Cache usage on %s is negative (%d)".format(host, usage)) } locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host) - self.reply(true) + sender ! true case MemoryCacheLost(host) => logInfo("Memory cache lost on " + host) @@ -79,48 +82,67 @@ class CacheTrackerActor extends Actor with Logging { locations(i) = locations(i).filterNot(_ == host) } } - self.reply(true) + sender ! true case GetCacheLocations => logInfo("Asked for current cache locations") - self.reply(locs.map{case (rrdId, array) => (rrdId -> array.clone())}) + sender ! locs.map{case (rrdId, array) => (rrdId -> array.clone())} case GetCacheStatus => val status = slaveCapacity.map { case (host, capacity) => (host, capacity, getCacheUsage(host)) }.toSeq - self.reply(status) + sender ! status case StopCacheTracker => - logInfo("CacheTrackerActor Server stopped!") - self.reply(true) - self.exit() + logInfo("Stopping CacheTrackerActor") + sender ! true + context.stop(self) } } -class CacheTracker(isMaster: Boolean, blockManager: BlockManager) extends Logging { +class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager) + extends Logging { + // Tracker actor on the master, or remote reference to it on workers val ip: String = System.getProperty("spark.master.host", "localhost") val port: Int = System.getProperty("spark.master.port", "7077").toInt - val aName: String = "CacheTracker" - - if (isMaster) { - } + val actorName: String = "CacheTracker" + + val timeout = 10.seconds var trackerActor: ActorRef = if (isMaster) { - val actor = actorOf(new CacheTrackerActor) - remote.register(aName, actor) - actor.start() - logInfo("Registered CacheTrackerActor actor @ " + ip + ":" + port) + val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName) + logInfo("Registered CacheTrackerActor actor") actor } else { - remote.actorFor(aName, ip, port) + val url = "akka://spark@%s:%s/%s".format(ip, port, actorName) + actorSystem.actorFor(url) } val registeredRddIds = new HashSet[Int] // Remembers which splits are currently being loaded (on worker nodes) val loading = new HashSet[String] + + // Send a message to the trackerActor and get its result within a default timeout, or + // throw a SparkException if this fails. + def askTracker(message: Any): Any = { + try { + val future = trackerActor.ask(message)(timeout) + return Await.result(future, timeout) + } catch { + case e: Exception => + throw new SparkException("Error communicating with CacheTracker", e) + } + } + + // Send a one-way message to the trackerActor, to which we expect it to reply with true. + def communicate(message: Any) { + if (askTracker(message) != true) { + throw new SparkException("Error reply received from CacheTracker") + } + } // Registers an RDD (on master only) def registerRDD(rddId: Int, numPartitions: Int) { @@ -128,62 +150,33 @@ class CacheTracker(isMaster: Boolean, blockManager: BlockManager) extends Loggin if (!registeredRddIds.contains(rddId)) { logInfo("Registering RDD ID " + rddId + " with cache") registeredRddIds += rddId - (trackerActor ? RegisterRDD(rddId, numPartitions)).as[Any] match { - case Some(true) => - logInfo("CacheTracker registerRDD " + RegisterRDD(rddId, numPartitions) + " successfully.") - case Some(oops) => - logError("CacheTracker registerRDD" + RegisterRDD(rddId, numPartitions) + " failed: " + oops) - case None => - logError("CacheTracker registerRDD None. " + RegisterRDD(rddId, numPartitions)) - throw new SparkException("Internal error: CacheTracker registerRDD None.") - } + communicate(RegisterRDD(rddId, numPartitions)) + logInfo(RegisterRDD(rddId, numPartitions) + " successful") } } } // For BlockManager.scala only def cacheLost(host: String) { - (trackerActor ? MemoryCacheLost(host)).as[Any] match { - case Some(true) => - logInfo("CacheTracker successfully removed entries on " + host) - case _ => - logError("CacheTracker did not reply to MemoryCacheLost") - } + communicate(MemoryCacheLost(host)) + logInfo("CacheTracker successfully removed entries on " + host) } // Get the usage status of slave caches. Each tuple in the returned sequence // is in the form of (host name, capacity, usage). def getCacheStatus(): Seq[(String, Long, Long)] = { - (trackerActor ? GetCacheStatus) match { - case h: Seq[(String, Long, Long)] => h.asInstanceOf[Seq[(String, Long, Long)]] - - case _ => - throw new SparkException( - "Internal error: CacheTrackerActor did not reply with a Seq[Tuple3[String, Long, Long]") - } + askTracker(GetCacheStatus).asInstanceOf[Seq[(String, Long, Long)]] } // For BlockManager.scala only def notifyTheCacheTrackerFromBlockManager(t: AddedToCache) { - (trackerActor ? t).as[Any] match { - case Some(true) => - logInfo("CacheTracker notifyTheCacheTrackerFromBlockManager successfully.") - case Some(oops) => - logError("CacheTracker notifyTheCacheTrackerFromBlockManager failed: " + oops) - case None => - logError("CacheTracker notifyTheCacheTrackerFromBlockManager None.") - } + communicate(t) + logInfo("notifyTheCacheTrackerFromBlockManager successful") } // Get a snapshot of the currently known locations def getLocationsSnapshot(): HashMap[Int, Array[List[String]]] = { - (trackerActor ? GetCacheLocations).as[Any] match { - case Some(h: HashMap[_, _]) => - h.asInstanceOf[HashMap[Int, Array[List[String]]]] - - case _ => - throw new SparkException("Internal error: CacheTrackerActor did not reply with a HashMap") - } + askTracker(GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]] } // Gets or computes an RDD split @@ -245,12 +238,11 @@ class CacheTracker(isMaster: Boolean, blockManager: BlockManager) extends Loggin // Called by the Cache to report that an entry has been dropped from it def dropEntry(rddId: Int, partition: Int) { - //TODO - do we really want to use '!!' when nobody checks returned future? '!' seems to enough here. - trackerActor !! DroppedFromCache(rddId, partition, Utils.localHostName()) + communicate(DroppedFromCache(rddId, partition, Utils.localHostName())) } def stop() { - trackerActor !! StopCacheTracker + communicate(StopCacheTracker) registeredRddIds.clear() trackerActor = null } diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index d938a6eb62..d18ecb921d 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -3,8 +3,11 @@ package spark import java.util.concurrent.ConcurrentHashMap import akka.actor._ -import akka.actor.Actor -import akka.actor.Actor._ +import akka.dispatch._ +import akka.pattern.ask +import akka.remote._ +import akka.util.Duration +import akka.util.Timeout import akka.util.duration._ import scala.collection.mutable.HashSet @@ -20,19 +23,21 @@ extends Actor with Logging { def receive = { case GetMapOutputLocations(shuffleId: Int) => logInfo("Asked to get map output locations for shuffle " + shuffleId) - self.reply(bmAddresses.get(shuffleId)) + sender ! bmAddresses.get(shuffleId) case StopMapOutputTracker => logInfo("MapOutputTrackerActor stopped!") - self.reply(true) - self.exit() + sender ! true + context.stop(self) } } -class MapOutputTracker(isMaster: Boolean) extends Logging { +class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging { val ip: String = System.getProperty("spark.master.host", "localhost") val port: Int = System.getProperty("spark.master.port", "7077").toInt - val aName: String = "MapOutputTracker" + val actorName: String = "MapOutputTracker" + + val timeout = 10.seconds private var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]] @@ -42,12 +47,31 @@ class MapOutputTracker(isMaster: Boolean) extends Logging { private var generationLock = new java.lang.Object var trackerActor: ActorRef = if (isMaster) { - val actor = actorOf(new MapOutputTrackerActor(bmAddresses)) - remote.register(aName, actor) - logInfo("Registered MapOutputTrackerActor actor @ " + ip + ":" + port) + val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(bmAddresses)), name = actorName) + logInfo("Registered MapOutputTrackerActor actor") actor } else { - remote.actorFor(aName, ip, port) + val url = "akka://spark@%s:%s/%s".format(ip, port, actorName) + actorSystem.actorFor(url) + } + + // Send a message to the trackerActor and get its result within a default timeout, or + // throw a SparkException if this fails. + def askTracker(message: Any): Any = { + try { + val future = trackerActor.ask(message)(timeout) + return Await.result(future, timeout) + } catch { + case e: Exception => + throw new SparkException("Error communicating with MapOutputTracker", e) + } + } + + // Send a one-way message to the trackerActor, to which we expect it to reply with true. + def communicate(message: Any) { + if (askTracker(message) != true) { + throw new SparkException("Error reply received from MapOutputTracker") + } } def registerShuffle(shuffleId: Int, numMaps: Int) { @@ -110,7 +134,7 @@ class MapOutputTracker(isMaster: Boolean) extends Logging { } // We won the race to fetch the output locs; do so logInfo("Doing the fetch; tracker actor = " + trackerActor) - val fetched = (trackerActor ? GetMapOutputLocations(shuffleId)).as[Array[BlockManagerId]].get + val fetched = askTracker(GetMapOutputLocations(shuffleId)).asInstanceOf[Array[BlockManagerId]] logInfo("Got the output locations") bmAddresses.put(shuffleId, fetched) @@ -125,7 +149,7 @@ class MapOutputTracker(isMaster: Boolean) extends Logging { } def stop() { - trackerActor !! StopMapOutputTracker + communicate(StopMapOutputTracker) bmAddresses.clear() trackerActor = null } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 8bb60b9845..0272040080 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -6,7 +6,6 @@ import java.util.concurrent.atomic.AtomicInteger import akka.actor.Actor import akka.actor.Actor._ -import scala.actors.remote.RemoteActor import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path @@ -64,14 +63,6 @@ class SparkContext( System.setProperty("spark.master.port", "7077") } - // Make sure a proper class loader is set for remote actors (unless user set one) - if (RemoteActor.classLoader == null) { - RemoteActor.classLoader = getClass.getClassLoader - } - - remote.start(System.getProperty("spark.master.host"), - System.getProperty("spark.master.port").toInt) - private val isLocal = master.startsWith("local") // TODO: better check for local // Create the Spark execution environment (cache, map output tracker, etc) @@ -260,7 +251,6 @@ class SparkContext( // Stop the SparkContext def stop() { - remote.shutdownServerModule() dagScheduler.stop() dagScheduler = null taskScheduler = null @@ -271,8 +261,11 @@ class SparkContext( env.shuffleManager.stop() env.blockManager.stop() BlockManagerMaster.stopBlockManagerMaster() + env.actorSystem.shutdown() + env.actorSystem.awaitTermination() SparkEnv.set(null) ShuffleMapTask.clearCache() + logInfo("Successfully stopped SparkContext") } // Wait for the scheduler to be registered with the cluster manager diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 897a5ef82d..974cb5f401 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -1,12 +1,15 @@ package spark -import akka.actor.Actor +import akka.actor.ActorSystem + +import com.typesafe.config.ConfigFactory import spark.storage.BlockManager import spark.storage.BlockManagerMaster import spark.network.ConnectionManager class SparkEnv ( + val actorSystem: ActorSystem, val cache: Cache, val serializer: Serializer, val closureSerializer: Serializer, @@ -19,7 +22,7 @@ class SparkEnv ( ) { /** No-parameter constructor for unit tests. */ - def this() = this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null) + def this() = this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null) } object SparkEnv { @@ -34,10 +37,24 @@ object SparkEnv { } def createFromSystemProperties(isMaster: Boolean, isLocal: Boolean): SparkEnv = { + val host = System.getProperty("spark.master.host") + val port = System.getProperty("spark.master.port").toInt + if (port == 0) { + throw new IllegalArgumentException("Setting spark.master.port to 0 is not yet supported") + } + val akkaConf = ConfigFactory.parseString(""" + akka.daemonic = on + akka.actor.provider = "akka.remote.RemoteActorRefProvider" + akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" + akka.remote.netty.hostname = "%s" + akka.remote.netty.port = %d + """.format(host, port)) + val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader) + val serializerClass = System.getProperty("spark.serializer", "spark.KryoSerializer") val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer] - BlockManagerMaster.startBlockManagerMaster(isMaster, isLocal) + BlockManagerMaster.startBlockManagerMaster(actorSystem, isMaster, isLocal) var blockManager = new BlockManager(serializer) @@ -52,10 +69,10 @@ object SparkEnv { val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache") val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache] - val cacheTracker = new CacheTracker(isMaster, blockManager) + val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager) blockManager.cacheTracker = cacheTracker - val mapOutputTracker = new MapOutputTracker(isMaster) + val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster) val shuffleFetcherClass = System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher") @@ -81,7 +98,16 @@ object SparkEnv { } */ - new SparkEnv(cache, serializer, closureSerializer, cacheTracker, mapOutputTracker, shuffleFetcher, - shuffleManager, blockManager, connectionManager) + new SparkEnv( + actorSystem, + cache, + serializer, + closureSerializer, + cacheTracker, + mapOutputTracker, + shuffleFetcher, + shuffleManager, + blockManager, + connectionManager) } } diff --git a/core/src/main/scala/spark/SparkException.scala b/core/src/main/scala/spark/SparkException.scala index 6f9be1a94f..aa7a16d7dd 100644 --- a/core/src/main/scala/spark/SparkException.scala +++ b/core/src/main/scala/spark/SparkException.scala @@ -1,3 +1,7 @@ package spark -class SparkException(message: String) extends Exception(message) {} +class SparkException(message: String, cause: Throwable) + extends Exception(message, cause) { + + def this(message: String) = this(message, null) +} diff --git a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala index 8182901ce3..525cf9747f 100644 --- a/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala +++ b/core/src/main/scala/spark/scheduler/mesos/CoarseMesosScheduler.scala @@ -15,10 +15,12 @@ import scala.collection.JavaConversions._ import scala.math.Ordering import akka.actor._ -import akka.actor.Actor -import akka.actor.Actor._ -import akka.actor.Channel -import akka.serialization.RemoteActorSerialization._ +import akka.dispatch._ +import akka.pattern.ask +import akka.remote._ +import akka.util.Duration +import akka.util.Timeout +import akka.util.duration._ import com.google.protobuf.ByteString @@ -30,7 +32,7 @@ import spark._ import spark.scheduler._ sealed trait CoarseMesosSchedulerMessage -case class RegisterSlave(slaveId: String, host: String, port: Int) extends CoarseMesosSchedulerMessage +case class RegisterSlave(slaveId: String, host: String) extends CoarseMesosSchedulerMessage case class StatusUpdate(slaveId: String, status: TaskStatus) extends CoarseMesosSchedulerMessage case class LaunchTask(slaveId: String, task: MTaskInfo) extends CoarseMesosSchedulerMessage case class ReviveOffers() extends CoarseMesosSchedulerMessage @@ -50,7 +52,9 @@ class CoarseMesosScheduler( frameworkName: String) extends MesosScheduler(sc, master, frameworkName) { - val CORES_PER_SLAVE = System.getProperty("spark.coarseMesosScheduler.coresPerSlave", "4").toInt + val actorSystem = sc.env.actorSystem + val actorName = "CoarseMesosScheduler" + val coresPerSlave = System.getProperty("spark.coarseMesosScheduler.coresPerSlave", "4").toInt class MasterActor extends Actor { val slaveActor = new HashMap[String, ActorRef] @@ -58,11 +62,11 @@ class CoarseMesosScheduler( val freeCores = new HashMap[String, Int] def receive = { - case RegisterSlave(slaveId, host, port) => - slaveActor(slaveId) = remote.actorFor("WorkerActor", host, port) - logInfo("Slave actor: " + slaveActor(slaveId)) + case RegisterSlave(slaveId, host) => + slaveActor(slaveId) = sender + logInfo("Slave actor: " + sender) slaveHost(slaveId) = host - freeCores(slaveId) = CORES_PER_SLAVE + freeCores(slaveId) = coresPerSlave makeFakeOffers() case StatusUpdate(slaveId, status) => @@ -92,9 +96,7 @@ class CoarseMesosScheduler( } } - val masterActor: ActorRef = actorOf(new MasterActor) - remote.register("MasterActor", masterActor) - masterActor.start() + val masterActor: ActorRef = actorSystem.actorOf(Props[MasterActor], name = actorName) val taskIdsOnSlave = new HashMap[String, HashSet[String]] @@ -282,12 +284,8 @@ class WorkerTask(slaveId: String, host: String) extends Task[Unit](-1) { generation = 0 def run(id: Int): Unit = { - val actor = actorOf(new WorkerActor(slaveId, host)) - if (!remote.isRunning) { - remote.start(Utils.localIpAddress, 7078) - } - remote.register("WorkerActor", actor) - actor.start() + val actorSystem = SparkEnv.get.actorSystem + val actor = actorSystem.actorOf(Props(new WorkerActor(slaveId, host)), name = "WorkerActor") while (true) { Thread.sleep(10000) } @@ -302,7 +300,8 @@ class WorkerActor(slaveId: String, host: String) extends Actor with Logging { val masterIp: String = System.getProperty("spark.master.host", "localhost") val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt - val masterActor = remote.actorFor("MasterActor", masterIp, masterPort) + val masterActor = env.actorSystem.actorFor( + "akka://spark@%s:%s/%s".format(masterIp, masterPort, "CoarseMesosScheduler")) class TaskRunner(desc: MTaskInfo) extends Runnable { @@ -352,9 +351,8 @@ class WorkerActor(slaveId: String, host: String) extends Actor with Logging { } override def preStart { - val ref = toRemoteActorRefProtocol(self).toByteArray logInfo("Registering with master") - masterActor ! RegisterSlave(slaveId, host, remote.address.getPort) + masterActor ! RegisterSlave(slaveId, host) } override def receive = { diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala index 9e4816f7ce..0a807f0582 100644 --- a/core/src/main/scala/spark/storage/BlockManager.scala +++ b/core/src/main/scala/spark/storage/BlockManager.scala @@ -154,8 +154,8 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging */ def getLocations(blockId: String): Seq[String] = { val startTimeMs = System.currentTimeMillis - var managers: Array[BlockManagerId] = BlockManagerMaster.mustGetLocations(GetLocations(blockId)) - val locations = managers.map((manager: BlockManagerId) => { manager.ip }).toSeq + var managers = BlockManagerMaster.mustGetLocations(GetLocations(blockId)) + val locations = managers.map(_.ip) logDebug("Get block locations in " + Utils.getUsedTimeMs(startTimeMs)) return locations } @@ -490,8 +490,7 @@ class BlockManager(maxMemory: Long, val serializer: Serializer) extends Logging private def replicate(blockId: String, data: ByteBuffer, level: StorageLevel) { val tLevel: StorageLevel = new StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) - var peers: Array[BlockManagerId] = BlockManagerMaster.mustGetPeers( - GetPeers(blockManagerId, level.replication - 1)) + var peers = BlockManagerMaster.mustGetPeers(GetPeers(blockManagerId, level.replication - 1)) for (peer: BlockManagerId <- peers) { val start = System.nanoTime logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is " diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala index d8400a1f65..5fe0e22dd0 100644 --- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala @@ -9,11 +9,15 @@ import scala.collection.mutable.HashSet import scala.util.Random import akka.actor._ -import akka.actor.Actor -import akka.actor.Actor._ +import akka.dispatch._ +import akka.pattern.ask +import akka.remote._ +import akka.util.Duration +import akka.util.Timeout import akka.util.duration._ import spark.Logging +import spark.SparkException import spark.Utils sealed trait ToBlockManagerMaster @@ -70,22 +74,15 @@ object HeartBeat { } } -case class GetLocations( - blockId: String) - extends ToBlockManagerMaster +case class GetLocations(blockId: String) extends ToBlockManagerMaster -case class GetLocationsMultipleBlockIds( - blockIds: Array[String]) - extends ToBlockManagerMaster +case class GetLocationsMultipleBlockIds(blockIds: Array[String]) extends ToBlockManagerMaster -case class GetPeers( - blockManagerId: BlockManagerId, - size: Int) - extends ToBlockManagerMaster +case class GetPeers(blockManagerId: BlockManagerId, size: Int) extends ToBlockManagerMaster -case class RemoveHost( - host: String) - extends ToBlockManagerMaster +case class RemoveHost(host: String) extends ToBlockManagerMaster + +case object StopBlockManagerMaster extends ToBlockManagerMaster class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { @@ -170,7 +167,7 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { val port = host.split(":")(1) blockManagerInfo.remove(new BlockManagerId(ip, port.toInt)) logInfo("Current hosts: " + blockManagerInfo.keySet.toSeq) - self.reply(true) + sender ! true } def receive = { @@ -187,14 +184,20 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { getLocationsMultipleBlockIds(blockIds) case GetPeers(blockManagerId, size) => - getPeers_Deterministic(blockManagerId, size) + getPeersDeterministic(blockManagerId, size) /*getPeers(blockManagerId, size)*/ case RemoveHost(host) => removeHost(host) + sender ! true - case msg => - logInfo("Got unknown msg: " + msg) + case StopBlockManagerMaster => + logInfo("Stopping BlockManagerMaster") + sender ! true + context.stop(self) + + case other => + logInfo("Got unknown message: " + other) } private def register(blockManagerId: BlockManagerId, maxMemSize: Long, maxDiskSize: Long) { @@ -209,7 +212,7 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { System.currentTimeMillis() / 1000, maxMemSize, maxDiskSize)) } logDebug("Got in register 1" + tmp + Utils.getUsedTimeMs(startTimeMs)) - self.reply(true) + sender ! true } private def heartBeat( @@ -225,7 +228,7 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { if (blockId == null) { blockManagerInfo(blockManagerId).updateLastSeenMs() logDebug("Got in heartBeat 1" + tmp + " used " + Utils.getUsedTimeMs(startTimeMs)) - self.reply(true) + sender ! true } blockManagerInfo(blockManagerId).addBlock(blockId, storageLevel, deserializedSize, size) @@ -247,7 +250,7 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { if (locations.size == 0) { blockInfo.remove(blockId) } - self.reply(true) + sender ! true } private def getLocations(blockId: String) { @@ -259,11 +262,11 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { res.appendAll(blockInfo.get(blockId)._2) logDebug("Got in getLocations 1" + tmp + " as "+ res.toSeq + " at " + Utils.getUsedTimeMs(startTimeMs)) - self.reply(res.toSeq) + sender ! res.toSeq } else { logDebug("Got in getLocations 2" + tmp + Utils.getUsedTimeMs(startTimeMs)) var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - self.reply(res) + sender ! res } } @@ -289,7 +292,7 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { res.append(getLocations(blockId)) } logDebug("Got in getLocationsMultipleBlockIds " + blockIds.toSeq + " : " + res.toSeq) - self.reply(res.toSeq) + sender ! res.toSeq } private def getPeers(blockManagerId: BlockManagerId, size: Int) { @@ -301,10 +304,10 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { while (res.length > size) { res.remove(rand.nextInt(res.length)) } - self.reply(res.toSeq) + sender ! res.toSeq } - private def getPeers_Deterministic(blockManagerId: BlockManagerId, size: Int) { + private def getPeersDeterministic(blockManagerId: BlockManagerId, size: Int) { var peers: Array[BlockManagerId] = blockManagerInfo.keySet.toArray var res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] @@ -322,8 +325,7 @@ class BlockManagerMaster(val isLocal: Boolean) extends Actor with Logging { } res += peers(index % peers.size) } - val resStr = res.map(_.toString).reduceLeft(_ + ", " + _) - self.reply(res.toSeq) + sender ! res.toSeq } } @@ -337,37 +339,51 @@ object BlockManagerMaster extends Logging { val DEFAULT_MANAGER_IP: String = Utils.localHostName() val DEFAULT_MANAGER_PORT: String = "10902" - implicit val TIME_OUT_SEC = Actor.Timeout(3000 millis) + val timeout = 10.seconds var masterActor: ActorRef = null - def startBlockManagerMaster(isMaster: Boolean, isLocal: Boolean) { + def startBlockManagerMaster(actorSystem: ActorSystem, isMaster: Boolean, isLocal: Boolean) { if (isMaster) { - masterActor = actorOf(new BlockManagerMaster(isLocal)) - remote.register(AKKA_ACTOR_NAME, masterActor) - logInfo("Registered BlockManagerMaster Actor: " + DEFAULT_MASTER_IP + ":" + DEFAULT_MASTER_PORT) - masterActor.start() + masterActor = actorSystem.actorOf( + Props(new BlockManagerMaster(isLocal)), name = AKKA_ACTOR_NAME) + logInfo("Registered BlockManagerMaster Actor") } else { - masterActor = remote.actorFor(AKKA_ACTOR_NAME, DEFAULT_MASTER_IP, DEFAULT_MASTER_PORT) + val url = "akka://spark@%s:%s/%s".format( + DEFAULT_MASTER_IP, DEFAULT_MASTER_PORT, AKKA_ACTOR_NAME) + masterActor = actorSystem.actorFor(url) } } def stopBlockManagerMaster() { if (masterActor != null) { - masterActor.stop() + communicate(StopBlockManagerMaster) masterActor = null logInfo("BlockManagerMaster stopped") } } + + // Send a message to the master actor and get its result within a default timeout, or + // throw a SparkException if this fails. + def askMaster(message: Any): Any = { + try { + val future = masterActor.ask(message)(timeout) + return Await.result(future, timeout) + } catch { + case e: Exception => + throw new SparkException("Error communicating with BlockManagerMaster", e) + } + } + + // Send a one-way message to the master actor, to which we expect it to reply with true. + def communicate(message: Any) { + if (askMaster(message) != true) { + throw new SparkException("Error reply received from BlockManagerMaster") + } + } def notifyADeadHost(host: String) { - (masterActor ? RemoveHost(host + ":" + DEFAULT_MANAGER_PORT)).as[Any] match { - case Some(true) => - logInfo("Removed " + host + " successfully. @ notifyADeadHost") - case Some(oops) => - logError("Failed @ notifyADeadHost: " + oops) - case None => - logError("None @ notifyADeadHost.") - } + communicate(RemoveHost(host + ":" + DEFAULT_MANAGER_PORT)) + logInfo("Removed " + host + " successfully in notifyADeadHost") } def mustRegisterBlockManager(msg: RegisterBlockManager) { @@ -383,16 +399,14 @@ object BlockManagerMaster extends Logging { val tmp = " msg " + msg + " " logDebug("Got in syncRegisterBlockManager 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - (masterActor ? msg).as[Any] match { - case Some(true) => - logInfo("BlockManager registered successfully @ syncRegisterBlockManager.") - logDebug("Got in syncRegisterBlockManager 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return true - case Some(oops) => - logError("Failed @ syncRegisterBlockManager: " + oops) - return false - case None => - logError("None @ syncRegisterBlockManager.") + try { + communicate(msg) + logInfo("BlockManager registered successfully @ syncRegisterBlockManager") + logDebug("Got in syncRegisterBlockManager 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) + return true + } catch { + case e: Exception => + logError("Failed in syncRegisterBlockManager", e) return false } } @@ -409,22 +423,20 @@ object BlockManagerMaster extends Logging { val tmp = " msg " + msg + " " logDebug("Got in syncHeartBeat " + tmp + " 0 " + Utils.getUsedTimeMs(startTimeMs)) - (masterActor ? msg).as[Any] match { - case Some(true) => - logInfo("Heartbeat sent successfully.") - logDebug("Got in syncHeartBeat " + tmp + " 1 " + Utils.getUsedTimeMs(startTimeMs)) - return true - case Some(oops) => - logError("Failed: " + oops) - return false - case None => - logError("None.") + try { + communicate(msg) + logInfo("Heartbeat sent successfully") + logDebug("Got in syncHeartBeat 1 " + tmp + " 1 " + Utils.getUsedTimeMs(startTimeMs)) + return true + } catch { + case e: Exception => + logError("Failed in syncHeartBeat", e) return false } } - def mustGetLocations(msg: GetLocations): Array[BlockManagerId] = { - var res: Array[BlockManagerId] = syncGetLocations(msg) + def mustGetLocations(msg: GetLocations): Seq[BlockManagerId] = { + var res = syncGetLocations(msg) while (res == null) { logInfo("Failed to get locations " + msg) Thread.sleep(REQUEST_RETRY_INTERVAL_MS) @@ -433,23 +445,24 @@ object BlockManagerMaster extends Logging { return res } - def syncGetLocations(msg: GetLocations): Array[BlockManagerId] = { + def syncGetLocations(msg: GetLocations): Seq[BlockManagerId] = { val startTimeMs = System.currentTimeMillis() val tmp = " msg " + msg + " " logDebug("Got in syncGetLocations 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - (masterActor ? msg).as[Seq[BlockManagerId]] match { - case Some(arr) => - logDebug("GetLocations successfully.") + + try { + val answer = askMaster(msg).asInstanceOf[ArrayBuffer[BlockManagerId]] + if (answer != null) { + logDebug("GetLocations successful") logDebug("Got in syncGetLocations 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - val res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - for (ele <- arr) { - res += ele - } - logDebug("Got in syncGetLocations 2 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return res.toArray - case None => - logError("GetLocations call returned None.") + return answer + } else { + logError("Master replied null in response to GetLocations") + return null + } + } catch { + case e: Exception => + logError("GetLocations failed", e) return null } } @@ -471,22 +484,26 @@ object BlockManagerMaster extends Logging { val tmp = " msg " + msg + " " logDebug("Got in syncGetLocationsMultipleBlockIds 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - (masterActor ? msg).as[Any] match { - case Some(arr: Seq[Seq[BlockManagerId]]) => - logDebug("GetLocationsMultipleBlockIds successfully: " + arr) - logDebug("Got in syncGetLocationsMultipleBlockIds 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return arr - case Some(oops) => - logError("Failed: " + oops) + try { + val answer = askMaster(msg).asInstanceOf[Seq[Seq[BlockManagerId]]] + if (answer != null) { + logDebug("GetLocationsMultipleBlockIds successful") + logDebug("Got in syncGetLocationsMultipleBlockIds 1 " + tmp + + Utils.getUsedTimeMs(startTimeMs)) + return answer + } else { + logError("Master replied null in response to GetLocationsMultipleBlockIds") return null - case None => - logInfo("None.") + } + } catch { + case e: Exception => + logError("GetLocationsMultipleBlockIds failed", e) return null } } - def mustGetPeers(msg: GetPeers): Array[BlockManagerId] = { - var res: Array[BlockManagerId] = syncGetPeers(msg) + def mustGetPeers(msg: GetPeers): Seq[BlockManagerId] = { + var res = syncGetPeers(msg) while ((res == null) || (res.length != msg.size)) { logInfo("Failed to get peers " + msg) Thread.sleep(REQUEST_RETRY_INTERVAL_MS) @@ -496,21 +513,24 @@ object BlockManagerMaster extends Logging { return res } - def syncGetPeers(msg: GetPeers): Array[BlockManagerId] = { + def syncGetPeers(msg: GetPeers): Seq[BlockManagerId] = { val startTimeMs = System.currentTimeMillis val tmp = " msg " + msg + " " logDebug("Got in syncGetPeers 0 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - - (masterActor ? msg).as[Seq[BlockManagerId]] match { - case Some(arr) => + + try { + val answer = askMaster(msg).asInstanceOf[Seq[BlockManagerId]] + if (answer != null) { + logDebug("GetPeers successful") logDebug("Got in syncGetPeers 1 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - val res: ArrayBuffer[BlockManagerId] = new ArrayBuffer[BlockManagerId] - logInfo("GetPeers successfully: " + arr.length) - res.appendAll(arr) - logDebug("Got in syncGetPeers 2 " + tmp + Utils.getUsedTimeMs(startTimeMs)) - return res.toArray - case None => - logError("GetPeers call returned None.") + return answer + } else { + logError("Master replied null in response to GetPeers") + return null + } + } catch { + case e: Exception => + logError("GetPeers failed", e) return null } } diff --git a/core/src/test/scala/spark/CacheTrackerSuite.scala b/core/src/test/scala/spark/CacheTrackerSuite.scala index 3d170a6e22..426c0d26e9 100644 --- a/core/src/test/scala/spark/CacheTrackerSuite.scala +++ b/core/src/test/scala/spark/CacheTrackerSuite.scala @@ -5,101 +5,127 @@ import org.scalatest.FunSuite import scala.collection.mutable.HashMap import akka.actor._ -import akka.actor.Actor -import akka.actor.Actor._ +import akka.dispatch._ +import akka.pattern.ask +import akka.remote._ +import akka.util.Duration +import akka.util.Timeout +import akka.util.duration._ class CacheTrackerSuite extends FunSuite { + // Send a message to an actor and wait for a reply, in a blocking manner + private def ask(actor: ActorRef, message: Any): Any = { + try { + val timeout = 10.seconds + val future = actor.ask(message)(timeout) + return Await.result(future, timeout) + } catch { + case e: Exception => + throw new SparkException("Error communicating with actor", e) + } + } test("CacheTrackerActor slave initialization & cache status") { //System.setProperty("spark.master.port", "1345") val initialSize = 2L << 20 - val tracker = actorOf(new CacheTrackerActor) - tracker.start() + val actorSystem = ActorSystem("test") + val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - tracker !! SlaveCacheStarted("host001", initialSize) + assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 0L))) + assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 0L))) - tracker !! StopCacheTracker + assert(ask(tracker, StopCacheTracker) === true) + + actorSystem.shutdown() + actorSystem.awaitTermination() } test("RegisterRDD") { //System.setProperty("spark.master.port", "1345") val initialSize = 2L << 20 - val tracker = actorOf(new CacheTrackerActor) - tracker.start() + val actorSystem = ActorSystem("test") + val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - tracker !! SlaveCacheStarted("host001", initialSize) + assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - tracker !! RegisterRDD(1, 3) - tracker !! RegisterRDD(2, 1) + assert(ask(tracker, RegisterRDD(1, 3)) === true) + assert(ask(tracker, RegisterRDD(2, 1)) === true) - assert(getCacheLocations(tracker) === Map(1 -> List(List(), List(), List()), 2 -> List(List()))) + assert(getCacheLocations(tracker) === Map(1 -> List(Nil, Nil, Nil), 2 -> List(Nil))) - tracker !! StopCacheTracker + assert(ask(tracker, StopCacheTracker) === true) + + actorSystem.shutdown() + actorSystem.awaitTermination() } test("AddedToCache") { //System.setProperty("spark.master.port", "1345") val initialSize = 2L << 20 - val tracker = actorOf(new CacheTrackerActor) - tracker.start() + val actorSystem = ActorSystem("test") + val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - tracker !! SlaveCacheStarted("host001", initialSize) + assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - tracker !! RegisterRDD(1, 2) - tracker !! RegisterRDD(2, 1) + assert(ask(tracker, RegisterRDD(1, 2)) === true) + assert(ask(tracker, RegisterRDD(2, 1)) === true) - tracker !! AddedToCache(1, 0, "host001", 2L << 15) - tracker !! AddedToCache(1, 1, "host001", 2L << 11) - tracker !! AddedToCache(2, 0, "host001", 3L << 10) + assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true) + assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true) + assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true) - assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 72704L))) + assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L))) assert(getCacheLocations(tracker) === Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) - tracker !! StopCacheTracker + assert(ask(tracker, StopCacheTracker) === true) + + actorSystem.shutdown() + actorSystem.awaitTermination() } test("DroppedFromCache") { //System.setProperty("spark.master.port", "1345") val initialSize = 2L << 20 - val tracker = actorOf(new CacheTrackerActor) - tracker.start() + val actorSystem = ActorSystem("test") + val tracker = actorSystem.actorOf(Props[CacheTrackerActor]) - tracker !! SlaveCacheStarted("host001", initialSize) + assert(ask(tracker, SlaveCacheStarted("host001", initialSize)) === true) - tracker !! RegisterRDD(1, 2) - tracker !! RegisterRDD(2, 1) + assert(ask(tracker, RegisterRDD(1, 2)) === true) + assert(ask(tracker, RegisterRDD(2, 1)) === true) - tracker !! AddedToCache(1, 0, "host001", 2L << 15) - tracker !! AddedToCache(1, 1, "host001", 2L << 11) - tracker !! AddedToCache(2, 0, "host001", 3L << 10) + assert(ask(tracker, AddedToCache(1, 0, "host001", 2L << 15)) === true) + assert(ask(tracker, AddedToCache(1, 1, "host001", 2L << 11)) === true) + assert(ask(tracker, AddedToCache(2, 0, "host001", 3L << 10)) === true) - assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 72704L))) + assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 72704L))) assert(getCacheLocations(tracker) === Map(1 -> List(List("host001"), List("host001")), 2 -> List(List("host001")))) - tracker !! DroppedFromCache(1, 1, "host001", 2L << 11) + assert(ask(tracker, DroppedFromCache(1, 1, "host001", 2L << 11)) === true) - assert((tracker ? GetCacheStatus).get === Seq(("host001", 2097152L, 68608L))) + assert(ask(tracker, GetCacheStatus) === Seq(("host001", 2097152L, 68608L))) assert(getCacheLocations(tracker) === Map(1 -> List(List("host001"),List()), 2 -> List(List("host001")))) - tracker !! StopCacheTracker + assert(ask(tracker, StopCacheTracker) === true) + + actorSystem.shutdown() + actorSystem.awaitTermination() } /** * Helper function to get cacheLocations from CacheTracker */ - def getCacheLocations(tracker: ActorRef) = (tracker ? GetCacheLocations).get match { - case h: HashMap[_, _] => h.asInstanceOf[HashMap[Int, Array[List[String]]]].map { - case (i, arr) => (i -> arr.toList) - } + def getCacheLocations(tracker: ActorRef): HashMap[Int, List[List[String]]] = { + val answer = ask(tracker, GetCacheLocations).asInstanceOf[HashMap[Int, Array[List[String]]]] + answer.map { case (i, arr) => (i, arr.toList) } } } diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala index 14ff5f8e3d..027d1423d4 100644 --- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala @@ -1,16 +1,27 @@ package spark.storage +import java.nio.ByteBuffer + +import akka.actor._ + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfterEach + import spark.KryoSerializer import spark.util.ByteBufferInputStream -import org.scalatest.FunSuite -import org.scalatest.BeforeAndAfter +class BlockManagerSuite extends FunSuite with BeforeAndAfterEach { + var actorSystem: ActorSystem = null -import java.nio.ByteBuffer + override def beforeEach() { + actorSystem = ActorSystem("test") + BlockManagerMaster.startBlockManagerMaster(actorSystem, true, true) + } -class BlockManagerSuite extends FunSuite with BeforeAndAfter{ - before { - BlockManagerMaster.startBlockManagerMaster(true, true) + override def afterEach() { + actorSystem.shutdown() + actorSystem.awaitTermination() + actorSystem = null } test("manager-master interaction") { diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 985de3cde3..10380e9397 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -33,7 +33,7 @@ object SparkBuild extends Build { "org.scalatest" %% "scalatest" % "1.6.1" % "test", "org.scala-tools.testing" %% "scalacheck" % "1.9" % "test" ), - parallelExecution in Test := false, + parallelExecution := false, /* Workaround for issue #206 (fixed after SBT 0.11.0) */ watchTransitiveSources <<= Defaults.inDependencies[Task[Seq[File]]](watchSources.task, const(std.TaskExtra.constant(Nil)), aggregate = true, includeRoot = true) apply { _.join.map(_.flatten) } @@ -58,9 +58,9 @@ object SparkBuild extends Build { "asm" % "asm-all" % "3.3.1", "com.google.protobuf" % "protobuf-java" % "2.4.1", "de.javakaffee" % "kryo-serializers" % "0.9", - "se.scalablesolutions.akka" % "akka-actor" % "1.3.1", - "se.scalablesolutions.akka" % "akka-remote" % "1.3.1", - "se.scalablesolutions.akka" % "akka-slf4j" % "1.3.1", + "com.typesafe.akka" % "akka-actor" % "2.0.2", + "com.typesafe.akka" % "akka-remote" % "2.0.2", + "com.typesafe.akka" % "akka-slf4j" % "2.0.2", "org.jboss.netty" % "netty" % "3.2.6.Final", "it.unimi.dsi" % "fastutil" % "6.4.4", "colt" % "colt" % "1.2.0" -- GitLab