diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 7e75e7083acb51b912cdb1f2e26e67626f679d50..4b90fbdf0ce7e46fbcdffbee0f2f866400886c2d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -142,8 +142,8 @@ final class EMLDAOptimizer extends LDAOptimizer {
     this.k = k
     this.vocabSize = docs.take(1).head._2.size
     this.checkpointInterval = lda.getCheckpointInterval
-    this.graphCheckpointer = new
-      PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval)
+    this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](
+      checkpointInterval, graph.vertices.sparkContext)
     this.globalTopicTotals = computeGlobalTopicTotals()
     this
   }
@@ -188,7 +188,7 @@ final class EMLDAOptimizer extends LDAOptimizer {
     // Update the vertex descriptors with the new counts.
     val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges)
     graph = newGraph
-    graphCheckpointer.updateGraph(newGraph)
+    graphCheckpointer.update(newGraph)
     globalTopicTotals = computeGlobalTopicTotals()
     this
   }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
new file mode 100644
index 0000000000000000000000000000000000000000..72d3aabc9b1f4a454f798864dbf5845e80771442
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.impl
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.{Path, FileSystem}
+
+import org.apache.spark.{SparkContext, Logging}
+import org.apache.spark.storage.StorageLevel
+
+
+/**
+ * This abstraction helps with persisting and checkpointing RDDs and types derived from RDDs
+ * (such as Graphs and DataFrames).  In documentation, we use the phrase "Dataset" to refer to
+ * the distributed data type (RDD, Graph, etc.).
+ *
+ * Specifically, this abstraction automatically handles persisting and (optionally) checkpointing,
+ * as well as unpersisting and removing checkpoint files.
+ *
+ * Users should call update() when a new Dataset has been created,
+ * before the Dataset has been materialized.  After updating [[PeriodicCheckpointer]], users are
+ * responsible for materializing the Dataset to ensure that persisting and checkpointing actually
+ * occur.
+ *
+ * When update() is called, this does the following:
+ *  - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets.
+ *  - Unpersist Datasets from queue until there are at most 3 persisted Datasets.
+ *  - If using checkpointing and the checkpoint interval has been reached,
+ *     - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets.
+ *     - Remove older checkpoints.
+ *
+ * WARNINGS:
+ *  - This class should NOT be copied (since copies may conflict on which Datasets should be
+ *    checkpointed).
+ *  - This class removes checkpoint files once later Datasets have been checkpointed.
+ *    However, references to the older Datasets will still return isCheckpointed = true.
+ *
+ * @param checkpointInterval  Datasets will be checkpointed at this interval
+ * @param sc  SparkContext for the Datasets given to this checkpointer
+ * @tparam T  Dataset type, such as RDD[Double]
+ */
+private[mllib] abstract class PeriodicCheckpointer[T](
+    val checkpointInterval: Int,
+    val sc: SparkContext) extends Logging {
+
+  /** FIFO queue of past checkpointed Datasets */
+  private val checkpointQueue = mutable.Queue[T]()
+
+  /** FIFO queue of past persisted Datasets */
+  private val persistedQueue = mutable.Queue[T]()
+
+  /** Number of times [[update()]] has been called */
+  private var updateCount = 0
+
+  /**
+   * Update with a new Dataset. Handle persistence and checkpointing as needed.
+   * Since this handles persistence and checkpointing, this should be called before the Dataset
+   * has been materialized.
+   *
+   * @param newData  New Dataset created from previous Datasets in the lineage.
+   */
+  def update(newData: T): Unit = {
+    persist(newData)
+    persistedQueue.enqueue(newData)
+    // We try to maintain 2 Datasets in persistedQueue to support the semantics of this class:
+    // Users should call [[update()]] when a new Dataset has been created,
+    // before the Dataset has been materialized.
+    while (persistedQueue.size > 3) {
+      val dataToUnpersist = persistedQueue.dequeue()
+      unpersist(dataToUnpersist)
+    }
+    updateCount += 1
+
+    // Handle checkpointing (after persisting)
+    if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
+      // Add new checkpoint before removing old checkpoints.
+      checkpoint(newData)
+      checkpointQueue.enqueue(newData)
+      // Remove checkpoints before the latest one.
+      var canDelete = true
+      while (checkpointQueue.size > 1 && canDelete) {
+        // Delete the oldest checkpoint only if the next checkpoint exists.
+        if (isCheckpointed(checkpointQueue.head)) {
+          removeCheckpointFile()
+        } else {
+          canDelete = false
+        }
+      }
+    }
+  }
+
+  /** Checkpoint the Dataset */
+  protected def checkpoint(data: T): Unit
+
+  /** Return true iff the Dataset is checkpointed */
+  protected def isCheckpointed(data: T): Boolean
+
+  /**
+   * Persist the Dataset.
+   * Note: This should handle checking the current [[StorageLevel]] of the Dataset.
+   */
+  protected def persist(data: T): Unit
+
+  /** Unpersist the Dataset */
+  protected def unpersist(data: T): Unit
+
+  /** Get list of checkpoint files for this given Dataset */
+  protected def getCheckpointFiles(data: T): Iterable[String]
+
+  /**
+   * Call this at the end to delete any remaining checkpoint files.
+   */
+  def deleteAllCheckpoints(): Unit = {
+    while (checkpointQueue.nonEmpty) {
+      removeCheckpointFile()
+    }
+  }
+
+  /**
+   * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files.
+   * This prints a warning but does not fail if the files cannot be removed.
+   */
+  private def removeCheckpointFile(): Unit = {
+    val old = checkpointQueue.dequeue()
+    // Since the old checkpoint is not deleted by Spark, we manually delete it.
+    val fs = FileSystem.get(sc.hadoopConfiguration)
+    getCheckpointFiles(old).foreach { checkpointFile =>
+      try {
+        fs.delete(new Path(checkpointFile), true)
+      } catch {
+        case e: Exception =>
+          logWarning("PeriodicCheckpointer could not remove old checkpoint file: " +
+            checkpointFile)
+      }
+    }
+  }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
index 6e5dd119dd653a8ff94ab9951974a56e26379a9b..11a059536c50cb21966f77a2cae6a4dea0963730 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
@@ -17,11 +17,7 @@
 
 package org.apache.spark.mllib.impl
 
