Skip to content
Snippets Groups Projects
Commit 9314c083 authored by Burak Yavuz's avatar Burak Yavuz Committed by Shixiong Zhu
Browse files

[SPARK-19774] StreamExecution should call stop() on sources when a stream fails

## What changes were proposed in this pull request?

We call stop() on a Structured Streaming Source only when the stream is shutdown when a user calls streamingQuery.stop(). We should actually stop all sources when the stream fails as well, otherwise we may leak resources, e.g. connections to Kafka.

## How was this patch tested?

Unit tests in `StreamingQuerySuite`.

Author: Burak Yavuz <brkyvz@gmail.com>

Closes #17107 from brkyvz/close-source.
parent 37a1c0e4
No related branches found
No related tags found
No related merge requests found
...@@ -321,6 +321,7 @@ class StreamExecution( ...@@ -321,6 +321,7 @@ class StreamExecution(
initializationLatch.countDown() initializationLatch.countDown()
try { try {
stopSources()
state.set(TERMINATED) state.set(TERMINATED)
currentStatus = status.copy(isTriggerActive = false, isDataAvailable = false) currentStatus = status.copy(isTriggerActive = false, isDataAvailable = false)
...@@ -558,6 +559,18 @@ class StreamExecution( ...@@ -558,6 +559,18 @@ class StreamExecution(
sparkSession.streams.postListenerEvent(event) sparkSession.streams.postListenerEvent(event)
} }
/** Stops all streaming sources safely. */
private def stopSources(): Unit = {
uniqueSources.foreach { source =>
try {
source.stop()
} catch {
case NonFatal(e) =>
logWarning(s"Failed to stop streaming source: $source. Resources may have leaked.", e)
}
}
}
/** /**
* Signals to the thread executing micro-batches that it should stop running after the next * Signals to the thread executing micro-batches that it should stop running after the next
* batch. This method blocks until the thread stops running. * batch. This method blocks until the thread stops running.
...@@ -570,7 +583,6 @@ class StreamExecution( ...@@ -570,7 +583,6 @@ class StreamExecution(
microBatchThread.interrupt() microBatchThread.interrupt()
microBatchThread.join() microBatchThread.join()
} }
uniqueSources.foreach(_.stop())
logInfo(s"Query $prettyIdString was stopped") logInfo(s"Query $prettyIdString was stopped")
} }
......
...@@ -20,10 +20,12 @@ package org.apache.spark.sql.streaming ...@@ -20,10 +20,12 @@ package org.apache.spark.sql.streaming
import java.util.concurrent.CountDownLatch import java.util.concurrent.CountDownLatch
import org.apache.commons.lang3.RandomStringUtils import org.apache.commons.lang3.RandomStringUtils
import org.mockito.Mockito._
import org.scalactic.TolerantNumerics import org.scalactic.TolerantNumerics
import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Eventually._
import org.scalatest.BeforeAndAfter import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.mock.MockitoSugar
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.{DataFrame, Dataset}
...@@ -32,11 +34,11 @@ import org.apache.spark.SparkException ...@@ -32,11 +34,11 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.util.BlockingSource import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider}
import org.apache.spark.util.ManualClock import org.apache.spark.util.ManualClock
class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging with MockitoSugar {
import AwaitTerminationTester._ import AwaitTerminationTester._
import testImplicits._ import testImplicits._
...@@ -481,6 +483,75 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { ...@@ -481,6 +483,75 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging {
} }
} }
test("StreamExecution should call stop() on sources when a stream is stopped") {
var calledStop = false
val source = new Source {
override def stop(): Unit = {
calledStop = true
}
override def getOffset: Option[Offset] = None
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
spark.emptyDataFrame
}
override def schema: StructType = MockSourceProvider.fakeSchema
}
MockSourceProvider.withMockSources(source) {
val df = spark.readStream
.format("org.apache.spark.sql.streaming.util.MockSourceProvider")
.load()
testStream(df)(StopStream)
assert(calledStop, "Did not call stop on source for stopped stream")
}
}
testQuietly("SPARK-19774: StreamExecution should call stop() on sources when a stream fails") {
var calledStop = false
val source1 = new Source {
override def stop(): Unit = {
throw new RuntimeException("Oh no!")
}
override def getOffset: Option[Offset] = Some(LongOffset(1))
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
spark.range(2).toDF(MockSourceProvider.fakeSchema.fieldNames: _*)
}
override def schema: StructType = MockSourceProvider.fakeSchema
}
val source2 = new Source {
override def stop(): Unit = {
calledStop = true
}
override def getOffset: Option[Offset] = None
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
spark.emptyDataFrame
}
override def schema: StructType = MockSourceProvider.fakeSchema
}
MockSourceProvider.withMockSources(source1, source2) {
val df1 = spark.readStream
.format("org.apache.spark.sql.streaming.util.MockSourceProvider")
.load()
.as[Int]
val df2 = spark.readStream
.format("org.apache.spark.sql.streaming.util.MockSourceProvider")
.load()
.as[Int]
testStream(df1.union(df2).map(i => i / 0))(
AssertOnQuery { sq =>
intercept[StreamingQueryException](sq.processAllAvailable())
sq.exception.isDefined && !sq.isActive
}
)
assert(calledStop, "Did not call stop on source for stopped stream")
}
}
/** Create a streaming DF that only execute one batch in which it returns the given static DF */ /** Create a streaming DF that only execute one batch in which it returns the given static DF */
private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = {
require(!triggerDF.isStreaming) require(!triggerDF.isStreaming)
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.streaming.util
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Source
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
/**
* A StreamSourceProvider that provides mocked Sources for unit testing. Example usage:
*
* {{{
* MockSourceProvider.withMockSources(source1, source2) {
* val df1 = spark.readStream
* .format("org.apache.spark.sql.streaming.util.MockSourceProvider")
* .load()
*
* val df2 = spark.readStream
* .format("org.apache.spark.sql.streaming.util.MockSourceProvider")
* .load()
*
* df1.union(df2)
* ...
* }
* }}}
*/
class MockSourceProvider extends StreamSourceProvider {
override def sourceSchema(
spark: SQLContext,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): (String, StructType) = {
("dummySource", MockSourceProvider.fakeSchema)
}
override def createSource(
spark: SQLContext,
metadataPath: String,
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): Source = {
MockSourceProvider.sourceProviderFunction()
}
}
object MockSourceProvider {
// Function to generate sources. May provide multiple sources if the user implements such a
// function.
private var sourceProviderFunction: () => Source = _
final val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)
def withMockSources(source: Source, otherSources: Source*)(f: => Unit): Unit = {
var i = 0
val sources = source +: otherSources
sourceProviderFunction = () => {
val source = sources(i % sources.length)
i += 1
source
}
try {
f
} finally {
sourceProviderFunction = null
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment