From 4e81783e92f464d479baaf93eccc3adb1496989a Mon Sep 17 00:00:00 2001
From: Marcelo Vanzin <vanzin@cloudera.com>
Date: Wed, 25 Nov 2015 12:58:18 -0800
Subject: [PATCH] [SPARK-11866][NETWORK][CORE] Make sure timed out RPCs are
 cleaned up.

This change does a couple of different things to make sure that the RpcEnv-level
code and the network library agree about the status of outstanding RPCs.

For RPCs that do not expect a reply ("RpcEnv.send"), support for one way
messages (hello CORBA!) was added to the network layer. This is a
"fire and forget" message that does not require any state to be kept
by the TransportClient; as a result, the RpcEnv 'Ack' message is not needed
anymore.

For RPCs that do expect a reply ("RpcEnv.ask"), the network library now
returns the internal RPC id; if the RpcEnv layer decides to time out the
RPC before the network layer does, it now asks the TransportClient to
forget about the RPC, so that if the network-level timeout occurs, the
client is not killed.

As part of implementing the above, I cleaned up some of the code in the
netty rpc backend, removing types that were not necessary and factoring
out some common code. Of interest is a slight change in the exceptions
when posting messages to a stopped RpcEnv; that's mostly to avoid nasty
error messages from the local-cluster backend when shutting down, which
pollutes the terminal output.

Author: Marcelo Vanzin <vanzin@cloudera.com>

Closes #9917 from vanzin/SPARK-11866.
---
 .../spark/deploy/worker/ExecutorRunner.scala  |   6 +-
 .../apache/spark/rpc/netty/Dispatcher.scala   |  55 +++----
 .../org/apache/spark/rpc/netty/Inbox.scala    |  28 ++--
 .../spark/rpc/netty/NettyRpcCallContext.scala |  35 +---
 .../apache/spark/rpc/netty/NettyRpcEnv.scala  | 153 +++++++-----------
 .../org/apache/spark/rpc/netty/Outbox.scala   |  64 ++++++--
 .../apache/spark/rpc/netty/InboxSuite.scala   |   6 +-
 .../rpc/netty/NettyRpcHandlerSuite.scala      |   2 +-
 .../spark/network/client/TransportClient.java |  34 +++-
 .../spark/network/protocol/Message.java       |   4 +-
 .../network/protocol/MessageDecoder.java      |   3 +
 .../spark/network/protocol/OneWayMessage.java |  75 +++++++++
 .../spark/network/sasl/SaslRpcHandler.java    |   5 +
 .../spark/network/server/RpcHandler.java      |  36 +++++
 .../server/TransportRequestHandler.java       |  18 ++-
 .../apache/spark/network/ProtocolSuite.java   |   2 +
 .../spark/network/RpcIntegrationSuite.java    |  31 ++++
 .../spark/network/sasl/SparkSaslSuite.java    |   9 ++
 18 files changed, 374 insertions(+), 192 deletions(-)
 create mode 100644 network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java

diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 3aef0515cb..25a17473e4 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -92,7 +92,11 @@ private[deploy] class ExecutorRunner(
       process.destroy()
       exitCode = Some(process.waitFor())
     }
-    worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode))
+    try {
+      worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode))
+    } catch {
+      case e: IllegalStateException => logWarning(e.getMessage(), e)
+    }
   }
 
   /** Stop this executor runner, including killing the process it launched */
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
index eb25d6c7b7..533c984766 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
@@ -106,44 +106,30 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
     val iter = endpoints.keySet().iterator()
     while (iter.hasNext) {
       val name = iter.next
-      postMessage(
-        name,
-        _ => message,
-        () => { logWarning(s"Drop $message because $name has been stopped") })
+      postMessage(name, message, (e) => logWarning(s"Message $message dropped.", e))
     }
   }
 
   /** Posts a message sent by a remote endpoint. */
   def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = {
-    def createMessage(sender: NettyRpcEndpointRef): InboxMessage = {
-      val rpcCallContext =
-        new RemoteNettyRpcCallContext(
-          nettyEnv, sender, callback, message.senderAddress, message.needReply)
-      ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)
-    }
-
-    def onEndpointStopped(): Unit = {
-      callback.onFailure(
-        new SparkException(s"Could not find ${message.receiver.name} or it has been stopped"))
-    }
-
-    postMessage(message.receiver.name, createMessage, onEndpointStopped)
+    val rpcCallContext =
+      new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress)
+    val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
+    postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e))
   }
 
   /** Posts a message sent by a local endpoint. */
   def postLocalMessage(message: RequestMessage, p: Promise[Any]): Unit = {
-    def createMessage(sender: NettyRpcEndpointRef): InboxMessage = {
-      val rpcCallContext =
-        new LocalNettyRpcCallContext(sender, message.senderAddress, message.needReply, p)
-      ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)
-    }
-
-    def onEndpointStopped(): Unit = {
-      p.tryFailure(
-        new SparkException(s"Could not find ${message.receiver.name} or it has been stopped"))
-    }
+    val rpcCallContext =
+      new LocalNettyRpcCallContext(message.senderAddress, p)
+    val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext)
+    postMessage(message.receiver.name, rpcMessage, (e) => p.tryFailure(e))
+  }
 
