diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala index c20627b056bef01882c44028e8aaa2c5ec089a5a..6c1fca71f228144d5822fc22c882eb1432a82f03 100644 --- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean +import scala.util.DynamicVariable import org.apache.spark.SparkContext @@ -60,25 +61,27 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri private val listenerThread = new Thread(name) { setDaemon(true) override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) { - while (true) { - eventLock.acquire() - self.synchronized { - processingEvent = true - } - try { - val event = eventQueue.poll - if (event == null) { - // Get out of the while loop and shutdown the daemon thread - if (!stopped.get) { - throw new IllegalStateException("Polling `null` from eventQueue means" + - " the listener bus has been stopped. So `stopped` must be true") - } - return - } - postToAll(event) - } finally { + AsynchronousListenerBus.withinListenerThread.withValue(true) { + while (true) { + eventLock.acquire() self.synchronized { - processingEvent = false + processingEvent = true + } + try { + val event = eventQueue.poll + if (event == null) { + // Get out of the while loop and shutdown the daemon thread + if (!stopped.get) { + throw new IllegalStateException("Polling `null` from eventQueue means" + + " the listener bus has been stopped. So `stopped` must be true") + } + return + } + postToAll(event) + } finally { + self.synchronized { + processingEvent = false + } } } } @@ -177,3 +180,10 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri */ def onDropEvent(event: E): Unit } + +private[spark] object AsynchronousListenerBus { + /* Allows for Context to check whether stop() call is made within listener thread + */ + val withinListenerThread: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) +} + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 97113835f3bd0e78797faad4c1a3e3b19bc61872..aee172a4f549ad4ce3af4b5909336e9df1c71b2d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -44,7 +44,7 @@ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} -import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils, Utils} +import org.apache.spark.util.{AsynchronousListenerBus, CallSite, ShutdownHookManager, ThreadUtils, Utils} /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -693,6 +693,10 @@ class StreamingContext private[streaming] ( */ def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = { var shutdownHookRefToRemove: AnyRef = null + if (AsynchronousListenerBus.withinListenerThread.value) { + throw new SparkException("Cannot stop StreamingContext within listener thread of" + + " AsynchronousListenerBus") + } synchronized { try { state match { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 5dc0472c7770cf89aaba5a95d6502f5c03da9c01..df4575ab25aadebaa5688c56e2b9cd988c570a95 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedBuffer, Synch import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits.global +import org.apache.spark.SparkException import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver @@ -161,6 +162,14 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { } } + test("don't call ssc.stop in listener") { + ssc = new StreamingContext("local[2]", "ssc", Milliseconds(1000)) + val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) + inputStream.foreachRDD(_.count) + + startStreamingContextAndCallStop(ssc) + } + test("onBatchCompleted with successful batch") { ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) @@ -207,6 +216,17 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { assert(failureReasons(1).contains("This is another failed job")) } + private def startStreamingContextAndCallStop(_ssc: StreamingContext): Unit = { + val contextStoppingCollector = new StreamingContextStoppingCollector(_ssc) + _ssc.addStreamingListener(contextStoppingCollector) + val batchCounter = new BatchCounter(_ssc) + _ssc.start() + // Make sure running at least one batch + batchCounter.waitUntilBatchesCompleted(expectedNumCompletedBatches = 1, timeout = 10000) + _ssc.stop() + assert(contextStoppingCollector.sparkExSeen) + } + private def startStreamingContextAndCollectFailureReasons( _ssc: StreamingContext, isFailed: Boolean = false): Map[Int, String] = { val failureReasonsCollector = new FailureReasonsCollector() @@ -320,3 +340,17 @@ class FailureReasonsCollector extends StreamingListener { } } } +/** + * A StreamingListener that calls StreamingContext.stop(). + */ +class StreamingContextStoppingCollector(val ssc: StreamingContext) extends StreamingListener { + @volatile var sparkExSeen = false + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { + try { + ssc.stop() + } catch { + case se: SparkException => + sparkExSeen = true + } + } +}