diff --git a/core/src/main/scala/spark/deploy/master/Master.scala b/core/src/main/scala/spark/deploy/master/Master.scala index 8fdecfa08c99639eec1ebc029eddd690dabe0bb4..cb0208e0b6f192eaf9bbb09134ee6c8165b47e8f 100644 --- a/core/src/main/scala/spark/deploy/master/Master.scala +++ b/core/src/main/scala/spark/deploy/master/Master.scala @@ -2,12 +2,19 @@ package spark.deploy.master import scala.collection.mutable.HashMap -import akka.actor.{Terminated, ActorRef, Props, Actor} +import akka.actor._ import spark.{Logging, Utils} import spark.util.AkkaUtils import java.text.SimpleDateFormat import java.util.Date import spark.deploy.{RegisteredSlave, RegisterSlave} +import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected} +import akka.remote.RemoteClientShutdown +import spark.deploy.RegisteredSlave +import akka.remote.RemoteClientDisconnected +import akka.actor.Terminated +import scala.Some +import spark.deploy.RegisterSlave class SlaveInfo( val id: Int, @@ -30,10 +37,12 @@ class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging { var nextJobId = 0 val slaves = new HashMap[Int, SlaveInfo] val actorToSlave = new HashMap[ActorRef, SlaveInfo] + val addressToSlave = new HashMap[Address, SlaveInfo] override def preStart() { logInfo("Starting Spark master at spark://" + ip + ":" + port) logInfo("Cluster ID: " + clusterId) + context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) startWebUi() } @@ -52,24 +61,37 @@ class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging { case RegisterSlave(host, slavePort, cores, memory) => { logInfo("Registering slave %s:%d with %d cores, %s RAM".format( host, slavePort, cores, Utils.memoryMegabytesToString(memory))) - val id = newSlaveId() - slaves(id) = new SlaveInfo(id, host, slavePort, cores, memory, sender) - actorToSlave(sender) = slaves(id) - context.watch(sender) - sender ! RegisteredSlave(clusterId, id) + val slave = addSlave(host, slavePort, cores, memory) + context.watch(sender) // This doesn't work with remote actors but helps for testing + sender ! RegisteredSlave(clusterId, slave.id) } - case Terminated(actor) => { + case RemoteClientDisconnected(transport, address) => + logInfo("Remote client disconnected: " + address) + addressToSlave.get(address).foreach(s => removeSlave(s)) // Remove slave, if any, at address + + case RemoteClientShutdown(transport, address) => + logInfo("Remote client shutdown: " + address) + addressToSlave.get(address).foreach(s => removeSlave(s)) // Remove slave, if any, at address + + case Terminated(actor) => logInfo("Slave disconnected: " + actor) - actorToSlave.get(actor) match { - case Some(slave) => - logInfo("Removing slave " + slave.id) - slaves -= slave.id - actorToSlave -= actor - case None => - logError("Did not have any slave registered for " + actor) - } - } + actorToSlave.get(actor).foreach(s => removeSlave(s)) // Remove slave, if any, at actor + } + + def addSlave(host: String, slavePort: Int, cores: Int, memory: Int): SlaveInfo = { + val slave = new SlaveInfo(newSlaveId(), host, slavePort, cores, memory, sender) + slaves(slave.id) = slave + actorToSlave(sender) = slave + addressToSlave(sender.path.address) = slave + return slave + } + + def removeSlave(slave: SlaveInfo) { + logInfo("Removing slave " + slave.id + " on " + slave.host + ":" + slave.port) + slaves -= slave.id + actorToSlave -= slave.actor + addressToSlave -= slave.actor.path.address } def newClusterId(): String = { diff --git a/core/src/main/scala/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/spark/deploy/master/MasterArguments.scala index ca4b8a143f2be40035ce18456822f716734387c0..5d975cd546ac8aecf2be653731d0e76373e244c2 100644 --- a/core/src/main/scala/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/spark/deploy/master/MasterArguments.scala @@ -41,11 +41,11 @@ class MasterArguments(args: Array[String]) { def printUsageAndExit(exitCode: Int) { System.err.println( "Usage: spark-master [options]\n" + - "\n" + - "Options:\n" + - " -i IP, --ip IP IP address or DNS name to listen on\n" + - " -p PORT, --port PORT Port to listen on (default: 7077)\n" + - " --webui-port PORT Port for web UI (default: 8080)") + "\n" + + "Options:\n" + + " -i IP, --ip IP IP address or DNS name to listen on\n" + + " -p PORT, --port PORT Port to listen on (default: 7077)\n" + + " --webui-port PORT Port for web UI (default: 8080)") System.exit(exitCode) } } \ No newline at end of file diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index fd49223798d5c23dc0356a43e943b4a2b359145d..22b070658d8896dedd9e55d3a7d3b4f4cba8af9e 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -1,17 +1,24 @@ package spark.deploy.worker -import scala.collection.mutable.HashMap -import akka.actor.{Terminated, ActorRef, Props, Actor} -import spark.{Logging, Utils} -import spark.util.AkkaUtils -import java.text.SimpleDateFormat -import java.util.Date -import spark.deploy.{RegisteredSlave, RegisterSlave} +import akka.actor.{ActorRef, Terminated, Props, Actor} +import akka.pattern.ask +import akka.util.duration._ +import spark.{SparkException, Logging, Utils} +import spark.util.{IntParam, AkkaUtils} +import spark.deploy.{RegisterSlave, RegisteredSlave} +import akka.dispatch.Await +import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent} -class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int) +class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, masterUrl: String) extends Actor with Logging { + val MASTER_REGEX = "spark://([^:]+):([0-9]+)".r + + var master: ActorRef = null + var clusterId: String = null + var slaveId: Int = 0 + var coresUsed = 0 var memoryUsed = 0 @@ -21,9 +28,32 @@ class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int) override def preStart() { logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( ip, port, cores, Utils.memoryMegabytesToString(memory))) + connectToMaster() startWebUi() } + def connectToMaster() { + masterUrl match { + case MASTER_REGEX(masterHost, masterPort) => + logInfo("Connecting to master spark://" + masterHost + ":" + masterPort) + val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort) + try { + master = context.actorFor(akkaUrl) + master ! RegisterSlave(ip, port, cores, memory) + context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + context.watch(master) // Doesn't work with remote actors, but useful for testing + } catch { + case e: Exception => + logError("Failed to connect to master", e) + System.exit(1) + } + + case _ => + logError("Invalid master URL: " + masterUrl) + System.exit(1) + } + } + def startWebUi() { val webUi = new WorkerWebUI(context.system, self) try { @@ -36,13 +66,25 @@ class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int) } override def receive = { - case RegisteredSlave(clusterId, slaveId) => { - logInfo("Registered with cluster ID " + clusterId + ", slave ID " + slaveId) - } + case RegisteredSlave(clusterId_, slaveId_) => + this.clusterId = clusterId_ + this.slaveId = slaveId_ + logInfo("Registered with master, cluster ID = " + clusterId + ", slave ID = " + slaveId) - case Terminated(actor) => { - logError("Master disconnected!") - } + case RemoteClientDisconnected(_, _) => + masterDisconnected() + + case RemoteClientShutdown(_, _) => + masterDisconnected() + + case Terminated(_) => + masterDisconnected() + } + + def masterDisconnected() { + // Not sure what to do here exactly, so just shut down for now. + logError("Connection to master failed! Shutting down.") + System.exit(1) } } @@ -51,7 +93,7 @@ object Worker { val args = new WorkerArguments(argStrings) val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port) val actor = actorSystem.actorOf( - Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory)), + Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory, args.master)), name = "Worker") actorSystem.awaitTermination() } diff --git a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala index cd112b7fa387783fb1061307f681488aefb235f1..ab764aa87762aebc377c968e42dbeee81e42845d 100644 --- a/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/spark/deploy/worker/WorkerArguments.scala @@ -14,6 +14,7 @@ class WorkerArguments(args: Array[String]) { var webUiPort = 8081 var cores = inferDefaultCores() var memory = inferDefaultMemory() + var master: String = null parse(args.toList) @@ -41,7 +42,17 @@ class WorkerArguments(args: Array[String]) { case ("--help" | "-h") :: tail => printUsageAndExit(0) - case Nil => {} + case value :: tail => + if (master != null) { // Two positional arguments were given + printUsageAndExit(1) + } + master = value + parse(tail) + + case Nil => + if (master == null) { // No positional argument was given + printUsageAndExit(1) + } case _ => printUsageAndExit(1) @@ -52,14 +63,16 @@ class WorkerArguments(args: Array[String]) { */ def printUsageAndExit(exitCode: Int) { System.err.println( - "Usage: spark-worker [options]\n" + - "\n" + - "Options:\n" + - " -c CORES, --cores CORES Number of cores to use\n" + - " -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" + - " -i IP, --ip IP IP address or DNS name to listen on\n" + - " -p PORT, --port PORT Port to listen on (default: random)\n" + - " --webui-port PORT Port for web UI (default: 8081)") + "Usage: spark-worker [options] <master>\n" + + "\n" + + "Master must be a URL of the form spark://hostname:port\n" + + "\n" + + "Options:\n" + + " -c CORES, --cores CORES Number of cores to use\n" + + " -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" + + " -i IP, --ip IP IP address or DNS name to listen on\n" + + " -p PORT, --port PORT Port to listen on (default: random)\n" + + " --webui-port PORT Port for web UI (default: 8081)") System.exit(exitCode) } diff --git a/core/src/main/scala/spark/util/AkkaUtils.scala b/core/src/main/scala/spark/util/AkkaUtils.scala index 84e942e5b7d889faee3d4b8395feb6b841fcfae8..3cf12ebe0e65cd9c784ab1c513b4238ecb9ebda8 100644 --- a/core/src/main/scala/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/spark/util/AkkaUtils.scala @@ -30,6 +30,7 @@ object AkkaUtils { akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" akka.remote.netty.hostname = "%s" akka.remote.netty.port = %d + akka.remote.netty.connection-timeout = 1s """.format(host, port)) val actorSystem = ActorSystem("spark", akkaConf, getClass.getClassLoader) @@ -39,8 +40,6 @@ object AkkaUtils { val provider = actorSystem.asInstanceOf[ActorSystemImpl].provider val boundPort = provider.asInstanceOf[RemoteActorRefProvider].transport.address.port.get return (actorSystem, boundPort) - - return (null, 0) } /**