-    postMessage(message.receiver.name, createMessage, onEndpointStopped)
+  /** Posts a one-way message. */
+  def postOneWayMessage(message: RequestMessage): Unit = {
+    postMessage(message.receiver.name, OneWayMessage(message.senderAddress, message.content),
+      (e) => throw e)
   }
 
   /**
@@ -155,21 +141,26 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
    */
   private def postMessage(
       endpointName: String,
-      createMessageFn: NettyRpcEndpointRef => InboxMessage,
-      callbackIfStopped: () => Unit): Unit = {
+      message: InboxMessage,
+      callbackIfStopped: (Exception) => Unit): Unit = {
     val shouldCallOnStop = synchronized {
       val data = endpoints.get(endpointName)
       if (stopped || data == null) {
         true
       } else {
-        data.inbox.post(createMessageFn(data.ref))
+        data.inbox.post(message)
         receivers.offer(data)
         false
       }
     }
     if (shouldCallOnStop) {
       // We don't need to call `onStop` in the `synchronized` block
-      callbackIfStopped()
+      val error = if (stopped) {
+          new IllegalStateException("RpcEnv already stopped.")
+        } else {
+          new SparkException(s"Could not find $endpointName or it has been stopped.")
+        }
+      callbackIfStopped(error)
     }
   }
 
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
index 464027f07c..175463cc10 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
@@ -27,10 +27,13 @@ import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint}
 
 private[netty] sealed trait InboxMessage
 
