diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index b959444b492981e3f69a846644f7426c3c348cd7..daed1dcb7737026a2617c054abd8fc2861acd24b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -222,14 +222,16 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val sink = new MemorySink(df.schema, outputMode) val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink)) + val chkpointLoc = extraOptions.get("checkpointLocation") + val recoverFromChkpoint = chkpointLoc.isDefined && outputMode == OutputMode.Complete() val query = df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), - extraOptions.get("checkpointLocation"), + chkpointLoc, df, sink, outputMode, useTempCheckpointLocation = true, - recoverFromCheckpointLocation = false, + recoverFromCheckpointLocation = recoverFromChkpoint, trigger = trigger) resultDf.createOrReplaceTempView(query.name) query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index f0994395813e46a1fb91674b9529e10b1889e3c0..5630464f4080395f4c99ad03b69559e5d510c8e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.streaming.test +import java.io.File import java.util.concurrent.TimeUnit import scala.concurrent.duration._ @@ -467,4 +468,68 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { val sq = df.writeStream.format("console").start() sq.stop() } + + test("MemorySink can recover from a checkpoint in Complete Mode") { + import testImplicits._ + val ms = new MemoryStream[Int](0, sqlContext) + val df = ms.toDF().toDF("a") + val checkpointLoc = newMetadataDir + val checkpointDir = new File(checkpointLoc, "offsets") + checkpointDir.mkdirs() + assert(checkpointDir.exists()) + val tableName = "test" + def startQuery: StreamingQuery = { + df.groupBy("a") + .count() + .writeStream + .format("memory") + .queryName(tableName) + .option("checkpointLocation", checkpointLoc) + .outputMode("complete") + .start() + } + // no exception here + val q = startQuery + ms.addData(0, 1) + q.processAllAvailable() + q.stop() + + checkAnswer( + spark.table(tableName), + Seq(Row(0, 1), Row(1, 1)) + ) + spark.sql(s"drop table $tableName") + // verify table is dropped + intercept[AnalysisException](spark.table(tableName).collect()) + val q2 = startQuery + ms.addData(0) + q2.processAllAvailable() + checkAnswer( + spark.table(tableName), + Seq(Row(0, 2), Row(1, 1)) + ) + + q2.stop() + } + + test("append mode memory sink's do not support checkpoint recovery") { + import testImplicits._ + val ms = new MemoryStream[Int](0, sqlContext) + val df = ms.toDF().toDF("a") + val checkpointLoc = newMetadataDir + val checkpointDir = new File(checkpointLoc, "offsets") + checkpointDir.mkdirs() + assert(checkpointDir.exists()) + + val e = intercept[AnalysisException] { + df.writeStream + .format("memory") + .queryName("test") + .option("checkpointLocation", checkpointLoc) + .outputMode("append") + .start() + } + assert(e.getMessage.contains("does not support recovering")) + assert(e.getMessage.contains("checkpoint location")) + } }