diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index e40135fdd7a55e6736f51a192d7d18a71635ad6c..2386f33f8ad414cc592228f90a590503add3b4eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -159,8 +159,8 @@ trait ProgressReporter extends Logging { name = name, timestamp = formatTimestamp(currentTriggerStartTimestamp), batchId = currentBatchId, - durationMs = currentDurationsMs.toMap.mapValues(long2Long).asJava, - eventTime = executionStats.eventTimeStats.asJava, + durationMs = new java.util.HashMap(currentDurationsMs.toMap.mapValues(long2Long).asJava), + eventTime = new java.util.HashMap(executionStats.eventTimeStats.asJava), stateOperators = executionStats.stateOperators.toArray, sources = sourceProgress.toArray, sink = sinkProgress) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala new file mode 100644 index 0000000000000000000000000000000000000000..020c9cb4a7304c913afdd3de197e5f3596b76fab --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryWrapper.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.util.UUID + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamingQueryProgress, StreamingQueryStatus} + +/** + * Wrap non-serializable StreamExecution to make the query serializable as it's easy to for it to + * get captured with normal usage. It's safe to capture the query but not use it in executors. + * However, if the user tries to call its methods, it will throw `IllegalStateException`. + */ +class StreamingQueryWrapper(@transient private val _streamingQuery: StreamExecution) + extends StreamingQuery with Serializable { + + def streamingQuery: StreamExecution = { + /** Assert the codes run in the driver. */ + if (_streamingQuery == null) { + throw new IllegalStateException("StreamingQuery cannot be used in executors") + } + _streamingQuery + } + + override def name: String = { + streamingQuery.name + } + + override def id: UUID = { + streamingQuery.id + } + + override def runId: UUID = { + streamingQuery.runId + } + + override def awaitTermination(): Unit = { + streamingQuery.awaitTermination() + } + + override def awaitTermination(timeoutMs: Long): Boolean = { + streamingQuery.awaitTermination(timeoutMs) + } + + override def stop(): Unit = { + streamingQuery.stop() + } + + override def processAllAvailable(): Unit = { + streamingQuery.processAllAvailable() + } + + override def isActive: Boolean = { + streamingQuery.isActive + } + + override def lastProgress: StreamingQueryProgress = { + streamingQuery.lastProgress + } + + override def explain(): Unit = { + streamingQuery.explain() + } + + override def explain(extended: Boolean): Unit = { + streamingQuery.explain(extended) + } + + /** + * This method is called in Python. Python cannot call "explain" directly as it outputs in the JVM + * process, which may not be visible in Python process. + */ + def explainInternal(extended: Boolean): String = { + streamingQuery.explainInternal(extended) + } + + override def sparkSession: SparkSession = { + streamingQuery.sparkSession + } + + override def recentProgress: Array[StreamingQueryProgress] = { + streamingQuery.recentProgress + } + + override def status: StreamingQueryStatus = { + streamingQuery.status + } + + override def exception: Option[StreamingQueryException] = { + streamingQuery.exception + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 6ebd70685effc01a36aa6c4c4e3051cf36aa11e3..8c26ee2bd3fcded5cad4eeb28e2a56caaa1220a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -193,7 +193,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) { useTempCheckpointLocation: Boolean, recoverFromCheckpointLocation: Boolean, trigger: Trigger, - triggerClock: Clock): StreamExecution = { + triggerClock: Clock): StreamingQueryWrapper = { val checkpointLocation = userSpecifiedCheckpointLocation.map { userSpecified => new Path(userSpecified).toUri.toString }.orElse { @@ -229,7 +229,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) { UnsupportedOperationChecker.checkForStreaming(analyzedPlan, outputMode) } - new StreamExecution( + new StreamingQueryWrapper(new StreamExecution( sparkSession, userSpecifiedName.orNull, checkpointLocation, @@ -237,7 +237,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) { sink, trigger, triggerClock, - outputMode) + outputMode)) } /** @@ -301,7 +301,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) { // As it's provided by the user and can run arbitrary codes, we must not hold any lock here. // Otherwise, it's easy to cause dead-lock, or block too long if the user codes take a long // time to finish. - query.start() + query.streamingQuery.start() } catch { case e: Throwable => activeQueriesLock.synchronized { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala index 44befa0d2ff76a657dc0088fdab4517ddc218210..c2befa6343ba91bf3cb92ff3d539b1079c9ad377 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala @@ -22,7 +22,10 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.annotation.Experimental + /** + * :: Experimental :: * Reports information about the instantaneous status of a streaming query. * * @param message A human readable description of what the stream is currently doing. @@ -32,10 +35,11 @@ import org.json4s.jackson.JsonMethods._ * * @since 2.1.0 */ +@Experimental class StreamingQueryStatus protected[sql]( val message: String, val isDataAvailable: Boolean, - val isTriggerActive: Boolean) { + val isTriggerActive: Boolean) extends Serializable { /** The compact JSON representation of this status. */ def json: String = compact(render(jsonValue)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala index e219cfde1265639f43e8ec7b07a5f6fd888ccbef..bea0b9e29784105dda83cad259e38236ac25a8c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala @@ -38,7 +38,7 @@ import org.apache.spark.annotation.Experimental @Experimental class StateOperatorProgress private[sql]( val numRowsTotal: Long, - val numRowsUpdated: Long) { + val numRowsUpdated: Long) extends Serializable { /** The compact JSON representation of this progress. */ def json: String = compact(render(jsonValue)) @@ -90,7 +90,7 @@ class StreamingQueryProgress private[sql]( val eventTime: ju.Map[String, String], val stateOperators: Array[StateOperatorProgress], val sources: Array[SourceProgress], - val sink: SinkProgress) { + val sink: SinkProgress) extends Serializable { /** The aggregate (across all sources) number of records processed in a trigger. */ def numInputRows: Long = sources.map(_.numInputRows).sum @@ -157,7 +157,7 @@ class SourceProgress protected[sql]( val endOffset: String, val numInputRows: Long, val inputRowsPerSecond: Double, - val processedRowsPerSecond: Double) { + val processedRowsPerSecond: Double) extends Serializable { /** The compact JSON representation of this progress. */ def json: String = compact(render(jsonValue)) @@ -197,7 +197,7 @@ class SourceProgress protected[sql]( */ @Experimental class SinkProgress protected[sql]( - val description: String) { + val description: String) extends Serializable { /** The compact JSON representation of this progress. */ def json: String = compact(render(jsonValue)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index b96ccb4e6cbf54c818f2f22c9dfe08b3d86d7cf4..cbcc98316b6d33a90c9ee5c38a45d79f4884cc00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -746,7 +746,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest { .format("memory") .queryName("file_data") .start() - .asInstanceOf[StreamExecution] + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery q.processAllAvailable() val memorySink = q.sink.asInstanceOf[MemorySink] val fileSource = q.logicalPlan.collect { @@ -836,7 +837,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest { df.explain() val q = df.writeStream.queryName("file_explain").format("memory").start() - .asInstanceOf[StreamExecution] + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery try { assert("No physical plan. Waiting for data." === q.explainInternal(false)) assert("No physical plan. Waiting for data." === q.explainInternal(true)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 6bdf47901ae68040710406272ff6b74386cdf98f..4a64054f63db86f5cf10dff046008b59f29211e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.types.{IntegerType, StructField, StructType} -import org.apache.spark.util.ManualClock class StreamSuite extends StreamTest { @@ -278,7 +277,8 @@ class StreamSuite extends StreamTest { // Test `explain` not throwing errors df.explain() val q = df.writeStream.queryName("memory_explain").format("memory").start() - .asInstanceOf[StreamExecution] + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery try { assert("No physical plan. Waiting for data." === q.explainInternal(false)) assert("No physical plan. Waiting for data." === q.explainInternal(true)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 10f267e11532087a8a778a2922dacf22fd6481fc..6fbbbb1f8e0380723e45adc2268ea21421de071d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -355,7 +355,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { outputMode, trigger = trigger, triggerClock = triggerClock) - .asInstanceOf[StreamExecution] + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery currentStream.microBatchThread.setUncaughtExceptionHandler( new UncaughtExceptionHandler { override def uncaughtException(t: Thread, e: Throwable): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala index 1742a5474cfd32e04e58a0618f613bc22bd8c219..8e16fd418a37c81af2a80388aad129e6a97df879 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala @@ -244,7 +244,7 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { failAfter(streamingTimeout) { val queries = withClue("Error starting queries") { datasets.zipWithIndex.map { case (ds, i) => - @volatile var query: StreamExecution = null + var query: StreamingQuery = null try { val df = ds.toDF val metadataRoot = @@ -256,7 +256,6 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { .option("checkpointLocation", metadataRoot) .outputMode("append") .start() - .asInstanceOf[StreamExecution] } catch { case NonFatal(e) => if (query != null) query.stop() @@ -304,7 +303,7 @@ class StreamingQueryManagerSuite extends StreamTest with BeforeAndAfter { Thread.sleep(stopAfter.toMillis) if (withError) { logDebug(s"Terminating query ${queryToStop.name} with error") - queryToStop.asInstanceOf[StreamExecution].logicalPlan.collect { + queryToStop.asInstanceOf[StreamingQueryWrapper].streamingQuery.logicalPlan.collect { case StreamingExecutionRelation(source, _) => source.asInstanceOf[MemoryStream[Int]].addData(0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala index c970743a31ad6c7ea7e3e9770f7018b308895442..34bf3985bad2c92f0d826bbece60eb3f038cbb36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala @@ -24,11 +24,12 @@ import scala.collection.JavaConverters._ import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.StreamingQueryStatusAndProgressSuite._ -class StreamingQueryStatusAndProgressSuite extends SparkFunSuite { +class StreamingQueryStatusAndProgressSuite extends StreamTest { test("StreamingQueryProgress - prettyJson") { val json1 = testProgress1.prettyJson @@ -128,6 +129,42 @@ class StreamingQueryStatusAndProgressSuite extends SparkFunSuite { test("StreamingQueryStatus - toString") { assert(testStatus.toString === testStatus.prettyJson) } + + test("progress classes should be Serializable") { + import testImplicits._ + + val inputData = MemoryStream[Int] + + val query = inputData.toDS() + .groupBy($"value") + .agg(count("*")) + .writeStream + .queryName("progress_serializable_test") + .format("memory") + .outputMode("complete") + .start() + try { + inputData.addData(1, 2, 3) + query.processAllAvailable() + + val progress = query.recentProgress + + // Make sure it generates the progress objects we want to test + assert(progress.exists { p => + p.sources.size >= 1 && p.stateOperators.size >= 1 && p.sink != null + }) + + val array = spark.sparkContext.parallelize(progress).collect() + assert(array.length === progress.length) + array.zip(progress).foreach { case (p1, p2) => + // Make sure we did serialize and deserialize the object + assert(p1 ne p2) + assert(p1.json === p2.json) + } + } finally { + query.stop() + } + } } object StreamingQueryStatusAndProgressSuite { @@ -137,12 +174,12 @@ object StreamingQueryStatusAndProgressSuite { name = "myName", timestamp = "2016-12-05T20:54:20.827Z", batchId = 2L, - durationMs = Map("total" -> 0L).mapValues(long2Long).asJava, - eventTime = Map( + durationMs = new java.util.HashMap(Map("total" -> 0L).mapValues(long2Long).asJava), + eventTime = new java.util.HashMap(Map( "max" -> "2016-12-05T20:54:20.827Z", "min" -> "2016-12-05T20:54:20.827Z", "avg" -> "2016-12-05T20:54:20.827Z", - "watermark" -> "2016-12-05T20:54:20.827Z").asJava, + "watermark" -> "2016-12-05T20:54:20.827Z").asJava), stateOperators = Array(new StateOperatorProgress(numRowsTotal = 0, numRowsUpdated = 1)), sources = Array( new SourceProgress( @@ -163,8 +200,9 @@ object StreamingQueryStatusAndProgressSuite { name = null, // should not be present in the json timestamp = "2016-12-05T20:54:20.827Z", batchId = 2L, - durationMs = Map("total" -> 0L).mapValues(long2Long).asJava, - eventTime = Map.empty[String, String].asJava, // empty maps should be handled correctly + durationMs = new java.util.HashMap(Map("total" -> 0L).mapValues(long2Long).asJava), + // empty maps should be handled correctly + eventTime = new java.util.HashMap(Map.empty[String, String].asJava), stateOperators = Array(new StateOperatorProgress(numRowsTotal = 0, numRowsUpdated = 1)), sources = Array( new SourceProgress( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index b052bd9e6a53b129807c6c484310e87c33e19817..6c4bb35ccb2a6772eece55d9f98bb4ab558e0c08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -26,7 +26,7 @@ import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.apache.spark.internal.Logging -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType import org.apache.spark.SparkException import org.apache.spark.sql.execution.streaming._ @@ -439,6 +439,48 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { } } + test("StreamingQuery should be Serializable but cannot be used in executors") { + def startQuery(ds: Dataset[Int], queryName: String): StreamingQuery = { + ds.writeStream + .queryName(queryName) + .format("memory") + .start() + } + + val input = MemoryStream[Int] + val q1 = startQuery(input.toDS, "stream_serializable_test_1") + val q2 = startQuery(input.toDS.map { i => + // Emulate that `StreamingQuery` get captured with normal usage unintentionally. + // It should not fail the query. + q1 + i + }, "stream_serializable_test_2") + val q3 = startQuery(input.toDS.map { i => + // Emulate that `StreamingQuery` is used in executors. We should fail the query with a clear + // error message. + q1.explain() + i + }, "stream_serializable_test_3") + try { + input.addData(1) + + // q2 should not fail since it doesn't use `q1` in the closure + q2.processAllAvailable() + + // The user calls `StreamingQuery` in the closure and it should fail + val e = intercept[StreamingQueryException] { + q3.processAllAvailable() + } + assert(e.getCause.isInstanceOf[SparkException]) + assert(e.getCause.getCause.isInstanceOf[IllegalStateException]) + assert(e.getMessage.contains("StreamingQuery cannot be used in executors")) + } finally { + q1.stop() + q2.stop() + q3.stop() + } + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index f4a62903ebeb16d67b87ab34a89ac442095e757f..acac0bfb0e253ed1eb461b00052f4ae9a8aa7eda 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -339,7 +339,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { .start() q.stop() - assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(10000)) + assert(q.asInstanceOf[StreamingQueryWrapper].streamingQuery.trigger == ProcessingTime(10000)) q = df.writeStream .format("org.apache.spark.sql.streaming.test") @@ -348,7 +348,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { .start() q.stop() - assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(100000)) + assert(q.asInstanceOf[StreamingQueryWrapper].streamingQuery.trigger == ProcessingTime(100000)) } test("source metadataPath") {