Skip to content
Snippets Groups Projects
Commit a88c66ca authored by zsxwing's avatar zsxwing Committed by Reynold Xin
Browse files

[SPARK-11098][CORE] Add Outbox to cache the sending messages to resolve the message disorder issue

The current NettyRpc has a message order issue because it uses a thread pool to send messages. E.g., running the following two lines in the same thread,

```
ref.send("A")
ref.send("B")
```

The remote endpoint may see "B" before "A" because sending "A" and "B" are in parallel.
To resolve this issue, this PR added an outbox for each connection, and if we are connecting to the remote node when sending messages, just cache the sending messages in the outbox and send them one by one when the connection is established.

Author: zsxwing <zsxwing@gmail.com>

Closes #9197 from zsxwing/rpc-outbox.
parent 34e71c6d
No related branches found
No related tags found
No related merge requests found
......@@ -20,6 +20,7 @@ import java.io._
import java.net.{InetSocketAddress, URI}
import java.nio.ByteBuffer
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable
......@@ -70,12 +71,30 @@ private[netty] class NettyRpcEnv(
// Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool
// to implement non-blocking send/ask.
// TODO: a non-blocking TransportClientFactory.createClient in future
private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(
private[netty] val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool(
"netty-rpc-connection",
conf.getInt("spark.rpc.connect.threads", 64))
@volatile private var server: TransportServer = _
private val stopped = new AtomicBoolean(false)
/**
* A map for [[RpcAddress]] and [[Outbox]]. When we are connecting to a remote [[RpcAddress]],
* we just put messages to its [[Outbox]] to implement a non-blocking `send` method.
*/
private val outboxes = new ConcurrentHashMap[RpcAddress, Outbox]()
/**
* Remove the address's Outbox and stop it.
*/
private[netty] def removeOutbox(address: RpcAddress): Unit = {
val outbox = outboxes.remove(address)
if (outbox != null) {
outbox.stop()
}
}
def start(port: Int): Unit = {
val bootstraps: java.util.List[TransportServerBootstrap] =
if (securityManager.isAuthenticationEnabled()) {
......@@ -116,6 +135,30 @@ private[netty] class NettyRpcEnv(
dispatcher.stop(endpointRef)
}
private def postToOutbox(address: RpcAddress, message: OutboxMessage): Unit = {
val targetOutbox = {
val outbox = outboxes.get(address)
if (outbox == null) {
val newOutbox = new Outbox(this, address)
val oldOutbox = outboxes.putIfAbsent(address, newOutbox)
if (oldOutbox == null) {
newOutbox
} else {
oldOutbox
}
} else {
outbox
}
}
if (stopped.get) {
// It's possible that we put `targetOutbox` after stopping. So we need to clean it.
outboxes.remove(address)
targetOutbox.stop()
} else {
targetOutbox.send(message)
}
}
private[netty] def send(message: RequestMessage): Unit = {
val remoteAddr = message.receiver.address
if (remoteAddr == address) {
......@@ -127,37 +170,28 @@ private[netty] class NettyRpcEnv(
val ack = response.asInstanceOf[Ack]
logTrace(s"Received ack from ${ack.sender}")
case Failure(e) =>
logError(s"Exception when sending $message", e)
logWarning(s"Exception when sending $message", e)
}(ThreadUtils.sameThread)
} else {
// Message to a remote RPC endpoint.
try {
// `createClient` will block if it cannot find a known connection, so we should run it in
// clientConnectionExecutor
clientConnectionExecutor.execute(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port)
client.sendRpc(serialize(message), new RpcResponseCallback {
override def onFailure(e: Throwable): Unit = {
logError(s"Exception when sending $message", e)
}
override def onSuccess(response: Array[Byte]): Unit = {
val ack = deserialize[Ack](response)
logDebug(s"Receive ack from ${ack.sender}")
}
})
}
})
} catch {
case e: RejectedExecutionException =>
// `send` after shutting clientConnectionExecutor down, ignore it
logWarning(s"Cannot send $message because RpcEnv is stopped")
}
postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback {
override def onFailure(e: Throwable): Unit = {
logWarning(s"Exception when sending $message", e)
}
override def onSuccess(response: Array[Byte]): Unit = {
val ack = deserialize[Ack](response)
logDebug(s"Receive ack from ${ack.sender}")
}
}))
}
}
private[netty] def createClient(address: RpcAddress): TransportClient = {
clientFactory.createClient(address.host, address.port)
}
private[netty] def ask(message: RequestMessage): Future[Any] = {
val promise = Promise[Any]()
val remoteAddr = message.receiver.address
......@@ -180,39 +214,25 @@ private[netty] class NettyRpcEnv(
}
}(ThreadUtils.sameThread)
} else {
try {
// `createClient` will block if it cannot find a known connection, so we should run it in
// clientConnectionExecutor
clientConnectionExecutor.execute(new Runnable {
override def run(): Unit = {
val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port)
client.sendRpc(serialize(message), new RpcResponseCallback {
override def onFailure(e: Throwable): Unit = {
if (!promise.tryFailure(e)) {
logWarning("Ignore Exception", e)
}
}
override def onSuccess(response: Array[Byte]): Unit = {
val reply = deserialize[AskResponse](response)
if (reply.reply.isInstanceOf[RpcFailure]) {
if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
logWarning(s"Ignore failure: ${reply.reply}")
}
} else if (!promise.trySuccess(reply.reply)) {
logWarning(s"Ignore message: ${reply}")
}
}
})
}
})
} catch {
case e: RejectedExecutionException =>
postToOutbox(remoteAddr, OutboxMessage(serialize(message), new RpcResponseCallback {
override def onFailure(e: Throwable): Unit = {
if (!promise.tryFailure(e)) {
logWarning(s"Ignore failure", e)
logWarning("Ignore Exception", e)
}
}
}
override def onSuccess(response: Array[Byte]): Unit = {
val reply = deserialize[AskResponse](response)
if (reply.reply.isInstanceOf[RpcFailure]) {
if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
logWarning(s"Ignore failure: ${reply.reply}")
}
} else if (!promise.trySuccess(reply.reply)) {
logWarning(s"Ignore message: ${reply}")
}
}
}))
}
promise.future
}
......@@ -245,6 +265,16 @@ private[netty] class NettyRpcEnv(
}
private def cleanup(): Unit = {
if (!stopped.compareAndSet(false, true)) {
return
}
val iter = outboxes.values().iterator()
while (iter.hasNext()) {
val outbox = iter.next()
outboxes.remove(outbox.address)
outbox.stop()
}
if (timeoutScheduler != null) {
timeoutScheduler.shutdownNow()
}
......@@ -463,6 +493,7 @@ private[netty] class NettyRpcHandler(
val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress]
if (addr != null) {
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
nettyEnv.removeOutbox(clientAddr)
val messageOpt: Option[RemoteProcessDisconnected] =
synchronized {
remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress =>
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.rpc.netty
import java.util.concurrent.Callable
import javax.annotation.concurrent.GuardedBy
import scala.util.control.NonFatal
import org.apache.spark.SparkException
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.rpc.RpcAddress
private[netty] case class OutboxMessage(content: Array[Byte], callback: RpcResponseCallback)
private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
outbox => // Give this an alias so we can use it more clearly in closures.
@GuardedBy("this")
private val messages = new java.util.LinkedList[OutboxMessage]
@GuardedBy("this")
private var client: TransportClient = null
/**
* connectFuture points to the connect task. If there is no connect task, connectFuture will be
* null.
*/
@GuardedBy("this")
private var connectFuture: java.util.concurrent.Future[Unit] = null
@GuardedBy("this")
private var stopped = false
/**
* If there is any thread draining the message queue
*/
@GuardedBy("this")
private var draining = false
/**
* Send a message. If there is no active connection, cache it and launch a new connection. If
* [[Outbox]] is stopped, the sender will be notified with a [[SparkException]].
*/
def send(message: OutboxMessage): Unit = {
val dropped = synchronized {
if (stopped) {
true
} else {
messages.add(message)
false
}
}
if (dropped) {
message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
} else {
drainOutbox()
}
}
/**
* Drain the message queue. If there is other draining thread, just exit. If the connection has
* not been established, launch a task in the `nettyEnv.clientConnectionExecutor` to setup the
* connection.
*/
private def drainOutbox(): Unit = {
var message: OutboxMessage = null
synchronized {
if (stopped) {
return
}
if (connectFuture != null) {
// We are connecting to the remote address, so just exit
return
}
if (client == null) {
// There is no connect task but client is null, so we need to launch the connect task.
launchConnectTask()
return
}
if (draining) {
// There is some thread draining, so just exit
return
}
message = messages.poll()
if (message == null) {
return
}
draining = true
}
while (true) {
try {
val _client = synchronized { client }
if (_client != null) {
_client.sendRpc(message.content, message.callback)
} else {
assert(stopped == true)
}
} catch {
case NonFatal(e) =>
handleNetworkFailure(e)
return
}
synchronized {
if (stopped) {
return
}
message = messages.poll()
if (message == null) {
draining = false
return
}
}
}
}
private def launchConnectTask(): Unit = {
connectFuture = nettyEnv.clientConnectionExecutor.submit(new Callable[Unit] {
override def call(): Unit = {
try {
val _client = nettyEnv.createClient(address)
outbox.synchronized {
client = _client
if (stopped) {
closeClient()
}
}
} catch {
case ie: InterruptedException =>
// exit
return
case NonFatal(e) =>
outbox.synchronized { connectFuture = null }
handleNetworkFailure(e)
return
}
outbox.synchronized { connectFuture = null }
// It's possible that no thread is draining now. If we don't drain here, we cannot send the
// messages until the next message arrives.
drainOutbox()
}
})
}
/**
* Stop [[Inbox]] and notify the waiting messages with the cause.
*/
private def handleNetworkFailure(e: Throwable): Unit = {
synchronized {
assert(connectFuture == null)
if (stopped) {
return
}
stopped = true
closeClient()
}
// Remove this Outbox from nettyEnv so that the further messages will create a new Outbox along
// with a new connection
nettyEnv.removeOutbox(address)
// Notify the connection failure for the remaining messages
//
// We always check `stopped` before updating messages, so here we can make sure no thread will
// update messages and it's safe to just drain the queue.
var message = messages.poll()
while (message != null) {
message.callback.onFailure(e)
message = messages.poll()
}
assert(messages.isEmpty)
}
private def closeClient(): Unit = synchronized {
// Not sure if `client.close` is idempotent. Just for safety.
if (client != null) {
client.close()
}
client = null
}
/**
* Stop [[Outbox]]. The remaining messages in the [[Outbox]] will be notified with a
* [[SparkException]].
*/
def stop(): Unit = {
synchronized {
if (stopped) {
return
}
stopped = true
if (connectFuture != null) {
connectFuture.cancel(true)
}
closeClient()
}
// We always check `stopped` before updating messages, so here we can make sure no thread will
// update messages and it's safe to just drain the queue.
var message = messages.poll()
while (message != null) {
message.callback.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
message = messages.poll()
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment