From 446738e51fcda50cf2dc44123ff6bf12a1611dc0 Mon Sep 17 00:00:00 2001 From: tedyu <yuzhihong@gmail.com> Date: Tue, 17 Nov 2015 22:47:53 -0800 Subject: [PATCH] [SPARK-11761] Prevent the call to StreamingContext#stop() in the listener bus's thread See discussion toward the tail of https://github.com/apache/spark/pull/9723 From zsxwing : ``` The user should not call stop or other long-time work in a listener since it will block the listener thread, and prevent from stopping SparkContext/StreamingContext. I cannot see an approach since we need to stop the listener bus's thread before stopping SparkContext/StreamingContext totally. ``` Proposed solution is to prevent the call to StreamingContext#stop() in the listener bus's thread. Author: tedyu <yuzhihong@gmail.com> Closes #9741 from tedyu/master. --- .../spark/util/AsynchronousListenerBus.scala | 46 +++++++++++-------- .../spark/streaming/StreamingContext.scala | 6 ++- .../streaming/StreamingListenerSuite.scala | 34 ++++++++++++++ 3 files changed, 67 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala index c20627b056..6c1fca71f2 100644 --- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean +import scala.util.DynamicVariable import org.apache.spark.SparkContext @@ -60,25 +61,27 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri private val listenerThread = new Thread(name) { setDaemon(true) override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) { - while (true) { - eventLock.acquire() - self.synchronized { - processingEvent = true - } - try { - val event = eventQueue.poll - if (event == null) { - // Get out of the while loop and shutdown the daemon thread - if (!stopped.get) { - throw new IllegalStateException("Polling `null` from eventQueue means" + - " the listener bus has been stopped. So `stopped` must be true") - } - return - } - postToAll(event) - } finally { + AsynchronousListenerBus.withinListenerThread.withValue(true) { + while (true) { + eventLock.acquire() self.synchronized { - processingEvent = false + processingEvent = true + } + try { + val event = eventQueue.poll + if (event == null) { + // Get out of the while loop and shutdown the daemon thread + if (!stopped.get) { + throw new IllegalStateException("Polling `null` from eventQueue means" + + " the listener bus has been stopped. So `stopped` must be true") + } + return + } + postToAll(event) + } finally { + self.synchronized { + processingEvent = false + } } } } @@ -177,3 +180,10 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri */ def onDropEvent(event: E): Unit } + +private[spark] object AsynchronousListenerBus { + /* Allows for Context to check whether stop() call is made within listener thread + */ + val withinListenerThread: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) +} + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 97113835f3..aee172a4f5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -44,7 +44,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils, Utils} +import org.apache.spark.util.{AsynchronousListenerBus, CallSite, ShutdownHookManager, ThreadUtils, Utils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -693,6 +693,10 @@ class StreamingContext private[streaming] ( */ def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = { var shutdownHookRefToRemove: AnyRef = null + if (AsynchronousListenerBus.withinListenerThread.value) { + throw new SparkException("Cannot stop StreamingContext within listener thread of" + + " AsynchronousListenerBus") + } synchronized { try { state match { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 5dc0472c77..df4575ab25 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedBuffer, Synch import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits.global +import org.apache.spark.SparkException import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver @@ -161,6 +162,14 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { } } + test("don't call ssc.stop in listener") { + ssc = new StreamingContext("local[2]", "ssc", Milliseconds(1000)) + val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) + inputStream.foreachRDD(_.count) + + startStreamingContextAndCallStop(ssc) + } + test("onBatchCompleted with successful batch") { ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) @@ -207,6 +216,17 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { assert(failureReasons(1).contains("This is another failed job")) } + private def startStreamingContextAndCallStop(_ssc: StreamingContext): Unit = { + val contextStoppingCollector = new StreamingContextStoppingCollector(_ssc) + _ssc.addStreamingListener(contextStoppingCollector) + val batchCounter = new BatchCounter(_ssc) + _ssc.start() + // Make sure running at least one batch + batchCounter.waitUntilBatchesCompleted(expectedNumCompletedBatches = 1, timeout = 10000) + _ssc.stop() + assert(contextStoppingCollector.sparkExSeen) + } + private def startStreamingContextAndCollectFailureReasons( _ssc: StreamingContext, isFailed: Boolean = false): Map[Int, String] = { val failureReasonsCollector = new FailureReasonsCollector() @@ -320,3 +340,17 @@ class FailureReasonsCollector extends StreamingListener { } } } +/** + * A StreamingListener that calls StreamingContext.stop(). + */ +class StreamingContextStoppingCollector(val ssc: StreamingContext) extends StreamingListener { + @volatile var sparkExSeen = false + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { + try { + ssc.stop() + } catch { + case se: SparkException => + sparkExSeen = true + } + } +} -- GitLab