diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 432b2d4925ae28f48cd46ccd63557d545f4b1dbf..c224f2f9f1404809213dd730050427e956aabb43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.execution.streaming -import java.io.{InterruptedIOException, IOException} +import java.io.{InterruptedIOException, IOException, UncheckedIOException} +import java.nio.channels.ClosedByInterruptException import java.util.UUID -import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.{CountDownLatch, ExecutionException, TimeUnit} import java.util.concurrent.atomic.AtomicReference import java.util.concurrent.locks.ReentrantLock @@ -27,6 +28,7 @@ import scala.collection.mutable.{Map => MutableMap} import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal +import com.google.common.util.concurrent.UncheckedExecutionException import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging @@ -335,7 +337,7 @@ class StreamExecution( // `stop()` is already called. Let `finally` finish the cleanup. } } catch { - case _: InterruptedException | _: InterruptedIOException if state.get == TERMINATED => + case e if isInterruptedByStop(e) => // interrupted by stop() updateStatusMessage("Stopped") case e: IOException if e.getMessage != null @@ -407,6 +409,32 @@ class StreamExecution( } } + private def isInterruptedByStop(e: Throwable): Boolean = { + if (state.get == TERMINATED) { + e match { + // InterruptedIOException - thrown when an I/O operation is interrupted + // ClosedByInterruptException - thrown when an I/O operation upon a channel is interrupted + case _: InterruptedException | _: InterruptedIOException | _: ClosedByInterruptException => + true + // The cause of the following exceptions may be one of the above exceptions: + // + // UncheckedIOException - thrown by codes that cannot throw a checked IOException, such as + // BiFunction.apply + // ExecutionException - thrown by codes running in a thread pool and these codes throw an + // exception + // UncheckedExecutionException - thrown by codes that cannot throw a checked + // ExecutionException, such as BiFunction.apply + case e2 @ (_: UncheckedIOException | _: ExecutionException | _: UncheckedExecutionException) + if e2.getCause != null => + isInterruptedByStop(e2.getCause) + case _ => + false + } + } else { + false + } + } + /** * Populate the start offsets to start the execution at the current offsets stored in the sink * (i.e. avoid reprocessing data that we have already processed). This function must be called diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 012cccfdd9166e93e12fd4e74fb115c8b01d924a..d0b2041a8644f770672bd11ad00dc087caae4b2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.streaming -import java.io.{File, InterruptedIOException, IOException} -import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} +import java.io.{File, InterruptedIOException, IOException, UncheckedIOException} +import java.nio.channels.ClosedByInterruptException +import java.util.concurrent.{CountDownLatch, ExecutionException, TimeoutException, TimeUnit} import scala.reflect.ClassTag import scala.util.control.ControlThrowable +import com.google.common.util.concurrent.UncheckedExecutionException import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration @@ -691,6 +693,31 @@ class StreamSuite extends StreamTest { } } } + + for (e <- Seq( + new InterruptedException, + new InterruptedIOException, + new ClosedByInterruptException, + new UncheckedIOException("test", new ClosedByInterruptException), + new ExecutionException("test", new InterruptedException), + new UncheckedExecutionException("test", new InterruptedException))) { + test(s"view ${e.getClass.getSimpleName} as a normal query stop") { + ThrowingExceptionInCreateSource.createSourceLatch = new CountDownLatch(1) + ThrowingExceptionInCreateSource.exception = e + val query = spark + .readStream + .format(classOf[ThrowingExceptionInCreateSource].getName) + .load() + .writeStream + .format("console") + .start() + assert(ThrowingExceptionInCreateSource.createSourceLatch + .await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS), + "ThrowingExceptionInCreateSource.createSource wasn't called before timeout") + query.stop() + assert(query.exception.isEmpty) + } + } } abstract class FakeSource extends StreamSourceProvider { @@ -824,3 +851,32 @@ class TestStateStoreProvider extends StateStoreProvider { override def getStore(version: Long): StateStore = null } + +/** A fake source that throws `ThrowingExceptionInCreateSource.exception` in `createSource` */ +class ThrowingExceptionInCreateSource extends FakeSource { + + override def createSource( + spark: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + ThrowingExceptionInCreateSource.createSourceLatch.countDown() + try { + Thread.sleep(30000) + throw new TimeoutException("sleep was not interrupted in 30 seconds") + } catch { + case _: InterruptedException => + throw ThrowingExceptionInCreateSource.exception + } + } +} + +object ThrowingExceptionInCreateSource { + /** + * A latch to allow the user to wait until `ThrowingExceptionInCreateSource.createSource` is + * called. + */ + @volatile var createSourceLatch: CountDownLatch = null + @volatile var exception: Exception = null +}