diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
index 1a6edf9473d84abf0f1f6237cd2f8dbe72020627..91a43e14a8b1bad7bba47a762ac798ccb843d6e6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
@@ -97,6 +97,8 @@ abstract class DStream[T: ClassTag] (
   private[streaming] val mustCheckpoint = false
   private[streaming] var checkpointDuration: Duration = null
   private[streaming] val checkpointData = new DStreamCheckpointData(this)
+  @transient
+  private var restoredFromCheckpointData = false
 
   // Reference to whole DStream graph
   private[streaming] var graph: DStreamGraph = null
@@ -507,11 +509,14 @@ abstract class DStream[T: ClassTag] (
    * override the updateCheckpointData() method would also need to override this method.
    */
   private[streaming] def restoreCheckpointData() {
-    // Create RDDs from the checkpoint data
-    logInfo("Restoring checkpoint data")
-    checkpointData.restore()
-    dependencies.foreach(_.restoreCheckpointData())
-    logInfo("Restored checkpoint data")
+    if (!restoredFromCheckpointData) {
+      // Create RDDs from the checkpoint data
+      logInfo("Restoring checkpoint data")
+      checkpointData.restore()
+      dependencies.foreach(_.restoreCheckpointData())
+      restoredFromCheckpointData = true
+      logInfo("Restored checkpoint data")
+    }
   }
 
   @throws(classOf[IOException])
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index cd28d3cf408d5b517ba4a968c595d1bcdaa6141a..f5f446f14a0daa4c383d30f159c3d5bc732d9d5e 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.streaming
 
-import java.io.{ObjectOutputStream, ByteArrayOutputStream, ByteArrayInputStream, File}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, ObjectOutputStream}
 
 import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
 import scala.reflect.ClassTag
@@ -34,9 +34,30 @@ import org.scalatest.concurrent.Eventually._
 import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils}
-import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.dstream._
 import org.apache.spark.streaming.scheduler._
-import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils}
+import org.apache.spark.util.{Clock, ManualClock, MutableURLClassLoader, Utils}
+
+/**
+ * A input stream that records the times of restore() invoked
+ */
+private[streaming]
+class CheckpointInputDStream(ssc_ : StreamingContext) extends InputDStream[Int](ssc_) {
+  protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData
+  override def start(): Unit = { }
+  override def stop(): Unit = { }
+  override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.makeRDD(Seq(1)))
+  private[streaming]
+  class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) {
+    @transient
+    var restoredTimes = 0
+    override def restore() {
+      restoredTimes += 1
+      super.restore()
+    }
+  }
+}
 
 /**
  * A trait of that can be mixed in to get methods for testing DStream operations under
@@ -110,7 +131,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite =>
     new StreamingContext(SparkContext.getOrCreate(conf), batchDuration)
   }
 
-  private def generateOutput[V: ClassTag](
+  protected def generateOutput[V: ClassTag](
       ssc: StreamingContext,
       targetBatchTime: Time,
       checkpointDir: String,
@@ -715,6 +736,33 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester {
     }
   }
 
+  test("DStreamCheckpointData.restore invoking times") {
+    withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
+      ssc.checkpoint(checkpointDir)
+      val inputDStream = new CheckpointInputDStream(ssc)
+      val checkpointData = inputDStream.checkpointData
+      val mappedDStream = inputDStream.map(_ + 100)
+      val outputStream = new TestOutputStreamWithPartitions(mappedDStream)
+      outputStream.register()
+      // do two more times output
+      mappedDStream.foreachRDD(rdd => rdd.count())
+      mappedDStream.foreachRDD(rdd => rdd.count())
+      assert(checkpointData.restoredTimes === 0)
+      val batchDurationMillis = ssc.progressListener.batchDuration
+      generateOutput(ssc, Time(batchDurationMillis * 3), checkpointDir, stopSparkContext = true)
+      assert(checkpointData.restoredTimes === 0)
+    }
+    logInfo("*********** RESTARTING ************")
+    withStreamingContext(new StreamingContext(checkpointDir)) { ssc =>
+      val checkpointData =
+        ssc.graph.getInputStreams().head.asInstanceOf[CheckpointInputDStream].checkpointData
+      assert(checkpointData.restoredTimes === 1)
+      ssc.start()
+      ssc.stop()
+      assert(checkpointData.restoredTimes === 1)
+    }
+  }
+
   // This tests whether spark can deserialize array object
   // refer to SPARK-5569
   test("recovery from checkpoint contains array object") {