-private[netty] case class ContentMessage(
+private[netty] case class OneWayMessage(
+    senderAddress: RpcAddress,
+    content: Any) extends InboxMessage
+
+private[netty] case class RpcMessage(
     senderAddress: RpcAddress,
     content: Any,
-    needReply: Boolean,
     context: NettyRpcCallContext) extends InboxMessage
 
 private[netty] case object OnStart extends InboxMessage
@@ -96,29 +99,24 @@ private[netty] class Inbox(
     while (true) {
       safelyCall(endpoint) {
         message match {
-          case ContentMessage(_sender, content, needReply, context) =>
-            // The partial function to call
-            val pf = if (needReply) endpoint.receiveAndReply(context) else endpoint.receive
+          case RpcMessage(_sender, content, context) =>
             try {
-              pf.applyOrElse[Any, Unit](content, { msg =>
+              endpoint.receiveAndReply(context).applyOrElse[Any, Unit](content, { msg =>
                 throw new SparkException(s"Unsupported message $message from ${_sender}")
               })
-              if (!needReply) {
-                context.finish()
-              }
             } catch {
               case NonFatal(e) =>
-                if (needReply) {
-                  // If the sender asks a reply, we should send the error back to the sender
-                  context.sendFailure(e)
-                } else {
-                  context.finish()
-                }
+                context.sendFailure(e)
                 // Throw the exception -- this exception will be caught by the safelyCall function.
                 // The endpoint's onError function will be called.
                 throw e
             }
 
+          case OneWayMessage(_sender, content) =>
+            endpoint.receive.applyOrElse[Any, Unit](content, { msg =>
+              throw new SparkException(s"Unsupported message $message from ${_sender}")
+            })
+
           case OnStart =>
             endpoint.onStart()
             if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) {
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
index 21d5bb4923..6637e2321f 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala
@@ -23,49 +23,28 @@ import org.apache.spark.Logging
 import org.apache.spark.network.client.RpcResponseCallback
 import org.apache.spark.rpc.{RpcAddress, RpcCallContext}
 
-private[netty] abstract class NettyRpcCallContext(
-    endpointRef: NettyRpcEndpointRef,
-    override val senderAddress: RpcAddress,
-    needReply: Boolean)
+private[netty] abstract class NettyRpcCallContext(override val senderAddress: RpcAddress)
   extends RpcCallContext with Logging {
 
   protected def send(message: Any): Unit
 
   override def reply(response: Any): Unit = {
-    if (needReply) {
-      send(AskResponse(endpointRef, response))
-    } else {
-      throw new IllegalStateException(
-        s"Cannot send $response to the sender because the sender does not expect a reply")
-    }
+    send(response)
   }
 
   override def sendFailure(e: Throwable): Unit = {
-    if (needReply) {
-      send(AskResponse(endpointRef, RpcFailure(e)))
-    } else {
-      logError(e.getMessage, e)
-      throw new IllegalStateException(
-        "Cannot send reply to the sender because the sender won't handle it")
-    }
+    send(RpcFailure(e))
   }
 
-  def finish(): Unit = {
-    if (!needReply) {
-      send(Ack(endpointRef))
-    }
-  }
 }
 
 /**
  * If the sender and the receiver are in the same process, the reply can be sent back via `Promise`.
  */
 private[netty] class LocalNettyRpcCallContext(
-    endpointRef: NettyRpcEndpointRef,
     senderAddress: RpcAddress,
-    needReply: Boolean,
     p: Promise[Any])
-  extends NettyRpcCallContext(endpointRef, senderAddress, needReply) {
+  extends NettyRpcCallContext(senderAddress) {
 
   override protected def send(message: Any): Unit = {
     p.success(message)
@@ -77,11 +56,9 @@ private[netty] class LocalNettyRpcCallContext(
  */
 private[netty] class RemoteNettyRpcCallContext(
     nettyEnv: NettyRpcEnv,
-    endpointRef: NettyRpcEndpointRef,
     callback: RpcResponseCallback,
-    senderAddress: RpcAddress,
-    needReply: Boolean)
-  extends NettyRpcCallContext(endpointRef, senderAddress, needReply) {
+    senderAddress: RpcAddress)
+  extends NettyRpcCallContext(senderAddress) {
 
   override protected def send(message: Any): Unit = {
     val reply = nettyEnv.serialize(message)
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index c8fa870f50..c7d74fa1d9 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -150,7 +150,7 @@ private[netty] class NettyRpcEnv(
 
   private def postToOutbox(receiver: NettyRpcEndpointRef, message: OutboxMessage): Unit = {
     if (receiver.client != null) {
-      receiver.client.sendRpc(message.content, message.createCallback(receiver.client));
+      message.sendWith(receiver.client)
     } else {
       require(receiver.address != null,
         "Cannot send message to client endpoint with no listen address.")
@@ -182,25 +182,10 @@ private[netty] class NettyRpcEnv(
     val remoteAddr = message.receiver.address
     if (remoteAddr == address) {
       // Message to a local RPC endpoint.
-      val promise = Promise[Any]()
-      dispatcher.postLocalMessage(message, promise)
-      promise.future.onComplete {
-        case Success(response) =>
-          val ack = response.asInstanceOf[Ack]
-          logTrace(s"Received ack from ${ack.sender}")
-        case Failure(e) =>
-          logWarning(s"Exception when sending $message", e)
-      }(ThreadUtils.sameThread)
+      dispatcher.postOneWayMessage(message)
     } else {
       // Message to a remote RPC endpoint.
-      postToOutbox(message.receiver, OutboxMessage(serialize(message),
-        (e) => {
-          logWarning(s"Exception when sending $message", e)
-        },
-        (client, response) => {
-          val ack = deserialize[Ack](client, response)
-          logDebug(s"Receive ack from ${ack.sender}")
-        }))
+      postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message)))
     }
   }
 
@@ -208,46 +193,52 @@ private[netty] class NettyRpcEnv(
     clientFactory.createClient(address.host, address.port)
   }
 
-  private[netty] def ask(message: RequestMessage): Future[Any] = {
+  private[netty] def ask[T: ClassTag](message: RequestMessage, timeout: RpcTimeout): Future[T] = {
     val promise = Promise[Any]()
     val remoteAddr = message.receiver.address
+
+    def onFailure(e: Throwable): Unit = {
+      if (!promise.tryFailure(e)) {
+        logWarning(s"Ignored failure: $e")
+      }
+    }
+
+    def onSuccess(reply: Any): Unit = reply match {
+      case RpcFailure(e) => onFailure(e)
+      case rpcReply =>
+        if (!promise.trySuccess(rpcReply)) {
+          logWarning(s"Ignored message: $reply")
+        }
+    }
+
     if (remoteAddr == address) {
       val p = Promise[Any]()
-      dispatcher.postLocalMessage(message, p)
       p.future.onComplete {
-        case Success(response) =>
-          val reply = response.asInstanceOf[AskResponse]
-          if (reply.reply.isInstanceOf[RpcFailure]) {
-            if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
-              logWarning(s"Ignore failure: ${reply.reply}")
-            }
-          } else if (!promise.trySuccess(reply.reply)) {
-            logWarning(s"Ignore message: ${reply}")
-          }
-        case Failure(e) =>
-          if (!promise.tryFailure(e)) {
-            logWarning("Ignore Exception", e)
-          }
+        case Success(response) => onSuccess(response)
+        case Failure(e) => onFailure(e)
       }(ThreadUtils.sameThread)
+      dispatcher.postLocalMessage(message, p)
     } else {
-      postToOutbox(message.receiver, OutboxMessage(serialize(message),
-        (e) => {
-          if (!promise.tryFailure(e)) {
-            logWarning("Ignore Exception", e)
-          }
-        },
-        (client, response) => {
-          val reply = deserialize[AskResponse](client, response)
-          if (reply.reply.isInstanceOf[RpcFailure]) {
-            if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) {
-              logWarning(s"Ignore failure: ${reply.reply}")
-            }
-          } else if (!promise.trySuccess(reply.reply)) {
-            logWarning(s"Ignore message: ${reply}")
-          }
-        }))
+      val rpcMessage = RpcOutboxMessage(serialize(message),
+        onFailure,
+        (client, response) => onSuccess(deserialize[Any](client, response)))
+      postToOutbox(message.receiver, rpcMessage)
+      promise.future.onFailure {
+        case _: TimeoutException => rpcMessage.onTimeout()
+        case _ =>
+      }(ThreadUtils.sameThread)
     }
-    promise.future
+
+    val timeoutCancelable = timeoutScheduler.schedule(new Runnable {
+      override def run(): Unit = {
+        promise.tryFailure(
+          new TimeoutException("Cannot receive any reply in ${timeout.duration}"))
+      }
+    }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
+    promise.future.onComplete { v =>
+      timeoutCancelable.cancel(true)
+    }(ThreadUtils.sameThread)
+    promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
   }
 
   private[netty] def serialize(content: Any): Array[Byte] = {
@@ -512,25 +503,12 @@ private[netty] class NettyRpcEndpointRef(
   override def name: String = _name
 
   override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = {
-    val promise = Promise[Any]()
-    val timeoutCancelable = nettyEnv.timeoutScheduler.schedule(new Runnable {
-      override def run(): Unit = {
-        promise.tryFailure(new TimeoutException("Cannot receive any reply in " + timeout.duration))
-      }
-    }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
-    val f = nettyEnv.ask(RequestMessage(nettyEnv.address, this, message, true))
-    f.onComplete { v =>
-      timeoutCancelable.cancel(true)
-      if (!promise.tryComplete(v)) {
-        logWarning(s"Ignore message $v")
-      }
-    }(ThreadUtils.sameThread)
-    promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
+    nettyEnv.ask(RequestMessage(nettyEnv.address, this, message), timeout)
   }
 
   override def send(message: Any): Unit = {
     require(message != null, "Message is null")
-    nettyEnv.send(RequestMessage(nettyEnv.address, this, message, false))
+    nettyEnv.send(RequestMessage(nettyEnv.address, this, message))
   }
 
   override def toString: String = s"NettyRpcEndpointRef(${_address})"
@@ -549,24 +527,7 @@ private[netty] class NettyRpcEndpointRef(
  * The message that is sent from the sender to the receiver.
  */
 private[netty] case class RequestMessage(
-    senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any, needReply: Boolean)
-
-/**
- * The base trait for all messages that are sent back from the receiver to the sender.
- */
-private[netty] trait ResponseMessage
-
-/**
- * The reply for `ask` from the receiver side.
- */
-private[netty] case class AskResponse(sender: NettyRpcEndpointRef, reply: Any)
-  extends ResponseMessage
-
-/**
- * A message to send back to the receiver side. It's necessary because [[TransportClient]] only
- * clean the resources when it receives a reply.
- */
-private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessage
+    senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any)
 
 /**
  * A response that indicates some failure happens in the receiver side.
@@ -598,6 +559,18 @@ private[netty] class NettyRpcHandler(
       client: TransportClient,
       message: Array[Byte],
       callback: RpcResponseCallback): Unit = {
+    val messageToDispatch = internalReceive(client, message)
+    dispatcher.postRemoteMessage(messageToDispatch, callback)
+  }
+
+  override def receive(
+      client: TransportClient,
+      message: Array[Byte]): Unit = {
+    val messageToDispatch = internalReceive(client, message)
+    dispatcher.postOneWayMessage(messageToDispatch)
+  }
+
+  private def internalReceive(client: TransportClient, message: Array[Byte]): RequestMessage = {
     val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
     assert(addr != null)
     val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
@@ -605,14 +578,12 @@ private[netty] class NettyRpcHandler(
       dispatcher.postToAll(RemoteProcessConnected(clientAddr))
     }
     val requestMessage = nettyEnv.deserialize[RequestMessage](client, message)
-    val messageToDispatch = if (requestMessage.senderAddress == null) {
-        // Create a new message with the socket address of the client as the sender.
-        RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content,
-          requestMessage.needReply)
-      } else {
-        requestMessage
-      }
-    dispatcher.postRemoteMessage(messageToDispatch, callback)
+    if (requestMessage.senderAddress == null) {
+      // Create a new message with the socket address of the client as the sender.
+      RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content)
+    } else {
+      requestMessage
+    }
   }
 
   override def getStreamManager: StreamManager = streamManager
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
index 2f6817f2eb..36fdd00bbc 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
@@ -22,22 +22,56 @@ import javax.annotation.concurrent.GuardedBy
 
 import scala.util.control.NonFatal
 
-import org.apache.spark.SparkException
+import org.apache.spark.{Logging, SparkException}
 import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
 import org.apache.spark.rpc.RpcAddress
 
-private[netty] case class OutboxMessage(content: Array[Byte],
-  _onFailure: (Throwable) => Unit,
-  _onSuccess: (TransportClient, Array[Byte]) => Unit) {
+private[netty] sealed trait OutboxMessage {
 
-  def createCallback(client: TransportClient): RpcResponseCallback = new RpcResponseCallback() {
-    override def onFailure(e: Throwable): Unit = {
-      _onFailure(e)
-    }
+  def sendWith(client: TransportClient): Unit
 
-    override def onSuccess(response: Array[Byte]): Unit = {
-      _onSuccess(client, response)
-    }
+  def onFailure(e: Throwable): Unit
+
+}
+
+private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends OutboxMessage
+  with Logging {
+
+  override def sendWith(client: TransportClient): Unit = {
+    client.send(content)
+  }
+
+  override def onFailure(e: Throwable): Unit = {
+    logWarning(s"Failed to send one-way RPC.", e)
+  }
+
+}
+
+private[netty] case class RpcOutboxMessage(
+    content: Array[Byte],
+    _onFailure: (Throwable) => Unit,
+    _onSuccess: (TransportClient, Array[Byte]) => Unit)
+  extends OutboxMessage with RpcResponseCallback {
+
+  private var client: TransportClient = _
+  private var requestId: Long = _
+
+  override def sendWith(client: TransportClient): Unit = {
+    this.client = client
+    this.requestId = client.sendRpc(content, this)
+  }
+
+  def onTimeout(): Unit = {
+    require(client != null, "TransportClient has not yet been set.")
+    client.removeRpcRequest(requestId)
+  }
+
+  override def onFailure(e: Throwable): Unit = {
+    _onFailure(e)
+  }
+
+  override def onSuccess(response: Array[Byte]): Unit = {
+    _onSuccess(client, response)
   }
 
 }
@@ -82,7 +116,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
       }
     }
     if (dropped) {
-      message._onFailure(new SparkException("Message is dropped because Outbox is stopped"))
+      message.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
     } else {
       drainOutbox()
     }
@@ -122,7 +156,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
       try {
         val _client = synchronized { client }
         if (_client != null) {
-          _client.sendRpc(message.content, message.createCallback(_client))
+          message.sendWith(_client)
         } else {
           assert(stopped == true)
         }
@@ -195,7 +229,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
     // update messages and it's safe to just drain the queue.
     var message = messages.poll()
     while (message != null) {
-      message._onFailure(e)
+      message.onFailure(e)
       message = messages.poll()
     }
     assert(messages.isEmpty)
@@ -229,7 +263,7 @@ private[netty] class Outbox(nettyEnv: NettyRpcEnv, val address: RpcAddress) {
     // update messages and it's safe to just drain the queue.
     var message = messages.poll()
     while (message != null) {
-      message._onFailure(new SparkException("Message is dropped because Outbox is stopped"))
+      message.onFailure(new SparkException("Message is dropped because Outbox is stopped"))
       message = messages.poll()
     }
   }
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
index 276c077b3d..2136795b18 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala
@@ -35,7 +35,7 @@ class InboxSuite extends SparkFunSuite {
     val dispatcher = mock(classOf[Dispatcher])
 
     val inbox = new Inbox(endpointRef, endpoint)
-    val message = ContentMessage(null, "hi", false, null)
+    val message = OneWayMessage(null, "hi")
     inbox.post(message)
     inbox.process(dispatcher)
     assert(inbox.isEmpty)
@@ -55,7 +55,7 @@ class InboxSuite extends SparkFunSuite {
     val dispatcher = mock(classOf[Dispatcher])
 
     val inbox = new Inbox(endpointRef, endpoint)
-    val message = ContentMessage(null, "hi", true, null)
+    val message = RpcMessage(null, "hi", null)
     inbox.post(message)
     inbox.process(dispatcher)
     assert(inbox.isEmpty)
@@ -83,7 +83,7 @@ class InboxSuite extends SparkFunSuite {
       new Thread {
         override def run(): Unit = {
           for (_ <- 0 until 100) {
-            val message = ContentMessage(null, "hi", false, null)
+            val message = OneWayMessage(null, "hi")
             inbox.post(message)
           }
           exitLatch.countDown()
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
index ccca795683..323184cdd9 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala
@@ -33,7 +33,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {
   val env = mock(classOf[NettyRpcEnv])
   val sm = mock(classOf[StreamManager])
   when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any()))
-    .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false))
+    .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null))
 
   test("receive") {
     val dispatcher = mock(classOf[Dispatcher])
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index 876fcd8467..8a58e7b245 100644
--- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -25,6 +25,7 @@ import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 import javax.annotation.Nullable;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Objects;
 import com.google.common.base.Preconditions;
 import com.google.common.base.Throwables;
@@ -36,6 +37,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.spark.network.protocol.ChunkFetchRequest;
+import org.apache.spark.network.protocol.OneWayMessage;
 import org.apache.spark.network.protocol.RpcRequest;
 import org.apache.spark.network.protocol.StreamChunkId;
 import org.apache.spark.network.protocol.StreamRequest;
@@ -205,8 +207,12 @@ public class TransportClient implements Closeable {
   /**
    * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked
    * with the server's response or upon any failure.
+   *
+   * @param message The message to send.
+   * @param callback Callback to handle the RPC's reply.
+   * @return The RPC's id.
    */
-  public void sendRpc(byte[] message, final RpcResponseCallback callback) {
+  public long sendRpc(byte[] message, final RpcResponseCallback callback) {
     final String serverAddr = NettyUtils.getRemoteAddress(channel);
     final long startTime = System.currentTimeMillis();
     logger.trace("Sending RPC to {}", serverAddr);
@@ -235,6 +241,8 @@ public class TransportClient implements Closeable {
           }
         }
       });
+
+    return requestId;
   }
 
   /**
@@ -265,11 +273,35 @@ public class TransportClient implements Closeable {
     }
   }
 
+  /**
+   * Sends an opaque message to the RpcHandler on the server-side. No reply is expected for the
+   * message, and no delivery guarantees are made.
+   *
+   * @param message The message to send.
+   */
+  public void send(byte[] message) {
+    channel.writeAndFlush(new OneWayMessage(message));
+  }
+
+  /**
+   * Removes any state associated with the given RPC.
+   *
+   * @param requestId The RPC id returned by {@link #sendRpc(byte[], RpcResponseCallback)}.
+   */
+  public void removeRpcRequest(long requestId) {
+    handler.removeRpcRequest(requestId);
+  }
+
   /** Mark this channel as having timed out. */
   public void timeOut() {
     this.timedOut = true;
   }
 
+  @VisibleForTesting
+  public TransportResponseHandler getHandler() {
+    return handler;
+  }
+
   @Override
   public void close() {
     // close is a local operation and should finish with milliseconds; timeout just to be safe
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
index d01598c20f..39afd03db6 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
@@ -28,7 +28,8 @@ public interface Message extends Encodable {
   public static enum Type implements Encodable {
     ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
     RpcRequest(3), RpcResponse(4), RpcFailure(5),
-    StreamRequest(6), StreamResponse(7), StreamFailure(8);
+    StreamRequest(6), StreamResponse(7), StreamFailure(8),
+    OneWayMessage(9);
 
     private final byte id;
 
@@ -55,6 +56,7 @@ public interface Message extends Encodable {
         case 6: return StreamRequest;
         case 7: return StreamResponse;
         case 8: return StreamFailure;
+        case 9: return OneWayMessage;
         default: throw new IllegalArgumentException("Unknown message type: " + id);
       }
     }
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
index 3c04048f38..074780f2b9 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
@@ -63,6 +63,9 @@ public final class MessageDecoder extends MessageToMessageDecoder<ByteBuf> {
       case RpcFailure:
         return RpcFailure.decode(in);
 
+      case OneWayMessage:
+        return OneWayMessage.decode(in);
+
       case StreamRequest:
         return StreamRequest.decode(in);
 
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
new file mode 100644
index 0000000000..95a0270be3
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/OneWayMessage.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * A RPC that does not expect a reply, which is handled by a remote
+ * {@link org.apache.spark.network.server.RpcHandler}.
+ */
+public final class OneWayMessage implements RequestMessage {
+  /** Serialized message to send to remote RpcHandler. */
+  public final byte[] message;
+
+  public OneWayMessage(byte[] message) {
+    this.message = message;
+  }
+
+  @Override
+  public Type type() { return Type.OneWayMessage; }
+
+  @Override
+  public int encodedLength() {
+    return Encoders.ByteArrays.encodedLength(message);
+  }
+
+  @Override
+  public void encode(ByteBuf buf) {
+    Encoders.ByteArrays.encode(buf, message);
+  }
+
+  public static OneWayMessage decode(ByteBuf buf) {
+    byte[] message = Encoders.ByteArrays.decode(buf);
+    return new OneWayMessage(message);
+  }
+
+  @Override
+  public int hashCode() {
+    return Arrays.hashCode(message);
+  }
+
+  @Override
+  public boolean equals(Object other) {
+    if (other instanceof OneWayMessage) {
+      OneWayMessage o = (OneWayMessage) other;
+      return Arrays.equals(message, o.message);
+    }
+    return false;
+  }
+
+  @Override
+  public String toString() {
+    return Objects.toStringHelper(this)
+      .add("message", message)
+      .toString();
+  }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
index 7033adb9ca..830db94b89 100644
--- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java
@@ -108,6 +108,11 @@ class SaslRpcHandler extends RpcHandler {
     }
   }
 
+  @Override
+  public void receive(TransportClient client, byte[] message) {
+    delegate.receive(client, message);
+  }
+
   @Override
   public StreamManager getStreamManager() {
     return delegate.getStreamManager();
diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
index dbb7f95f55..65109ddfe1 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java
@@ -17,6 +17,9 @@
 
 package org.apache.spark.network.server;
 
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
 import org.apache.spark.network.client.RpcResponseCallback;
 import org.apache.spark.network.client.TransportClient;
 
@@ -24,6 +27,9 @@ import org.apache.spark.network.client.TransportClient;
  * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s.
  */
 public abstract class RpcHandler {
+
+  private static final RpcResponseCallback ONE_WAY_CALLBACK = new OneWayRpcCallback();
+
   /**
    * Receive a single RPC message. Any exception thrown while in this method will be sent back to
    * the client in string form as a standard RPC failure.
@@ -47,6 +53,19 @@ public abstract class RpcHandler {
    */
   public abstract StreamManager getStreamManager();
 
+  /**
+   * Receives an RPC message that does not expect a reply. The default implementation will
+   * call "{@link receive(TransportClient, byte[], RpcResponseCallback}" and log a warning if
+   * any of the callback methods are called.
+   *
+   * @param client A channel client which enables the handler to make requests back to the sender
+   *               of this RPC. This will always be the exact same object for a particular channel.
+   * @param message The serialized bytes of the RPC.
+   */
+  public void receive(TransportClient client, byte[] message) {
+    receive(client, message, ONE_WAY_CALLBACK);
+  }
+
   /**
    * Invoked when the connection associated with the given client has been invalidated.
    * No further requests will come from this client.
@@ -54,4 +73,21 @@ public abstract class RpcHandler {
   public void connectionTerminated(TransportClient client) { }
 
   public void exceptionCaught(Throwable cause, TransportClient client) { }
+
+  private static class OneWayRpcCallback implements RpcResponseCallback {
+
+    private final Logger logger = LoggerFactory.getLogger(OneWayRpcCallback.class);
+
+    @Override
+    public void onSuccess(byte[] response) {
+      logger.warn("Response provided for one-way RPC.");
+    }
+
+    @Override
+    public void onFailure(Throwable e) {
+      logger.error("Error response provided for one-way RPC.", e);
+    }
+
+  }
+
 }
diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index 4f67bd573b..db18ea77d1 100644
--- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.network.server;
 
+import com.google.common.base.Preconditions;
 import com.google.common.base.Throwables;
 import io.netty.channel.Channel;
 import io.netty.channel.ChannelFuture;
@@ -27,13 +28,14 @@ import org.slf4j.LoggerFactory;
 import org.apache.spark.network.buffer.ManagedBuffer;
 import org.apache.spark.network.client.RpcResponseCallback;
 import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.network.protocol.Encodable;
-import org.apache.spark.network.protocol.RequestMessage;
 import org.apache.spark.network.protocol.ChunkFetchRequest;
-import org.apache.spark.network.protocol.RpcRequest;
 import org.apache.spark.network.protocol.ChunkFetchFailure;
 import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.Encodable;
+import org.apache.spark.network.protocol.OneWayMessage;
+import org.apache.spark.network.protocol.RequestMessage;
 import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcRequest;
 import org.apache.spark.network.protocol.RpcResponse;
 import org.apache.spark.network.protocol.StreamFailure;
 import org.apache.spark.network.protocol.StreamRequest;
@@ -95,6 +97,8 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
       processFetchRequest((ChunkFetchRequest) request);
     } else if (request instanceof RpcRequest) {
       processRpcRequest((RpcRequest) request);
+    } else if (request instanceof OneWayMessage) {
+      processOneWayMessage((OneWayMessage) request);
     } else if (request instanceof StreamRequest) {
       processStreamRequest((StreamRequest) request);
     } else {
@@ -156,6 +160,14 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
     }
   }
 
+  private void processOneWayMessage(OneWayMessage req) {
+    try {
+      rpcHandler.receive(reverseClient, req.message);
+    } catch (Exception e) {
+      logger.error("Error while invoking RpcHandler#receive() for one-way message.", e);
+    }
+  }
+
   /**
    * Responds to a single message with some Encodable object. If a failure occurs while sending,
    * it will be logged and the channel closed.
diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
index 22b451fc0e..1aa20900ff 100644
--- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -35,6 +35,7 @@ import org.apache.spark.network.protocol.ChunkFetchSuccess;
 import org.apache.spark.network.protocol.Message;
 import org.apache.spark.network.protocol.MessageDecoder;
 import org.apache.spark.network.protocol.MessageEncoder;
+import org.apache.spark.network.protocol.OneWayMessage;
 import org.apache.spark.network.protocol.RpcFailure;
 import org.apache.spark.network.protocol.RpcRequest;
 import org.apache.spark.network.protocol.RpcResponse;
@@ -84,6 +85,7 @@ public class ProtocolSuite {
     testClientToServer(new RpcRequest(12345, new byte[0]));
     testClientToServer(new RpcRequest(12345, new byte[100]));
     testClientToServer(new StreamRequest("abcde"));
+    testClientToServer(new OneWayMessage(new byte[100]));
   }
 
   @Test
diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
index 8eb56bdd98..88fa2258bb 100644
--- a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
@@ -17,9 +17,11 @@
 
 package org.apache.spark.network;
 
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.Iterator;
+import java.util.List;
 import java.util.Set;
 import java.util.concurrent.Semaphore;
 import java.util.concurrent.TimeUnit;
@@ -46,6 +48,7 @@ public class RpcIntegrationSuite {
   static TransportServer server;
   static TransportClientFactory clientFactory;
   static RpcHandler rpcHandler;
+  static List<String> oneWayMsgs;
 
   @BeforeClass
   public static void setUp() throws Exception {
@@ -64,12 +67,19 @@ public class RpcIntegrationSuite {
         }
       }
 
+      @Override
+      public void receive(TransportClient client, byte[] message) {
+        String msg = new String(message, Charsets.UTF_8);
+        oneWayMsgs.add(msg);
+      }
+
       @Override
       public StreamManager getStreamManager() { return new OneForOneStreamManager(); }
     };
     TransportContext context = new TransportContext(conf, rpcHandler);
     server = context.createServer();
     clientFactory = context.createClientFactory();
+    oneWayMsgs = new ArrayList<>();
   }
 
   @AfterClass
@@ -158,6 +168,27 @@ public class RpcIntegrationSuite {
     assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !"));
   }
 
+  @Test
+  public void sendOneWayMessage() throws Exception {
+    final String message = "no reply";
+    TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+    try {
+      client.send(message.getBytes(Charsets.UTF_8));
+      assertEquals(0, client.getHandler().numOutstandingRequests());
+
+      // Make sure the message arrives.
+      long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS);
+      while (System.nanoTime() < deadline && oneWayMsgs.size() == 0) {
+        TimeUnit.MILLISECONDS.sleep(10);
+      }
+
+      assertEquals(1, oneWayMsgs.size());
+      assertEquals(message, oneWayMsgs.get(0));
+    } finally {
+      client.close();
+    }
+  }
+
   private void assertErrorsContain(Set<String> errors, Set<String> contains) {
     assertEquals(contains.size(), errors.size());
 
diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
index b146899670..a6f180bc40 100644
--- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -21,6 +21,7 @@ import static org.junit.Assert.*;
 import static org.mockito.Mockito.*;
 
 import java.io.File;
+import java.lang.reflect.Method;
 import java.nio.charset.StandardCharsets;
 import java.util.Arrays;
 import java.util.List;
@@ -353,6 +354,14 @@ public class SparkSaslSuite {
     verify(handler).exceptionCaught(any(Throwable.class), any(TransportClient.class));
   }
 
+  @Test
+  public void testDelegates() throws Exception {
+    Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods();
+    for (Method m : rpcHandlerMethods) {
+      SaslRpcHandler.class.getDeclaredMethod(m.getName(), m.getParameterTypes());
+    }
+  }
+
   private static class SaslTestCtx {
 
     final TransportClient client;
-- 
GitLab