diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index ee60d697d8799261cca27a943fce0f5db254d997..1f1f0b75de5f1505dd5fc2d9b65198c89720434d 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -20,6 +20,7 @@ package org.apache.spark import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable +import scala.concurrent.Future import org.apache.spark.executor.TaskMetrics import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} @@ -147,11 +148,31 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) } } + /** + * Send ExecutorRegistered to the event loop to add a new executor. Only for test. + * + * @return if HeartbeatReceiver is stopped, return None. Otherwise, return a Some(Future) that + * indicate if this operation is successful. + */ + def addExecutor(executorId: String): Option[Future[Boolean]] = { + Option(self).map(_.ask[Boolean](ExecutorRegistered(executorId))) + } + /** * If the heartbeat receiver is not stopped, notify it of executor registrations. */ override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { - Option(self).foreach(_.ask[Boolean](ExecutorRegistered(executorAdded.executorId))) + addExecutor(executorAdded.executorId) + } + + /** + * Send ExecutorRemoved to the event loop to remove a executor. Only for test. + * + * @return if HeartbeatReceiver is stopped, return None. Otherwise, return a Some(Future) that + * indicate if this operation is successful. + */ + def removeExecutor(executorId: String): Option[Future[Boolean]] = { + Option(self).map(_.ask[Boolean](ExecutorRemoved(executorId))) } /** @@ -165,7 +186,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) * and expire it with loud error messages. */ override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { - Option(self).foreach(_.ask[Boolean](ExecutorRemoved(executorRemoved.executorId))) + removeExecutor(executorRemoved.executorId) } private def expireDeadHosts(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 139b8dc25f4b4b61fffad82e904ed99cc4a1023d..18f2229fea39b6d0ac50faf6f8f4339f847be730 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -19,7 +19,10 @@ package org.apache.spark import java.util.concurrent.{ExecutorService, TimeUnit} +import scala.collection.Map import scala.collection.mutable +import scala.concurrent.Await +import scala.concurrent.duration._ import scala.language.postfixOps import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} @@ -96,18 +99,18 @@ class HeartbeatReceiverSuite test("normal heartbeat") { heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) - heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) - heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + addExecutorAndVerify(executorId1) + addExecutorAndVerify(executorId2) triggerHeartbeat(executorId1, executorShouldReregister = false) triggerHeartbeat(executorId2, executorShouldReregister = false) - val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen()) + val trackedExecutors = getTrackedExecutors assert(trackedExecutors.size === 2) assert(trackedExecutors.contains(executorId1)) assert(trackedExecutors.contains(executorId2)) } test("reregister if scheduler is not ready yet") { - heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) + addExecutorAndVerify(executorId1) // Task scheduler is not set yet in HeartbeatReceiver, so executors should reregister triggerHeartbeat(executorId1, executorShouldReregister = true) } @@ -116,20 +119,20 @@ class HeartbeatReceiverSuite heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) // Received heartbeat from unknown executor, so we ask it to re-register triggerHeartbeat(executorId1, executorShouldReregister = true) - assert(heartbeatReceiver.invokePrivate(_executorLastSeen()).isEmpty) + assert(getTrackedExecutors.isEmpty) } test("reregister if heartbeat from removed executor") { heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) - heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) - heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + addExecutorAndVerify(executorId1) + addExecutorAndVerify(executorId2) // Remove the second executor but not the first - heartbeatReceiver.onExecutorRemoved(SparkListenerExecutorRemoved(0, executorId2, "bad boy")) + removeExecutorAndVerify(executorId2) // Now trigger the heartbeats // A heartbeat from the second executor should require reregistering triggerHeartbeat(executorId1, executorShouldReregister = false) triggerHeartbeat(executorId2, executorShouldReregister = true) - val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen()) + val trackedExecutors = getTrackedExecutors assert(trackedExecutors.size === 1) assert(trackedExecutors.contains(executorId1)) assert(!trackedExecutors.contains(executorId2)) @@ -138,8 +141,8 @@ class HeartbeatReceiverSuite test("expire dead hosts") { val executorTimeout = heartbeatReceiver.invokePrivate(_executorTimeoutMs()) heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) - heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) - heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + addExecutorAndVerify(executorId1) + addExecutorAndVerify(executorId2) triggerHeartbeat(executorId1, executorShouldReregister = false) triggerHeartbeat(executorId2, executorShouldReregister = false) // Advance the clock and only trigger a heartbeat for the first executor @@ -149,7 +152,7 @@ class HeartbeatReceiverSuite heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts) // Only the second executor should be expired as a dead host verify(scheduler).executorLost(Matchers.eq(executorId2), any()) - val trackedExecutors = heartbeatReceiver.invokePrivate(_executorLastSeen()) + val trackedExecutors = getTrackedExecutors assert(trackedExecutors.size === 1) assert(trackedExecutors.contains(executorId1)) assert(!trackedExecutors.contains(executorId2)) @@ -175,8 +178,8 @@ class HeartbeatReceiverSuite fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisteredExecutor.type]( RegisterExecutor(executorId2, dummyExecutorEndpointRef2, "dummy:4040", 0, Map.empty)) heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) - heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId1, null)) - heartbeatReceiver.onExecutorAdded(SparkListenerExecutorAdded(0, executorId2, null)) + addExecutorAndVerify(executorId1) + addExecutorAndVerify(executorId2) triggerHeartbeat(executorId1, executorShouldReregister = false) triggerHeartbeat(executorId2, executorShouldReregister = false) @@ -222,6 +225,26 @@ class HeartbeatReceiverSuite } } + private def addExecutorAndVerify(executorId: String): Unit = { + assert( + heartbeatReceiver.addExecutor(executorId).map { f => + Await.result(f, 10.seconds) + } === Some(true)) + } + + private def removeExecutorAndVerify(executorId: String): Unit = { + assert( + heartbeatReceiver.removeExecutor(executorId).map { f => + Await.result(f, 10.seconds) + } === Some(true)) + } + + private def getTrackedExecutors: Map[String, Long] = { + // We may receive undesired SparkListenerExecutorAdded from LocalBackend, so exclude it from + // the map. See SPARK-10800. + heartbeatReceiver.invokePrivate(_executorLastSeen()). + filterKeys(_ != SparkContext.DRIVER_IDENTIFIER) + } } // TODO: use these classes to add end-to-end tests for dynamic allocation!