diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index 79fc2e94599c7a0d7704e616e4ec028d2067eceb..fa5ad4e8d81e1db9e338da6dc4bff626c6ac2021 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -52,7 +52,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { * Post the event to all registered listeners. The `postToAll` caller should guarantee calling * `postToAll` in the same thread for all events. */ - final def postToAll(event: E): Unit = { + def postToAll(event: E): Unit = { // JavaConverters can create a JIterableWrapper if we use asScala. // However, this method will be called frequently. To avoid the wrapper cost, here we use // Java Iterator directly. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala index a2153d27e9fefd79bc035e36207c8625d54f2139..4207013c3f75d8edcc86832a71f72c4f08afa5cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala @@ -75,6 +75,19 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) } } + /** + * Override the parent `postToAll` to remove the query id from `activeQueryRunIds` after all + * the listeners process `QueryTerminatedEvent`. (SPARK-19594) + */ + override def postToAll(event: Event): Unit = { + super.postToAll(event) + event match { + case t: QueryTerminatedEvent => + activeQueryRunIds.synchronized { activeQueryRunIds -= t.runId } + case _ => + } + } + override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { case e: StreamingQueryListener.Event => @@ -112,7 +125,6 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) case queryTerminated: QueryTerminatedEvent => if (shouldReport(queryTerminated.runId)) { listener.onQueryTerminated(queryTerminated) - activeQueryRunIds.synchronized { activeQueryRunIds -= queryTerminated.runId } } case _ => } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 4596aa1d348e3612f77e9c3d70ec5a78439d8805..eb09b9ffcfc5de672b451b38243d06246429d992 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -133,6 +133,31 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } } + test("SPARK-19594: all of listeners should receive QueryTerminatedEvent") { + val df = MemoryStream[Int].toDS().as[Long] + val listeners = (1 to 5).map(_ => new EventCollector) + try { + listeners.foreach(listener => spark.streams.addListener(listener)) + testStream(df, OutputMode.Append)( + StartStream(), + StopStream, + AssertOnQuery { query => + eventually(Timeout(streamingTimeout)) { + listeners.foreach(listener => assert(listener.terminationEvent !== null)) + listeners.foreach(listener => assert(listener.terminationEvent.id === query.id)) + listeners.foreach(listener => assert(listener.terminationEvent.runId === query.runId)) + listeners.foreach(listener => assert(listener.terminationEvent.exception === None)) + } + listeners.foreach(listener => listener.checkAsyncErrors()) + listeners.foreach(listener => listener.reset()) + true + } + ) + } finally { + listeners.foreach(spark.streams.removeListener) + } + } + test("adding and removing listener") { def isListenerActive(listener: EventCollector): Boolean = { listener.reset()