diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 1afb1870f1617924791426fa32ddb958cb903b06..6590e9779e09e6ad0758a41a6f344c1e02de3af3 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -60,7 +60,7 @@ private[spark] class MapOutputTracker extends Logging { private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") // Set to the MapOutputTrackerActor living on the driver - var trackerActor: ActorRef = _ + var trackerActor: Either[ActorRef, ActorSelection] = _ private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] @@ -79,7 +79,11 @@ private[spark] class MapOutputTracker extends Logging { // throw a SparkException if this fails. def askTracker(message: Any): Any = { try { - val future = trackerActor.ask(message)(timeout) + val future = if (trackerActor.isLeft ) { + trackerActor.left.get.ask(message)(timeout) + } else { + trackerActor.right.get.ask(message)(timeout) + } return Await.result(future, timeout) } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index a267407c673f0b54bb74ecb7fe9ea6bb2a76155c..0d9bd500e487f3fa70986cf70402e5ab6c46da9c 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -20,7 +20,7 @@ package org.apache.spark import collection.mutable import serializer.Serializer -import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem} +import akka.actor._ import akka.remote.RemoteActorRefProvider import org.apache.spark.broadcast.BroadcastManager @@ -161,17 +161,17 @@ object SparkEnv extends Logging { val closureSerializer = serializerManager.get( System.getProperty("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer")) - def registerOrLookup(name: String, newActor: => Actor): ActorRef = { + def registerOrLookup(name: String, newActor: => Actor): Either[ActorRef, ActorSelection] = { if (isDriver) { logInfo("Registering " + name) - actorSystem.actorOf(Props(newActor), name = name) + Left(actorSystem.actorOf(Props(newActor), name = name)) } else { val driverHost: String = System.getProperty("spark.driver.host", "localhost") val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt Utils.checkHost(driverHost, "Expected hostname") val url = "akka.tcp://spark@%s:%s/user/%s".format(driverHost, driverPort, name) logInfo("Connecting to " + name + ": " + url) - actorSystem.actorFor(url) + Right(actorSystem.actorSelection(url)) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala index 164386782ccb99a059fcd5859385e875af1440aa..000d1ee9f8f1f754712955da7a2254a97a077f8a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala @@ -49,18 +49,14 @@ private[spark] class Client( var appId: String = null class ClientActor extends Actor with Logging { - var master: ActorRef = null - var masterAddress: Address = null + var master: ActorSelection = null var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times override def preStart() { logInfo("Connecting to master " + masterUrl) try { - master = context.actorFor(Master.toAkkaUrl(masterUrl)) - masterAddress = master.path.address + master = context.actorSelection(Master.toAkkaUrl(masterUrl)) master ! RegisterApplication(appDescription) - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - context.watch(master) // Doesn't work with remote actors, but useful for testing } catch { case e: Exception => logError("Failed to connect to master", e) @@ -71,6 +67,7 @@ private[spark] class Client( override def receive = { case RegisteredApplication(appId_) => + context.watch(sender) appId = appId_ listener.connected(appId) @@ -92,18 +89,8 @@ private[spark] class Client( listener.executorRemoved(fullId, message.getOrElse(""), exitStatus) } - case Terminated(actor_) if actor_ == master => - logError("Connection to master failed; stopping client") - markDisconnected() - context.stop(self) - - case DisassociatedEvent(_, address, _) if address == masterAddress => - logError("Connection to master failed; stopping client") - markDisconnected() - context.stop(self) - - case AssociationErrorEvent(_, _, address, _) if address == masterAddress => - logError("Connection to master failed; stopping client") + case Terminated(actor_) => + logError(s"Connection to $actor_ dropped, stopping client") markDisconnected() context.stop(self) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 3904b701b2f2848243da8ef5505529f9f8fb068a..400d6f26ea45cbb8353a83acb361422361164857 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -24,7 +24,7 @@ import java.io.File import scala.collection.mutable.HashMap import scala.concurrent.duration._ -import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated} +import akka.actor._ import akka.remote.{RemotingLifecycleEvent, AssociationErrorEvent, DisassociatedEvent} import org.apache.spark.Logging @@ -34,6 +34,16 @@ import org.apache.spark.deploy.master.Master import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.deploy.DeployMessages.WorkerStateResponse +import org.apache.spark.deploy.DeployMessages.RegisterWorkerFailed +import org.apache.spark.deploy.DeployMessages.KillExecutor +import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged +import scala.Some +import org.apache.spark.deploy.DeployMessages.Heartbeat +import org.apache.spark.deploy.DeployMessages.RegisteredWorker +import akka.remote.DisassociatedEvent +import org.apache.spark.deploy.DeployMessages.LaunchExecutor +import org.apache.spark.deploy.DeployMessages.RegisterWorker private[spark] class Worker( @@ -54,7 +64,7 @@ private[spark] class Worker( // Send a heartbeat every (heartbeat timeout) / 4 milliseconds val HEARTBEAT_MILLIS = System.getProperty("spark.worker.timeout", "60").toLong * 1000 / 4 - var master: ActorRef = null + var master: ActorSelection = null var masterWebUiUrl : String = "" val workerId = generateWorkerId() var sparkHome: File = null @@ -111,10 +121,8 @@ private[spark] class Worker( def connectToMaster() { logInfo("Connecting to master " + masterUrl) - master = context.actorFor(Master.toAkkaUrl(masterUrl)) + master = context.actorSelection(Master.toAkkaUrl(masterUrl)) master ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get, publicAddress) - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - context.watch(master) // Doesn't work with remote actors, but useful for testing } import context.dispatcher @@ -123,6 +131,8 @@ private[spark] class Worker( case RegisteredWorker(url) => masterWebUiUrl = url logInfo("Successfully registered with master") + context.watch(sender) // remote death watch for master + //TODO: Is heartbeat really necessary akka does it anyway ! context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis) { master ! Heartbeat(workerId) } @@ -165,7 +175,8 @@ private[spark] class Worker( logInfo("Asked to kill unknown executor " + fullId) } - case DisassociatedEvent(_, _, _) => + case Terminated(actor_) => + logInfo(s"$actor_ terminated !") masterDisconnected() case RequestWorkerState => { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 0c977f05d1d22dcc730c438bdbee13c11167b62a..c1aa43d59c7c51315427ec84304fe3f50e30ccc1 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -17,14 +17,7 @@ package org.apache.spark.storage -import java.io._ -import java.util.{HashMap => JHashMap} - -import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.util.Random - -import akka.actor.{Actor, ActorRef, ActorSystem, Props} +import akka.actor._ import scala.concurrent.Await import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits.global @@ -34,8 +27,16 @@ import scala.concurrent.duration._ import org.apache.spark.{Logging, SparkException} import org.apache.spark.storage.BlockManagerMessages._ +import org.apache.spark.storage.BlockManagerMessages.GetLocations +import org.apache.spark.storage.BlockManagerMessages.GetLocationsMultipleBlockIds +import org.apache.spark.storage.BlockManagerMessages.RegisterBlockManager +import org.apache.spark.storage.BlockManagerMessages.HeartBeat +import org.apache.spark.storage.BlockManagerMessages.RemoveExecutor +import org.apache.spark.storage.BlockManagerMessages.GetPeers +import org.apache.spark.storage.BlockManagerMessages.RemoveBlock +import org.apache.spark.storage.BlockManagerMessages.RemoveRdd -private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging { +private[spark] class BlockManagerMaster(var driverActor : Either[ActorRef, ActorSelection]) extends Logging { val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt @@ -165,7 +166,11 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi while (attempts < AKKA_RETRY_ATTEMPTS) { attempts += 1 try { - val future = driverActor.ask(message)(timeout) + val future = if (driverActor.isLeft ) { + driverActor.left.get.ask(message)(timeout) + } else { + driverActor.right.get.ask(message)(timeout) + } val result = Await.result(future, timeout) if (result == null) { throw new SparkException("BlockManagerMaster returned null") diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index f2ae8dd97dab1758308e3b36ffd1db1eb03b0955..1e6da269f29aaeda377fa47c8b7e6a293dda1593 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -93,7 +93,7 @@ private[spark] object ThreadingTest { val actorSystem = ActorSystem("test") val serializer = new KryoSerializer val blockManagerMaster = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true)))) + Left(actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))) val blockManager = new BlockManager( "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 18fb1bf590ab4639ec0328a144d641da894f3c1c..955f6cdadccc7450fcb4c654ac0cf14f38b90b49 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -49,14 +49,14 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("master start and stop") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTracker() - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker))) + tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))) tracker.stop() } test("master register and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTracker() - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker))) + tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))) tracker.registerShuffle(10, 2) val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L) @@ -75,7 +75,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("master register and unregister and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTracker() - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker))) + tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))) tracker.registerShuffle(10, 2) val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L) @@ -103,13 +103,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { System.setProperty("spark.hostPort", hostname + ":" + boundPort) val masterTracker = new MapOutputTracker() - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker") + masterTracker.trackerActor = Left(actorSystem.actorOf( + Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker")) val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0) val slaveTracker = new MapOutputTracker() - slaveTracker.trackerActor = slaveSystem.actorFor( - "akka.tcp://spark@localhost:" + boundPort + "/user/MapOutputTracker") + slaveTracker.trackerActor = Right(slaveSystem.actorSelection( + "akka.tcp://spark@localhost:" + boundPort + "/user/MapOutputTracker")) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 038a9acb8593c07f45c0e660bf0b719e97a8b888..4fdc43cc227523f6428a2d8eb1c88270f2c339c0 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -53,7 +53,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT System.setProperty("spark.hostPort", "localhost:" + boundPort) master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true)))) + Left(actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))) // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case oldArch = System.setProperty("os.arch", "amd64") diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala index 08e399f9ee16803d10b90f5d70ead4a27b9af99b..128711aacda76624e571a75225bd5e2779dd7f1f 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala @@ -86,7 +86,7 @@ class FeederActor extends Actor { class SampleActorReceiver[T: ClassTag](urlOfPublisher: String) extends Actor with Receiver { - lazy private val remotePublisher = context.actorFor(urlOfPublisher) + lazy private val remotePublisher = context.actorSelection(urlOfPublisher) override def preStart = remotePublisher ! SubscribeReceiver(context.self) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala index 394a39fbb0af427274ba39c0c9765f79e8773806..b2f9f8b224daa99ca01301938f86812baa5f195d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala @@ -178,7 +178,7 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging val ip = System.getProperty("spark.driver.host", "localhost") val port = System.getProperty("spark.driver.port", "7077").toInt val url = "akka.tcp://spark@%s:%s/user/NetworkInputTracker".format(ip, port) - val tracker = env.actorSystem.actorFor(url) + val tracker = env.actorSystem.actorSelection(url) val timeout = 5.seconds override def preStart() {