diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index b959444b492981e3f69a846644f7426c3c348cd7..daed1dcb7737026a2617c054abd8fc2861acd24b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -222,14 +222,16 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
 
       val sink = new MemorySink(df.schema, outputMode)
       val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink))
+      val chkpointLoc = extraOptions.get("checkpointLocation")
+      val recoverFromChkpoint = chkpointLoc.isDefined && outputMode == OutputMode.Complete()
       val query = df.sparkSession.sessionState.streamingQueryManager.startQuery(
         extraOptions.get("queryName"),
-        extraOptions.get("checkpointLocation"),
+        chkpointLoc,
         df,
         sink,
         outputMode,
         useTempCheckpointLocation = true,
-        recoverFromCheckpointLocation = false,
+        recoverFromCheckpointLocation = recoverFromChkpoint,
         trigger = trigger)
       resultDf.createOrReplaceTempView(query.name)
       query
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
index f0994395813e46a1fb91674b9529e10b1889e3c0..5630464f4080395f4c99ad03b69559e5d510c8e0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.streaming.test
 
+import java.io.File
 import java.util.concurrent.TimeUnit
 
 import scala.concurrent.duration._
@@ -467,4 +468,68 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
     val sq = df.writeStream.format("console").start()
     sq.stop()
   }
+
+  test("MemorySink can recover from a checkpoint in Complete Mode") {
+    import testImplicits._
+    val ms = new MemoryStream[Int](0, sqlContext)
+    val df = ms.toDF().toDF("a")
+    val checkpointLoc = newMetadataDir
+    val checkpointDir = new File(checkpointLoc, "offsets")
+    checkpointDir.mkdirs()
+    assert(checkpointDir.exists())
+    val tableName = "test"
+    def startQuery: StreamingQuery = {
+      df.groupBy("a")
+        .count()
+        .writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option("checkpointLocation", checkpointLoc)
+        .outputMode("complete")
+        .start()
+    }
+    // no exception here
+    val q = startQuery
+    ms.addData(0, 1)
+    q.processAllAvailable()
+    q.stop()
+
+    checkAnswer(
+      spark.table(tableName),
+      Seq(Row(0, 1), Row(1, 1))
+    )
+    spark.sql(s"drop table $tableName")
+    // verify table is dropped
+    intercept[AnalysisException](spark.table(tableName).collect())
+    val q2 = startQuery
+    ms.addData(0)
+    q2.processAllAvailable()
+    checkAnswer(
+      spark.table(tableName),
+      Seq(Row(0, 2), Row(1, 1))
+    )
+
+    q2.stop()
+  }
+
+  test("append mode memory sink's do not support checkpoint recovery") {
+    import testImplicits._
+    val ms = new MemoryStream[Int](0, sqlContext)
+    val df = ms.toDF().toDF("a")
+    val checkpointLoc = newMetadataDir
+    val checkpointDir = new File(checkpointLoc, "offsets")
+    checkpointDir.mkdirs()
+    assert(checkpointDir.exists())
+
+    val e = intercept[AnalysisException] {
+      df.writeStream
+        .format("memory")
+        .queryName("test")
+        .option("checkpointLocation", checkpointLoc)
+        .outputMode("append")
+        .start()
+    }
+    assert(e.getMessage.contains("does not support recovering"))
+    assert(e.getMessage.contains("checkpoint location"))
+  }
 }