-import scala.collection.mutable
-
-import org.apache.hadoop.fs.{Path, FileSystem}
-
-import org.apache.spark.Logging
+import org.apache.spark.SparkContext
 import org.apache.spark.graphx.Graph
 import org.apache.spark.storage.StorageLevel
 
@@ -31,12 +27,12 @@ import org.apache.spark.storage.StorageLevel
  * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as
  * unpersisting and removing checkpoint files.
  *
- * Users should call [[PeriodicGraphCheckpointer.updateGraph()]] when a new graph has been created,
+ * Users should call update() when a new graph has been created,
  * before the graph has been materialized.  After updating [[PeriodicGraphCheckpointer]], users are
  * responsible for materializing the graph to ensure that persisting and checkpointing actually
  * occur.
  *
- * When [[PeriodicGraphCheckpointer.updateGraph()]] is called, this does the following:
+ * When update() is called, this does the following:
  *  - Persist new graph (if not yet persisted), and put in queue of persisted graphs.
  *  - Unpersist graphs from queue until there are at most 3 persisted graphs.
  *  - If using checkpointing and the checkpoint interval has been reached,
@@ -52,7 +48,7 @@ import org.apache.spark.storage.StorageLevel
  * Example usage:
  * {{{
  *  val (graph1, graph2, graph3, ...) = ...
- *  val cp = new PeriodicGraphCheckpointer(graph1, dir, 2)
+ *  val cp = new PeriodicGraphCheckpointer(2, sc)
  *  graph1.vertices.count(); graph1.edges.count()
  *  // persisted: graph1
  *  cp.updateGraph(graph2)
@@ -73,99 +69,30 @@ import org.apache.spark.storage.StorageLevel
  *  // checkpointed: graph4
  * }}}
  *
