diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 398e0936906a3077a6f1c289ebb880dd72ded3aa..23ae9360f6a22bb6aa383d4ba63ce09636fca473 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -252,7 +252,8 @@ object SparkEnv extends Logging { // Create the ActorSystem for Akka and get the port it binds to. val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName - val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager) + val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager, + clientMode = !isDriver) val actorSystem: ActorSystem = if (rpcEnv.isInstanceOf[AkkaRpcEnv]) { rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem @@ -262,9 +263,11 @@ object SparkEnv extends Logging { } // Figure out which port Akka actually bound to in case the original port is 0 or occupied. + // In the non-driver case, the RPC env's address may be null since it may not be listening + // for incoming connections. if (isDriver) { conf.set("spark.driver.port", rpcEnv.address.port.toString) - } else { + } else if (rpcEnv.address != null) { conf.set("spark.executor.port", rpcEnv.address.port.toString) } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index a9c6a05ecd4345133d2fe293db096d0a7e28efd3..c2ebf30596215f270ca9634729a33a26d781e873 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -45,8 +45,6 @@ private[spark] class CoarseGrainedExecutorBackend( env: SparkEnv) extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging { - Utils.checkHostPort(hostPort, "Expected hostport") - var executor: Executor = null @volatile var driver: Option[RpcEndpointRef] = None @@ -80,9 +78,8 @@ private[spark] class CoarseGrainedExecutorBackend( } override def receive: PartialFunction[Any, Unit] = { - case RegisteredExecutor => + case RegisteredExecutor(hostname) => logInfo("Successfully registered with driver") - val (hostname, _) = Utils.parseHostPort(hostPort) executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false) case RegisterExecutorFailed(message) => @@ -163,7 +160,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { hostname, port, executorConf, - new SecurityManager(executorConf)) + new SecurityManager(executorConf), + clientMode = true) val driver = fetcher.setupEndpointRefByURI(driverUrl) val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++ Seq[(String, String)](("spark.app.id", appId)) @@ -188,12 +186,12 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { val env = SparkEnv.createExecutorEnv( driverConf, executorId, hostname, port, cores, isLocal = false) - // SparkEnv sets spark.driver.port so it shouldn't be 0 anymore. - val boundPort = env.conf.getInt("spark.executor.port", 0) - assert(boundPort != 0) - - // Start the CoarseGrainedExecutorBackend endpoint. - val sparkHostPort = hostname + ":" + boundPort + // SparkEnv will set spark.executor.port if the rpc env is listening for incoming + // connections (e.g., if it's using akka). Otherwise, the executor is running in + // client mode only, and does not accept incoming connections. + val sparkHostPort = env.conf.getOption("spark.executor.port").map { port => + hostname + ":" + port + }.orNull env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env)) workerUrl.foreach { url => diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 2c4a8b9a0a87821b192d9b622287a063d5dfee40..a560fd10cdf76c9140723aac04949f07453477e0 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -43,9 +43,10 @@ private[spark] object RpcEnv { host: String, port: Int, conf: SparkConf, - securityManager: SecurityManager): RpcEnv = { + securityManager: SecurityManager, + clientMode: Boolean = false): RpcEnv = { // Using Reflection to create the RpcEnv to avoid to depend on Akka directly - val config = RpcEnvConfig(conf, name, host, port, securityManager) + val config = RpcEnvConfig(conf, name, host, port, securityManager, clientMode) getRpcEnvFactory(conf).create(config) } } @@ -139,4 +140,5 @@ private[spark] case class RpcEnvConfig( name: String, host: String, port: Int, - securityManager: SecurityManager) + securityManager: SecurityManager, + clientMode: Boolean) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 7bf44a6565b618d20327101dc31ae047ac4f4d88..eb25d6c7b721b15e3cd8db68e0d851f6c88c62f7 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -55,7 +55,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { private var stopped = false def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = { - val addr = new RpcEndpointAddress(nettyEnv.address.host, nettyEnv.address.port, name) + val addr = RpcEndpointAddress(nettyEnv.address, name) val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv) synchronized { if (stopped) { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 284284eb805b7ca92067e474e352df968182ce14..09093819bb22c232aa31a367819eb7f7c541969a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -17,10 +17,12 @@ package org.apache.spark.rpc.netty import java.io._ +import java.lang.{Boolean => JBoolean} import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean +import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -29,6 +31,7 @@ import scala.reflect.ClassTag import scala.util.{DynamicVariable, Failure, Success} import scala.util.control.NonFatal +import com.google.common.base.Preconditions import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.network.TransportContext import org.apache.spark.network.client._ @@ -45,15 +48,14 @@ private[netty] class NettyRpcEnv( host: String, securityManager: SecurityManager) extends RpcEnv(conf) with Logging { - // Override numConnectionsPerPeer to 1 for RPC. private val transportConf = SparkTransportConf.fromSparkConf( conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"), conf.getInt("spark.rpc.io.threads", 0)) private val dispatcher: Dispatcher = new Dispatcher(this) - private val transportContext = - new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this)) + private val transportContext = new TransportContext(transportConf, + new NettyRpcHandler(dispatcher, this)) private val clientFactory = { val bootstraps: java.util.List[TransportClientBootstrap] = @@ -95,7 +97,7 @@ private[netty] class NettyRpcEnv( } } - def start(port: Int): Unit = { + def startServer(port: Int): Unit = { val bootstraps: java.util.List[TransportServerBootstrap] = if (securityManager.isAuthenticationEnabled()) { java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager)) @@ -107,9 +109,9 @@ private[netty] class NettyRpcEnv( RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher)) } + @Nullable override lazy val address: RpcAddress = { - require(server != null, "NettyRpcEnv has not yet started") - RpcAddress(host, server.getPort) + if (server != null) RpcAddress(host, server.getPort()) else null } override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { @@ -120,7 +122,7 @@ private[netty] class NettyRpcEnv( val addr = RpcEndpointAddress(uri) val endpointRef = new NettyRpcEndpointRef(conf, addr, this) val verifier = new NettyRpcEndpointRef( - conf, RpcEndpointAddress(addr.host, addr.port, RpcEndpointVerifier.NAME), this) + conf, RpcEndpointAddress(addr.rpcAddress, RpcEndpointVerifier.NAME), this) verifier.ask[Boolean](RpcEndpointVerifier.CheckExistence(endpointRef.name)).flatMap { find => if (find) { Future.successful(endpointRef) @@ -135,28 +137,34 @@ private[netty] class NettyRpcEnv( dispatcher.stop(endpointRef) } - private def postToOutbox(address: RpcAddress, message: OutboxMessage): Unit = { - val targetOutbox = { - val outbox = outboxes.get(address) - if (outbox == null) { - val newOutbox = new Outbox(this, address) - val oldOutbox = outboxes.putIfAbsent(address, newOutbox) - if (oldOutbox == null) { - newOutbox + private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = { + if (receiver.client != null) { + receiver.client.sendRpc(message.content, message.createCallback(receiver.client)); + } else { + require(receiver.address != null, + "Cannot send message to client endpoint with no listen address.") + val targetOutbox = { + val outbox = outboxes.get(receiver.address) + if (outbox == null) { + val newOutbox = new Outbox(this, receiver.address) + val oldOutbox = outboxes.putIfAbsent(receiver.address, newOutbox) + if (oldOutbox == null) { + newOutbox + } else { + oldOutbox + } } else { - oldOutbox + outbox } + } + if (stopped.get) { + // It's possible that we put `targetOutbox` after stopping. So we need to clean it. + outboxes.remove(receiver.address) + targetOutbox.stop() } else { - outbox + targetOutbox.send(message) } } - if (stopped.get) { - // It's possible that we put `targetOutbox` after stopping. So we need to clean it. - outboxes.remove(address) - targetOutbox.stop() - } else { - targetOutbox.send(message) - } } private[netty] def send(message: RequestMessage): Unit = { @@ -174,17 +182,14 @@ private[netty] class NettyRpcEnv( }(ThreadUtils.sameThread) } else { // Message to a remote RPC endpoint. - postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback { - - override def onFailure(e: Throwable): Unit = { + postToOutbox(message.receiver, OutboxMessage(serialize(message), + (e) => { logWarning(s"Exception when sending $message", e) - } - - override def onSuccess(response: Array[Byte]): Unit = { - val ack = deserialize[Ack](response) + }, + (client, response) => { + val ack = deserialize[Ack](client, response) logDebug(s"Receive ack from ${ack.sender}") - } - })) + })) } } @@ -214,16 +219,14 @@ private[netty] class NettyRpcEnv( } }(ThreadUtils.sameThread) } else { - postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback { - - override def onFailure(e: Throwable): Unit = { + postToOutbox(message.receiver, OutboxMessage(serialize(message), + (e) => { if (!promise.tryFailure(e)) { logWarning("Ignore Exception", e) } - } - - override def onSuccess(response: Array[Byte]): Unit = { - val reply = deserialize[AskResponse](response) + }, + (client, response) => { + val reply = deserialize[AskResponse](client, response) if (reply.reply.isInstanceOf[RpcFailure]) { if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { logWarning(s"Ignore failure: ${reply.reply}") @@ -231,8 +234,7 @@ private[netty] class NettyRpcEnv( } else if (!promise.trySuccess(reply.reply)) { logWarning(s"Ignore message: ${reply}") } - } - })) + })) } promise.future } @@ -243,9 +245,11 @@ private[netty] class NettyRpcEnv( buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit) } - private[netty] def deserialize[T: ClassTag](bytes: Array[Byte]): T = { - deserialize { () => - javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes)) + private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: Array[Byte]): T = { + NettyRpcEnv.currentClient.withValue(client) { + deserialize { () => + javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes)) + } } } @@ -254,7 +258,7 @@ private[netty] class NettyRpcEnv( } override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = - new RpcEndpointAddress(address.host, address.port, endpointName).toString + new RpcEndpointAddress(address, endpointName).toString override def shutdown(): Unit = { cleanup() @@ -297,6 +301,7 @@ private[netty] class NettyRpcEnv( deserializationAction() } } + } private[netty] object NettyRpcEnv extends Logging { @@ -312,6 +317,13 @@ private[netty] object NettyRpcEnv extends Logging { * }}} */ private[netty] val currentEnv = new DynamicVariable[NettyRpcEnv](null) + + /** + * Similar to `currentEnv`, this variable references the client instance associated with an + * RPC, in case it's needed to find out the remote address during deserialization. + */ + private[netty] val currentClient = new DynamicVariable[TransportClient](null) + } private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { @@ -324,47 +336,68 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance] val nettyEnv = new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager) - val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => - nettyEnv.start(actualPort) - (nettyEnv, actualPort) - } - try { - Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1 - } catch { - case NonFatal(e) => - nettyEnv.shutdown() - throw e + if (!config.clientMode) { + val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => + nettyEnv.startServer(actualPort) + (nettyEnv, actualPort) + } + try { + Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1 + } catch { + case NonFatal(e) => + nettyEnv.shutdown() + throw e + } } + nettyEnv } } -private[netty] class NettyRpcEndpointRef(@transient private val conf: SparkConf) +/** + * The NettyRpcEnv version of RpcEndpointRef. + * + * This class behaves differently depending on where it's created. On the node that "owns" the + * RpcEndpoint, it's a simple wrapper around the RpcEndpointAddress instance. + * + * On other machines that receive a serialized version of the reference, the behavior changes. The + * instance will keep track of the TransportClient that sent the reference, so that messages + * to the endpoint are sent over the client connection, instead of needing a new connection to + * be opened. + * + * The RpcAddress of this ref can be null; what that means is that the ref can only be used through + * a client connection, since the process hosting the endpoint is not listening for incoming + * connections. These refs should not be shared with 3rd parties, since they will not be able to + * send messages to the endpoint. + * + * @param conf Spark configuration. + * @param endpointAddress The address where the endpoint is listening. + * @param nettyEnv The RpcEnv associated with this ref. + * @param local Whether the referenced endpoint lives in the same process. + */ +private[netty] class NettyRpcEndpointRef( + @transient private val conf: SparkConf, + endpointAddress: RpcEndpointAddress, + @transient @volatile private var nettyEnv: NettyRpcEnv) extends RpcEndpointRef(conf) with Serializable with Logging { - @transient @volatile private var nettyEnv: NettyRpcEnv = _ + @transient @volatile var client: TransportClient = _ - @transient @volatile private var _address: RpcEndpointAddress = _ + private val _address = if (endpointAddress.rpcAddress != null) endpointAddress else null + private val _name = endpointAddress.name - def this(conf: SparkConf, _address: RpcEndpointAddress, nettyEnv: NettyRpcEnv) { - this(conf) - this._address = _address - this.nettyEnv = nettyEnv - } - - override def address: RpcAddress = _address.toRpcAddress + override def address: RpcAddress = if (_address != null) _address.rpcAddress else null private def readObject(in: ObjectInputStream): Unit = { in.defaultReadObject() - _address = in.readObject().asInstanceOf[RpcEndpointAddress] nettyEnv = NettyRpcEnv.currentEnv.value + client = NettyRpcEnv.currentClient.value } private def writeObject(out: ObjectOutputStream): Unit = { out.defaultWriteObject() - out.writeObject(_address) } - override def name: String = _address.name + override def name: String = _name override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { val promise = Promise[Any]() @@ -429,41 +462,43 @@ private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessa private[netty] case class RpcFailure(e: Throwable) /** - * Maintain the mapping relations between client addresses and [[RpcEnv]] addresses, broadcast - * network events and forward messages to [[Dispatcher]]. + * Dispatches incoming RPCs to registered endpoints. + * + * The handler keeps track of all client instances that communicate with it, so that the RpcEnv + * knows which `TransportClient` instance to use when sending RPCs to a client endpoint (i.e., + * one that is not listening for incoming connections, but rather needs to be contacted via the + * client socket). + * + * Events are sent on a per-connection basis, so if a client opens multiple connections to the + * RpcEnv, multiple connection / disconnection events will be created for that client (albeit + * with different `RpcAddress` information). */ private[netty] class NettyRpcHandler( dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging { - private type ClientAddress = RpcAddress - private type RemoteEnvAddress = RpcAddress - - // Store all client addresses and their NettyRpcEnv addresses. - // TODO: Is this even necessary? - @GuardedBy("this") - private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]() + // TODO: Can we add connection callback (channel registered) to the underlying framework? + // A variable to track whether we should dispatch the RemoteProcessConnected message. + private val clients = new ConcurrentHashMap[TransportClient, JBoolean]() override def receive( - client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = { - val requestMessage = nettyEnv.deserialize[RequestMessage](message) - val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] + client: TransportClient, + message: Array[Byte], + callback: RpcResponseCallback): Unit = { + val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) - val remoteEnvAddress = requestMessage.senderAddress val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - - // TODO: Can we add connection callback (channel registered) to the underlying framework? - // A variable to track whether we should dispatch the RemoteProcessConnected message. - var dispatchRemoteProcessConnected = false - synchronized { - if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) { - // clientAddr connects at the first time, fire "RemoteProcessConnected" - dispatchRemoteProcessConnected = true - } + if (clients.putIfAbsent(client, JBoolean.TRUE) == null) { + dispatcher.postToAll(RemoteProcessConnected(clientAddr)) } - if (dispatchRemoteProcessConnected) { - dispatcher.postToAll(RemoteProcessConnected(remoteEnvAddress)) - } - dispatcher.postRemoteMessage(requestMessage, callback) + val requestMessage = nettyEnv.deserialize[RequestMessage](client, message) + val messageToDispatch = if (requestMessage.senderAddress == null) { + // Create a new message with the socket address of the client as the sender. + RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content, + requestMessage.needReply) + } else { + requestMessage + } + dispatcher.postRemoteMessage(messageToDispatch, callback) } override def getStreamManager: StreamManager = new OneForOneStreamManager @@ -472,15 +507,7 @@ private[netty] class NettyRpcHandler( val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - val broadcastMessage = - synchronized { - remoteAddresses.get(clientAddr).map(RemoteProcessConnectionError(cause, _)) - } - if (broadcastMessage.isEmpty) { - logError(cause.getMessage, cause) - } else { - dispatcher.postToAll(broadcastMessage.get) - } + dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr)) } else { // If the channel is closed before connecting, its remoteAddress will be null. // See java.net.Socket.getRemoteSocketAddress @@ -493,15 +520,9 @@ private[netty] class NettyRpcHandler( val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + clients.remove(client) nettyEnv.removeOutbox(clientAddr) - val messageOpt: Option[RemoteProcessDisconnected] = - synchronized { - remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress => - remoteAddresses -= clientAddr - Some(RemoteProcessDisconnected(remoteEnvAddress)) - } - } - messageOpt.foreach(dispatcher.postToAll) + dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) } else { // If the channel is closed before connecting, its remoteAddress will be null. In this case, // we can ignore it since we don't fire "Associated". diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala index 7d9d593b362412e7d075c7abc8ebd19c36624aff..2f6817f2eb935f67008daa1f37b961777bf82883 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -26,7 +26,21 @@ import org.apache.spark.SparkException import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.rpc.RpcAddress -private[netty] case class OutboxMessage(content: Array[Byte], callback: RpcResponseCallback) +private[netty] case class OutboxMessage(content: Array[Byte], + _onFailure: (Throwable) => Unit, + _onSuccess: (TransportClient, Array[Byte]) => Unit) { + + def createCallback(client: TransportClient): RpcResponseCallback = new RpcResponseCallback() { + override def onFailure(e: Throwable): Unit = { + _onFailure(e) + } + + override def onSuccess(response: Array[Byte]): Unit = { + _onSuccess(client, response) + } + } + +} private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { @@ -68,7 +82,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { } } if (dropped) { - message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped")) + message._onFailure(new SparkException("Message is dropped because Outbox is stopped")) } else { drainOutbox() } @@ -108,7 +122,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { try { val _client = synchronized { client } if (_client != null) { - _client.sendRpc(message.content, message.callback) + _client.sendRpc(message.content, message.createCallback(_client)) } else { assert(stopped == true) } @@ -181,7 +195,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { // update messages and it's safe to just drain the queue. var message = messages.poll() while (message != null) { - message.callback.onFailure(e) + message._onFailure(e) message = messages.poll() } assert(messages.isEmpty) @@ -215,7 +229,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { // update messages and it's safe to just drain the queue. var message = messages.poll() while (message != null) { - message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped")) + message._onFailure(new SparkException("Message is dropped because Outbox is stopped")) message = messages.poll() } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala index 87b62369368174fcfc4b26bc61730fc88cb5adb4..d2e94f943aba5f13104a56e99b019dd2ae2cec20 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala @@ -23,15 +23,25 @@ import org.apache.spark.rpc.RpcAddress /** * An address identifier for an RPC endpoint. * - * @param host host name of the remote process. - * @param port the port the remote RPC environment binds to. - * @param name name of the remote endpoint. + * The `rpcAddress` may be null, in which case the endpoint is registered via a client-only + * connection and can only be reached via the client that sent the endpoint reference. + * + * @param rpcAddress The socket address of the endpint. + * @param name Name of the endpoint. */ -private[netty] case class RpcEndpointAddress(host: String, port: Int, name: String) { +private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) { + + require(name != null, "RpcEndpoint name must be provided.") - def toRpcAddress: RpcAddress = RpcAddress(host, port) + def this(host: String, port: Int, name: String) = { + this(RpcAddress(host, port), name) + } - override val toString = s"spark://$name@$host:$port" + override val toString = if (rpcAddress != null) { + s"spark://$name@${rpcAddress.host}:${rpcAddress.port}" + } else { + s"spark-client://$name" + } } private[netty] object RpcEndpointAddress { @@ -51,7 +61,7 @@ private[netty] object RpcEndpointAddress { uri.getQuery != null) { throw new SparkException("Invalid Spark URL: " + sparkUrl) } - RpcEndpointAddress(host, port, name) + new RpcEndpointAddress(host, port, name) } catch { case e: java.net.URISyntaxException => throw new SparkException("Invalid Spark URL: " + sparkUrl, e) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 8103efa7302e7f8270c3505bb9cf61331e9fc57b..f3d0d8547677226dc285a4ada2542232f376530c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -38,7 +38,7 @@ private[spark] object CoarseGrainedClusterMessages { sealed trait RegisterExecutorResponse - case object RegisteredExecutor extends CoarseGrainedClusterMessage + case class RegisteredExecutor(hostname: String) extends CoarseGrainedClusterMessage with RegisterExecutorResponse case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage @@ -51,9 +51,7 @@ private[spark] object CoarseGrainedClusterMessages { hostPort: String, cores: Int, logUrls: Map[String, String]) - extends CoarseGrainedClusterMessage { - Utils.checkHostPort(hostPort, "Expected host port") - } + extends CoarseGrainedClusterMessage case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, data: SerializableBuffer) extends CoarseGrainedClusterMessage @@ -107,8 +105,4 @@ private[spark] object CoarseGrainedClusterMessages { // Used internally by executors to shut themselves down. case object Shutdown extends CoarseGrainedClusterMessage - // SPARK-10987: workaround for netty RPC issue; forces a connection from the driver back - // to the AM. - case object DriverHello extends CoarseGrainedClusterMessage - } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 55a564b5c8eac3a6a896855342cff1e1bef88245..439a11927026b6161309b306ede70f58be1be785 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -131,16 +131,22 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RegisterExecutor(executorId, executorRef, hostPort, cores, logUrls) => - Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorDataMap.contains(executorId)) { context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) } else { - logInfo("Registered executor: " + executorRef + " with ID " + executorId) - addressToExecutorId(executorRef.address) = executorId + // If the executor's rpc env is not listening for incoming connections, `hostPort` + // will be null, and the client connection should be used to contact the executor. + val executorAddress = if (executorRef.address != null) { + executorRef.address + } else { + context.senderAddress + } + logInfo(s"Registered executor $executorRef ($executorAddress) with ID $executorId") + addressToExecutorId(executorAddress) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) - val (host, _) = Utils.parseHostPort(hostPort) - val data = new ExecutorData(executorRef, executorRef.address, host, cores, cores, logUrls) + val data = new ExecutorData(executorRef, executorRef.address, executorAddress.host, + cores, cores, logUrls) // This must be synchronized because variables mutated // in this block are read when requesting executors CoarseGrainedSchedulerBackend.this.synchronized { @@ -151,7 +157,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } // Note: some tests expect the reply to come after we put the executor in the map - context.reply(RegisteredExecutor) + context.reply(RegisteredExecutor(executorAddress.host)) listenerBus.post( SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) makeOffers() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index e483688edef5f68f75b4b57cf5eade2bc7c39e5f..cb24072d7d94152e899287740b27de4bc9396441 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -170,8 +170,6 @@ private[spark] abstract class YarnSchedulerBackend( case RegisterClusterManager(am) => logInfo(s"ApplicationMaster registered as $am") amEndpoint = Option(am) - // See SPARK-10987. - am.send(DriverHello) case AddWebUIFilter(filterName, filterParams, proxyBase) => addWebUIFilter(filterName, filterParams, proxyBase) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 3bead6395d3843d59507d43c2e1f0ae4c676033b..834e4743df866ce7e22c2466142b5d2e47b6063d 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -48,7 +48,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } - def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv + def createRpcEnv(conf: SparkConf, name: String, port: Int, clientMode: Boolean = false): RpcEnv test("send a message locally") { @volatile var message: String = null @@ -76,7 +76,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely") try { @@ -130,7 +130,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely") try { @@ -158,7 +158,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val shortProp = "spark.rpc.short.timeout" conf.set("spark.rpc.retry.wait", "0") conf.set("spark.rpc.numRetries", "1") - val anotherEnv = createRpcEnv(conf, "remote", 13345) + val anotherEnv = createRpcEnv(conf, "remote", 13345, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") try { @@ -417,7 +417,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely") try { @@ -457,7 +457,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-remotely-error") @@ -497,26 +497,40 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "network-events") val remoteAddress = anotherEnv.address rpcEndpointRef.send("hello") eventually(timeout(5 seconds), interval(5 millis)) { - assert(events === List(("onConnected", remoteAddress))) + // anotherEnv is connected in client mode, so the remote address may be unknown depending on + // the implementation. Account for that when doing checks. + if (remoteAddress != null) { + assert(events === List(("onConnected", remoteAddress))) + } else { + assert(events.size === 1) + assert(events(0)._1 === "onConnected") + } } anotherEnv.shutdown() anotherEnv.awaitTermination() eventually(timeout(5 seconds), interval(5 millis)) { - assert(events === List( - ("onConnected", remoteAddress), - ("onNetworkError", remoteAddress), - ("onDisconnected", remoteAddress)) || - events === List( - ("onConnected", remoteAddress), - ("onDisconnected", remoteAddress))) + // Account for anotherEnv not having an address due to running in client mode. + if (remoteAddress != null) { + assert(events === List( + ("onConnected", remoteAddress), + ("onNetworkError", remoteAddress), + ("onDisconnected", remoteAddress)) || + events === List( + ("onConnected", remoteAddress), + ("onDisconnected", remoteAddress))) + } else { + val eventNames = events.map(_._1) + assert(eventNames === List("onConnected", "onNetworkError", "onDisconnected") || + eventNames === List("onConnected", "onDisconnected")) + } } } @@ -529,7 +543,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef( "local", env.address, "sendWithReply-unserializable-error") @@ -558,7 +572,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.authenticate.secret", "good") val localEnv = createRpcEnv(conf, "authentication-local", 13345) - val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true) try { @volatile var message: String = null @@ -589,7 +603,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.authenticate.secret", "good") val localEnv = createRpcEnv(conf, "authentication-local", 13345) - val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345, clientMode = true) try { localEnv.setupEndpoint("ask-authentication", new RpcEndpoint { diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala index 4aa75c9230b2c0114791b3a925bba3c9e3d4e2cb..6478ab51c4da2a247f9b5ecc72eb9fa06357d574 100644 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -22,9 +22,12 @@ import org.apache.spark.{SSLSampleConfigs, SecurityManager, SparkConf} class AkkaRpcEnvSuite extends RpcEnvSuite { - override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = { + override def createRpcEnv(conf: SparkConf, + name: String, + port: Int, + clientMode: Boolean = false): RpcEnv = { new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf))) + RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf), clientMode)) } test("setupEndpointRef: systemName, address, endpointName") { @@ -37,7 +40,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { }) val conf = new SparkConf() val newRpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf))) + RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf), false)) try { val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint") assert(s"akka.tcp://local@${env.address}/user/test_endpoint" === @@ -56,7 +59,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { val conf = SSLSampleConfigs.sparkSSLConfig() val securityManager = new SecurityManager(conf) val rpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 12346, securityManager)) + RpcEnvConfig(conf, "test", "localhost", 12346, securityManager, false)) try { val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala index 973a07a0bde3a813bc912ec162765b6c209bc994..56743ba650b418e3484d18e6faaf8334644d23ec 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala @@ -22,8 +22,13 @@ import org.apache.spark.SparkFunSuite class NettyRpcAddressSuite extends SparkFunSuite { test("toString") { - val addr = RpcEndpointAddress("localhost", 12345, "test") + val addr = new RpcEndpointAddress("localhost", 12345, "test") assert(addr.toString === "spark://test@localhost:12345") } + test("toString for client mode") { + val addr = RpcEndpointAddress(null, "test") + assert(addr.toString === "spark-client://test") + } + } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala index be19668e17c04a0d3155779120d153fb2206e38e..ce83087ec04d6070a57fea1d52845d1f22d6ae4b 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -22,8 +22,13 @@ import org.apache.spark.rpc._ class NettyRpcEnvSuite extends RpcEnvSuite { - override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = { - val config = RpcEnvConfig(conf, "test", "localhost", port, new SecurityManager(conf)) + override def createRpcEnv( + conf: SparkConf, + name: String, + port: Int, + clientMode: Boolean = false): RpcEnv = { + val config = RpcEnvConfig(conf, "test", "localhost", port, new SecurityManager(conf), + clientMode) new NettyRpcEnvFactory().create(config) } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index 5430e4c0c4d6cd886e34916d97be13e6e8c2a3a7..f9d8e80c98b669fd5ffb9439d6d9250ede8fc6fa 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.rpc._ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) - when(env.deserialize(any(classOf[Array[Byte]]))(any())). + when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())). thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) test("receive") { @@ -42,7 +42,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) nettyRpcHandler.receive(client, null, null) - verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345))) + verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) } test("connectionTerminated") { @@ -57,9 +57,9 @@ class NettyRpcHandlerSuite extends SparkFunSuite { when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) nettyRpcHandler.connectionTerminated(client) - verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 12345))) + verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) verify(dispatcher, times(1)).postToAll( - RemoteProcessDisconnected(RpcAddress("localhost", 12345))) + RemoteProcessDisconnected(RpcAddress("localhost", 40000))) } } diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index 541ed9a8d0ab6f438aaf6decff881a78d30e0c8e..e2360eff5cfe11ec04c9d9ce092d305b722c9441 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -54,6 +54,11 @@ <groupId>org.apache.hadoop</groupId> <artifactId>hadoop-client</artifactId> </dependency> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-api</artifactId> + <scope>provided</scope> + </dependency> </dependencies> <build> diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index c6a6d7ac56bf3bec5a904d3ee3f403629c0bec2f..12ae350e4cef6ab2cb90030720f6da950858f717 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -321,7 +321,8 @@ private[spark] class ApplicationMaster( private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { val port = sparkConf.getInt("spark.yarn.am.port", 0) - rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr) + rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr, + clientMode = true) val driverRef = waitForSparkDriver() addAmIpFilter() registerAM(rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) @@ -574,9 +575,6 @@ private[spark] class ApplicationMaster( case x: AddWebUIFilter => logInfo(s"Add WebUI Filter. $x") driver.send(x) - - case DriverHello => - // SPARK-10987: no action needed for this message. } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {