From 22662b241629b56205719ede2f801a476e10a3cd Mon Sep 17 00:00:00 2001
From: Shixiong Zhu <shixiong@databricks.com>
Date: Tue, 26 Jan 2016 17:24:40 -0800
Subject: [PATCH] [SPARK-12614][CORE] Don't throw non fatal exception from ask

Right now RpcEndpointRef.ask may throw exception in some corner cases, such as calling ask after stopping RpcEnv. It's better to avoid throwing exception from RpcEndpointRef.ask. We can send the exception to the future for `ask`.

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #10568 from zsxwing/send-ask-fail.
---
 .../apache/spark/rpc/netty/NettyRpcEnv.scala  | 54 ++++++++++---------
 1 file changed, 29 insertions(+), 25 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index ef876b1d8c..9ae74d9d7b 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -211,33 +211,37 @@ private[netty] class NettyRpcEnv(
         }
     }
 
-    if (remoteAddr == address) {
-      val p = Promise[Any]()
-      p.future.onComplete {
-        case Success(response) => onSuccess(response)
-        case Failure(e) => onFailure(e)
-      }(ThreadUtils.sameThread)
-      dispatcher.postLocalMessage(message, p)
-    } else {
-      val rpcMessage = RpcOutboxMessage(serialize(message),
-        onFailure,
-        (client, response) => onSuccess(deserialize[Any](client, response)))
-      postToOutbox(message.receiver, rpcMessage)
-      promise.future.onFailure {
-        case _: TimeoutException => rpcMessage.onTimeout()
-        case _ =>
+    try {
+      if (remoteAddr == address) {
+        val p = Promise[Any]()
+        p.future.onComplete {
+          case Success(response) => onSuccess(response)
+          case Failure(e) => onFailure(e)
+        }(ThreadUtils.sameThread)
+        dispatcher.postLocalMessage(message, p)
+      } else {
+        val rpcMessage = RpcOutboxMessage(serialize(message),
+          onFailure,
+          (client, response) => onSuccess(deserialize[Any](client, response)))
+        postToOutbox(message.receiver, rpcMessage)
+        promise.future.onFailure {
+          case _: TimeoutException => rpcMessage.onTimeout()
+          case _ =>
+        }(ThreadUtils.sameThread)
+      }
+
+      val timeoutCancelable = timeoutScheduler.schedule(new Runnable {
+        override def run(): Unit = {
+          onFailure(new TimeoutException(s"Cannot receive any reply in ${timeout.duration}"))
+        }
+      }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
+      promise.future.onComplete { v =>
+        timeoutCancelable.cancel(true)
       }(ThreadUtils.sameThread)
+    } catch {
+      case NonFatal(e) =>
+        onFailure(e)
     }
-
-    val timeoutCancelable = timeoutScheduler.schedule(new Runnable {
-      override def run(): Unit = {
-        promise.tryFailure(
-          new TimeoutException(s"Cannot receive any reply in ${timeout.duration}"))
-      }
-    }, timeout.duration.toNanos, TimeUnit.NANOSECONDS)
-    promise.future.onComplete { v =>
-      timeoutCancelable.cancel(true)
-    }(ThreadUtils.sameThread)
     promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
   }
 
-- 
GitLab