diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index e01cf1a29e95b873016d957223256b73112cbd6b..284284eb805b7ca92067e474e352df968182ce14 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -20,6 +20,7 @@ import java.io._ import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -70,12 +71,30 @@ private[netty] class NettyRpcEnv( // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool // to implement non-blocking send/ask. // TODO: a non-blocking TransportClientFactory.createClient in future - private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool( + private[netty] val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool( "netty-rpc-connection", conf.getInt("spark.rpc.connect.threads", 64)) @volatile private var server: TransportServer = _ + private val stopped = new AtomicBoolean(false) + + /** + * A map for [[RpcAddress]] and [[Outbox]]. When we are connecting to a remote [[RpcAddress]], + * we just put messages to its [[Outbox]] to implement a non-blocking `send` method. + */ + private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]() + + /** + * Remove the address's Outbox and stop it. + */ + private[netty] def removeOutbox(address: RpcAddress): Unit = { + val outbox = outboxes.remove(address) + if (outbox != null) { + outbox.stop() + } + } + def start(port: Int): Unit = { val bootstraps: java.util.List[TransportServerBootstrap] = if (securityManager.isAuthenticationEnabled()) { @@ -116,6 +135,30 @@ private[netty] class NettyRpcEnv( dispatcher.stop(endpointRef) } + private def postToOutbox(address: RpcAddress, message: OutboxMessage): Unit = { + val targetOutbox = { + val outbox = outboxes.get(address) + if (outbox == null) { + val newOutbox = new Outbox(this, address) + val oldOutbox = outboxes.putIfAbsent(address, newOutbox) + if (oldOutbox == null) { + newOutbox + } else { + oldOutbox + } + } else { + outbox + } + } + if (stopped.get) { + // It's possible that we put `targetOutbox` after stopping. So we need to clean it. + outboxes.remove(address) + targetOutbox.stop() + } else { + targetOutbox.send(message) + } + } + private[netty] def send(message: RequestMessage): Unit = { val remoteAddr = message.receiver.address if (remoteAddr == address) { @@ -127,37 +170,28 @@ private[netty] class NettyRpcEnv( val ack = response.asInstanceOf[Ack] logTrace(s"Received ack from ${ack.sender}") case Failure(e) => - logError(s"Exception when sending $message", e) + logWarning(s"Exception when sending $message", e) }(ThreadUtils.sameThread) } else { // Message to a remote RPC endpoint. - try { - // `createClient` will block if it cannot find a known connection, so we should run it in - // clientConnectionExecutor - clientConnectionExecutor.execute(new Runnable { - override def run(): Unit = Utils.tryLogNonFatalError { - val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port) - client.sendRpc(serialize(message), new RpcResponseCallback { - - override def onFailure(e: Throwable): Unit = { - logError(s"Exception when sending $message", e) - } - - override def onSuccess(response: Array[Byte]): Unit = { - val ack = deserialize[Ack](response) - logDebug(s"Receive ack from ${ack.sender}") - } - }) - } - }) - } catch { - case e: RejectedExecutionException => - // `send` after shutting clientConnectionExecutor down, ignore it - logWarning(s"Cannot send $message because RpcEnv is stopped") - } + postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback { + + override def onFailure(e: Throwable): Unit = { + logWarning(s"Exception when sending $message", e) + } + + override def onSuccess(response: Array[Byte]): Unit = { + val ack = deserialize[Ack](response) + logDebug(s"Receive ack from ${ack.sender}") + } + })) } } + private[netty] def createClient(address: RpcAddress): TransportClient = { + clientFactory.createClient(address.host, address.port) + } + private[netty] def ask(message: RequestMessage): Future[Any] = { val promise = Promise[Any]() val remoteAddr = message.receiver.address @@ -180,39 +214,25 @@ private[netty] class NettyRpcEnv( } }(ThreadUtils.sameThread) } else { - try { - // `createClient` will block if it cannot find a known connection, so we should run it in - // clientConnectionExecutor - clientConnectionExecutor.execute(new Runnable { - override def run(): Unit = { - val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port) - client.sendRpc(serialize(message), new RpcResponseCallback { - - override def onFailure(e: Throwable): Unit = { - if (!promise.tryFailure(e)) { - logWarning("Ignore Exception", e) - } - } - - override def onSuccess(response: Array[Byte]): Unit = { - val reply = deserialize[AskResponse](response) - if (reply.reply.isInstanceOf[RpcFailure]) { - if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { - logWarning(s"Ignore failure: ${reply.reply}") - } - } else if (!promise.trySuccess(reply.reply)) { - logWarning(s"Ignore message: ${reply}") - } - } - }) - } - }) - } catch { - case e: RejectedExecutionException => + postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback { + + override def onFailure(e: Throwable): Unit = { if (!promise.tryFailure(e)) { - logWarning(s"Ignore failure", e) + logWarning("Ignore Exception", e) } - } + } + + override def onSuccess(response: Array[Byte]): Unit = { + val reply = deserialize[AskResponse](response) + if (reply.reply.isInstanceOf[RpcFailure]) { + if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { + logWarning(s"Ignore failure: ${reply.reply}") + } + } else if (!promise.trySuccess(reply.reply)) { + logWarning(s"Ignore message: ${reply}") + } + } + })) } promise.future } @@ -245,6 +265,16 @@ private[netty] class NettyRpcEnv( } private def cleanup(): Unit = { + if (!stopped.compareAndSet(false, true)) { + return + } + + val iter = outboxes.values().iterator() + while (iter.hasNext()) { + val outbox = iter.next() + outboxes.remove(outbox.address) + outbox.stop() + } if (timeoutScheduler != null) { timeoutScheduler.shutdownNow() } @@ -463,6 +493,7 @@ private[netty] class NettyRpcHandler( val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + nettyEnv.removeOutbox(clientAddr) val messageOpt: Option[RemoteProcessDisconnected] = synchronized { remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress => diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala new file mode 100644 index 0000000000000000000000000000000000000000..7d9d593b362412e7d075c7abc8ebd19c36624aff --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.util.concurrent.Callable +import javax.annotation.concurrent.GuardedBy + +import scala.util.control.NonFatal + +import org.apache.spark.SparkException +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.rpc.RpcAddress + +private[netty] case class OutboxMessage(content: Array[Byte], callback: RpcResponseCallback) + +private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) { + + outbox => // Give this an alias so we can use it more clearly in closures. + + @GuardedBy("this") + private val messages = new java.util.LinkedList[OutboxMessage] + + @GuardedBy("this") + private var client: TransportClient = null + + /** + * connectFuture points to the connect task. If there is no connect task, connectFuture will be + * null. + */ + @GuardedBy("this") + private var connectFuture: java.util.concurrent.Future[Unit] = null + + @GuardedBy("this") + private var stopped = false + + /** + * If there is any thread draining the message queue + */ + @GuardedBy("this") + private var draining = false + + /** + * Send a message. If there is no active connection, cache it and launch a new connection. If + * [[Outbox]] is stopped, the sender will be notified with a [[SparkException]]. + */ + def send(message: OutboxMessage): Unit = { + val dropped = synchronized { + if (stopped) { + true + } else { + messages.add(message) + false + } + } + if (dropped) { + message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped")) + } else { + drainOutbox() + } + } + + /** + * Drain the message queue. If there is other draining thread, just exit. If the connection has + * not been established, launch a task in the `nettyEnv.clientConnectionExecutor` to setup the + * connection. + */ + private def drainOutbox(): Unit = { + var message: OutboxMessage = null + synchronized { + if (stopped) { + return + } + if (connectFuture != null) { + // We are connecting to the remote address, so just exit + return + } + if (client == null) { + // There is no connect task but client is null, so we need to launch the connect task. + launchConnectTask() + return + } + if (draining) { + // There is some thread draining, so just exit + return + } + message = messages.poll() + if (message == null) { + return + } + draining = true + } + while (true) { + try { + val _client = synchronized { client } + if (_client != null) { + _client.sendRpc(message.content, message.callback) + } else { + assert(stopped == true) + } + } catch { + case NonFatal(e) => + handleNetworkFailure(e) + return + } + synchronized { + if (stopped) { + return + } + message = messages.poll() + if (message == null) { + draining = false + return + } + } + } + } + + private def launchConnectTask(): Unit = { + connectFuture = nettyEnv.clientConnectionExecutor.submit(new Callable[Unit] { + + override def call(): Unit = { + try { + val _client = nettyEnv.createClient(address) + outbox.synchronized { + client = _client + if (stopped) { + closeClient() + } + } + } catch { + case ie: InterruptedException => + // exit + return + case NonFatal(e) => + outbox.synchronized { connectFuture = null } + handleNetworkFailure(e) + return + } + outbox.synchronized { connectFuture = null } + // It's possible that no thread is draining now. If we don't drain here, we cannot send the + // messages until the next message arrives. + drainOutbox() + } + }) + } + + /** + * Stop [[Inbox]] and notify the waiting messages with the cause. + */ + private def handleNetworkFailure(e: Throwable): Unit = { + synchronized { + assert(connectFuture == null) + if (stopped) { + return + } + stopped = true + closeClient() + } + // Remove this Outbox from nettyEnv so that the further messages will create a new Outbox along + // with a new connection + nettyEnv.removeOutbox(address) + + // Notify the connection failure for the remaining messages + // + // We always check `stopped` before updating messages, so here we can make sure no thread will + // update messages and it's safe to just drain the queue. + var message = messages.poll() + while (message != null) { + message.callback.onFailure(e) + message = messages.poll() + } + assert(messages.isEmpty) + } + + private def closeClient(): Unit = synchronized { + // Not sure if `client.close` is idempotent. Just for safety. + if (client != null) { + client.close() + } + client = null + } + + /** + * Stop [[Outbox]]. The remaining messages in the [[Outbox]] will be notified with a + * [[SparkException]]. + */ + def stop(): Unit = { + synchronized { + if (stopped) { + return + } + stopped = true + if (connectFuture != null) { + connectFuture.cancel(true) + } + closeClient() + } + + // We always check `stopped` before updating messages, so here we can make sure no thread will + // update messages and it's safe to just drain the queue. + var message = messages.poll() + while (message != null) { + message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped")) + message = messages.poll() + } + } +}