- * @param currentGraph  Initial graph
  * @param checkpointInterval Graphs will be checkpointed at this interval
  * @tparam VD  Vertex descriptor type
  * @tparam ED  Edge descriptor type
  *
- * TODO: Generalize this for Graphs and RDDs, and move it out of MLlib.
+ * TODO: Move this out of MLlib?
  */
 private[mllib] class PeriodicGraphCheckpointer[VD, ED](
-    var currentGraph: Graph[VD, ED],
-    val checkpointInterval: Int) extends Logging {
-
-  /** FIFO queue of past checkpointed RDDs */
-  private val checkpointQueue = mutable.Queue[Graph[VD, ED]]()
-
-  /** FIFO queue of past persisted RDDs */
-  private val persistedQueue = mutable.Queue[Graph[VD, ED]]()
-
-  /** Number of times [[updateGraph()]] has been called */
-  private var updateCount = 0
-
-  /**
-   * Spark Context for the Graphs given to this checkpointer.
-   * NOTE: This code assumes that only one SparkContext is used for the given graphs.
-   */
-  private val sc = currentGraph.vertices.sparkContext
+    checkpointInterval: Int,
+    sc: SparkContext)
+  extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) {
 
-  updateGraph(currentGraph)
+  override protected def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint()
 
-  /**
-   * Update [[currentGraph]] with a new graph. Handle persistence and checkpointing as needed.
-   * Since this handles persistence and checkpointing, this should be called before the graph
-   * has been materialized.
-   *
-   * @param newGraph  New graph created from previous graphs in the lineage.
-   */
-  def updateGraph(newGraph: Graph[VD, ED]): Unit = {
-    if (newGraph.vertices.getStorageLevel == StorageLevel.NONE) {
-      newGraph.persist()
-    }
-    persistedQueue.enqueue(newGraph)
-    // We try to maintain 2 Graphs in persistedQueue to support the semantics of this class:
-    // Users should call [[updateGraph()]] when a new graph has been created,
-    // before the graph has been materialized.
-    while (persistedQueue.size > 3) {
-      val graphToUnpersist = persistedQueue.dequeue()
-      graphToUnpersist.unpersist(blocking = false)
-    }
-    updateCount += 1
+  override protected def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed
 
-    // Handle checkpointing (after persisting)
-    if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
-      // Add new checkpoint before removing old checkpoints.
-      newGraph.checkpoint()
-      checkpointQueue.enqueue(newGraph)
-      // Remove checkpoints before the latest one.
-      var canDelete = true
-      while (checkpointQueue.size > 1 && canDelete) {
-        // Delete the oldest checkpoint only if the next checkpoint exists.
-        if (checkpointQueue.get(1).get.isCheckpointed) {
-          removeCheckpointFile()
-        } else {
-          canDelete = false
-        }
-      }
+  override protected def persist(data: Graph[VD, ED]): Unit = {
+    if (data.vertices.getStorageLevel == StorageLevel.NONE) {
+      data.persist()
     }
   }
 
-  /**
-   * Call this at the end to delete any remaining checkpoint files.
-   */
-  def deleteAllCheckpoints(): Unit = {
-    while (checkpointQueue.size > 0) {
-      removeCheckpointFile()
-    }
-  }
+  override protected def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false)
 
-  /**
-   * Dequeue the oldest checkpointed Graph, and remove its checkpoint files.
-   * This prints a warning but does not fail if the files cannot be removed.
-   */
-  private def removeCheckpointFile(): Unit = {
-    val old = checkpointQueue.dequeue()
-    // Since the old checkpoint is not deleted by Spark, we manually delete it.
-    val fs = FileSystem.get(sc.hadoopConfiguration)
-    old.getCheckpointFiles.foreach { checkpointFile =>
-      try {
-        fs.delete(new Path(checkpointFile), true)
-      } catch {
-        case e: Exception =>
-          logWarning("PeriodicGraphCheckpointer could not remove old checkpoint file: " +
-            checkpointFile)
-      }
-    }
+  override protected def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = {
+    data.getCheckpointFiles
   }
-
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
new file mode 100644
index 0000000000000000000000000000000000000000..f31ed2aa90a6420bc079e2b4a0c4c53d91f34e47
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.impl
+
+import org.apache.spark.SparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+
+/**
+ * This class helps with persisting and checkpointing RDDs.
+ * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as
+ * unpersisting and removing checkpoint files.
+ *
+ * Users should call update() when a new RDD has been created,
+ * before the RDD has been materialized.  After updating [[PeriodicRDDCheckpointer]], users are
+ * responsible for materializing the RDD to ensure that persisting and checkpointing actually
+ * occur.
+ *
+ * When update() is called, this does the following:
+ *  - Persist new RDD (if not yet persisted), and put in queue of persisted RDDs.
+ *  - Unpersist RDDs from queue until there are at most 3 persisted RDDs.
+ *  - If using checkpointing and the checkpoint interval has been reached,
+ *     - Checkpoint the new RDD, and put in a queue of checkpointed RDDs.
+ *     - Remove older checkpoints.
+ *
+ * WARNINGS:
+ *  - This class should NOT be copied (since copies may conflict on which RDDs should be
+ *    checkpointed).
+ *  - This class removes checkpoint files once later RDDs have been checkpointed.
+ *    However, references to the older RDDs will still return isCheckpointed = true.
+ *
+ * Example usage:
+ * {{{
+ *  val (rdd1, rdd2, rdd3, ...) = ...
+ *  val cp = new PeriodicRDDCheckpointer(2, sc)
+ *  rdd1.count();
+ *  // persisted: rdd1
+ *  cp.update(rdd2)
+ *  rdd2.count();
+ *  // persisted: rdd1, rdd2
+ *  // checkpointed: rdd2
+ *  cp.update(rdd3)
+ *  rdd3.count();
+ *  // persisted: rdd1, rdd2, rdd3
+ *  // checkpointed: rdd2
+ *  cp.update(rdd4)
+ *  rdd4.count();
+ *  // persisted: rdd2, rdd3, rdd4
+ *  // checkpointed: rdd4
+ *  cp.update(rdd5)
+ *  rdd5.count();
+ *  // persisted: rdd3, rdd4, rdd5
+ *  // checkpointed: rdd4
+ * }}}
+ *
+ * @param checkpointInterval  RDDs will be checkpointed at this interval
+ * @tparam T  RDD element type
+ *
+ * TODO: Move this out of MLlib?
+ */
+private[mllib] class PeriodicRDDCheckpointer[T](
+    checkpointInterval: Int,
+    sc: SparkContext)
+  extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) {
+
+  override protected def checkpoint(data: RDD[T]): Unit = data.checkpoint()
+
+  override protected def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed
+
+  override protected def persist(data: RDD[T]): Unit = {
+    if (data.getStorageLevel == StorageLevel.NONE) {
+      data.persist()
+    }
+  }
+
+  override protected def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false)
+
+  override protected def getCheckpointFiles(data: RDD[T]): Iterable[String] = {
+    data.getCheckpointFile.map(x => x)
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
index d34888af2d73b3314e69075396f377324d19f9ca..e331c7598918751e0300f140e3ab3a4c1c5654d3 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
@@ -30,20 +30,20 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
 
   import PeriodicGraphCheckpointerSuite._
 
-  // TODO: Do I need to call count() on the graphs' RDDs?
-
   test("Persisting") {
     var graphsToCheck = Seq.empty[GraphToCheck]
 
     val graph1 = createGraph(sc)
-    val checkpointer = new PeriodicGraphCheckpointer(graph1, 10)
+    val checkpointer =
+      new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext)
+    checkpointer.update(graph1)
     graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
     checkPersistence(graphsToCheck, 1)
 
     var iteration = 2
     while (iteration < 9) {
       val graph = createGraph(sc)
-      checkpointer.updateGraph(graph)
+      checkpointer.update(graph)
       graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
       checkPersistence(graphsToCheck, iteration)
       iteration += 1
@@ -57,7 +57,9 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
     var graphsToCheck = Seq.empty[GraphToCheck]
     sc.setCheckpointDir(path)
     val graph1 = createGraph(sc)
-    val checkpointer = new PeriodicGraphCheckpointer(graph1, checkpointInterval)
+    val checkpointer = new PeriodicGraphCheckpointer[Double, Double](
+      checkpointInterval, graph1.vertices.sparkContext)
+    checkpointer.update(graph1)
     graph1.edges.count()
     graph1.vertices.count()
     graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
@@ -66,7 +68,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
     var iteration = 2
     while (iteration < 9) {
       val graph = createGraph(sc)
-      checkpointer.updateGraph(graph)
+      checkpointer.update(graph)
       graph.vertices.count()
       graph.edges.count()
       graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
@@ -168,7 +170,7 @@ private object PeriodicGraphCheckpointerSuite {
       } else {
         // Graph should never be checkpointed
         assert(!graph.isCheckpointed, "Graph should never have been checkpointed")
-        assert(graph.getCheckpointFiles.length == 0, "Graph should not have any checkpoint files")
+        assert(graph.getCheckpointFiles.isEmpty, "Graph should not have any checkpoint files")
       }
     } catch {
       case e: AssertionError =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..b2a459a68b5fa72b421d82bb7941a8a3b8201cdc
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.impl
+
+import org.apache.hadoop.fs.{FileSystem, Path}
+
+import org.apache.spark.{SparkContext, SparkFunSuite}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
+
+
+class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+  import PeriodicRDDCheckpointerSuite._
+
+  test("Persisting") {
+    var rddsToCheck = Seq.empty[RDDToCheck]
+
+    val rdd1 = createRDD(sc)
+    val checkpointer = new PeriodicRDDCheckpointer[Double](10, rdd1.sparkContext)
+    checkpointer.update(rdd1)
+    rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1)
+    checkPersistence(rddsToCheck, 1)
+
+    var iteration = 2
+    while (iteration < 9) {
+      val rdd = createRDD(sc)
+      checkpointer.update(rdd)
+      rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration)
+      checkPersistence(rddsToCheck, iteration)
+      iteration += 1
+    }
+  }
+
+  test("Checkpointing") {
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+    val checkpointInterval = 2
+    var rddsToCheck = Seq.empty[RDDToCheck]
+    sc.setCheckpointDir(path)
+    val rdd1 = createRDD(sc)
+    val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext)
+    checkpointer.update(rdd1)
+    rdd1.count()
+    rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1)
+    checkCheckpoint(rddsToCheck, 1, checkpointInterval)
+
+    var iteration = 2
+    while (iteration < 9) {
+      val rdd = createRDD(sc)
+      checkpointer.update(rdd)
+      rdd.count()
+      rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration)
+      checkCheckpoint(rddsToCheck, iteration, checkpointInterval)
+      iteration += 1
+    }
+
+    checkpointer.deleteAllCheckpoints()
+    rddsToCheck.foreach { rdd =>
+      confirmCheckpointRemoved(rdd.rdd)
+    }
+
+    Utils.deleteRecursively(tempDir)
+  }
+}
+
+private object PeriodicRDDCheckpointerSuite {
+
+  case class RDDToCheck(rdd: RDD[Double], gIndex: Int)
+
+  def createRDD(sc: SparkContext): RDD[Double] = {
+    sc.parallelize(Seq(0.0, 1.0, 2.0, 3.0))
+  }
+
+  def checkPersistence(rdds: Seq[RDDToCheck], iteration: Int): Unit = {
+    rdds.foreach { g =>
+      checkPersistence(g.rdd, g.gIndex, iteration)
+    }
+  }
+
+  /**
+   * Check storage level of rdd.
+   * @param gIndex  Index of rdd in order inserted into checkpointer (from 1).
+   * @param iteration  Total number of rdds inserted into checkpointer.
+   */
+  def checkPersistence(rdd: RDD[_], gIndex: Int, iteration: Int): Unit = {
+    try {
+      if (gIndex + 2 < iteration) {
+        assert(rdd.getStorageLevel == StorageLevel.NONE)
+      } else {
+        assert(rdd.getStorageLevel != StorageLevel.NONE)
+      }
+    } catch {
+      case _: AssertionError =>
+        throw new Exception(s"PeriodicRDDCheckpointerSuite.checkPersistence failed with:\n" +
+          s"\t gIndex = $gIndex\n" +
+          s"\t iteration = $iteration\n" +
+          s"\t rdd.getStorageLevel = ${rdd.getStorageLevel}\n")
+    }
+  }
+
+  def checkCheckpoint(rdds: Seq[RDDToCheck], iteration: Int, checkpointInterval: Int): Unit = {
+    rdds.reverse.foreach { g =>
+      checkCheckpoint(g.rdd, g.gIndex, iteration, checkpointInterval)
+    }
+  }
+
+  def confirmCheckpointRemoved(rdd: RDD[_]): Unit = {
+    // Note: We cannot check rdd.isCheckpointed since that value is never updated.
+    //       Instead, we check for the presence of the checkpoint files.
+    //       This test should continue to work even after this rdd.isCheckpointed issue
+    //       is fixed (though it can then be simplified and not look for the files).
+    val fs = FileSystem.get(rdd.sparkContext.hadoopConfiguration)
+    rdd.getCheckpointFile.foreach { checkpointFile =>
+      assert(!fs.exists(new Path(checkpointFile)), "RDD checkpoint file should have been removed")
+    }
+  }
+
+  /**
+   * Check checkpointed status of rdd.
+   * @param gIndex  Index of rdd in order inserted into checkpointer (from 1).
+   * @param iteration  Total number of rdds inserted into checkpointer.
+   */
+  def checkCheckpoint(
+      rdd: RDD[_],
+      gIndex: Int,
+      iteration: Int,
+      checkpointInterval: Int): Unit = {
+    try {
+      if (gIndex % checkpointInterval == 0) {
+        // We allow 2 checkpoint intervals since we perform an action (checkpointing a second rdd)
+        // only AFTER PeriodicRDDCheckpointer decides whether to remove the previous checkpoint.
+        if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) {
+          assert(rdd.isCheckpointed, "RDD should be checkpointed")
+          assert(rdd.getCheckpointFile.nonEmpty, "RDD should have 2 checkpoint files")
+        } else {
+          confirmCheckpointRemoved(rdd)
+        }
+      } else {
+        // RDD should never be checkpointed
+        assert(!rdd.isCheckpointed, "RDD should never have been checkpointed")
+        assert(rdd.getCheckpointFile.isEmpty, "RDD should not have any checkpoint files")
+      }
+    } catch {
+      case e: AssertionError =>
+        throw new Exception(s"PeriodicRDDCheckpointerSuite.checkCheckpoint failed with:\n" +
+          s"\t gIndex = $gIndex\n" +
+          s"\t iteration = $iteration\n" +
+          s"\t checkpointInterval = $checkpointInterval\n" +
+          s"\t rdd.isCheckpointed = ${rdd.isCheckpointed}\n" +
+          s"\t rdd.getCheckpointFile = ${rdd.getCheckpointFile.mkString(", ")}\n" +
+          s"  AssertionError message: ${e.getMessage}")
+    }
+  }
+
+}