From 7ed1bf4b485131d58ea6728e7247b79320aca9e6 Mon Sep 17 00:00:00 2001
From: Josh Rosen <joshrosen@eecs.berkeley.edu>
Date: Wed, 16 Jan 2013 19:15:14 -0800
Subject: [PATCH] Add RDD checkpointing to Python API.

---
 .../scala/spark/api/python/PythonRDD.scala    |  3 --
 python/epydoc.conf                            |  2 +-
 python/pyspark/context.py                     |  9 ++++
 python/pyspark/rdd.py                         | 34 ++++++++++++++
 python/pyspark/tests.py                       | 46 +++++++++++++++++++
 python/run-tests                              |  3 ++
 6 files changed, 93 insertions(+), 4 deletions(-)
 create mode 100644 python/pyspark/tests.py

diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 89f7c316dc..8c38262dd8 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -135,8 +135,6 @@ private[spark] class PythonRDD[T: ClassManifest](
     }
   }
 
-  override def checkpoint() { }
-
   val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
 }
 
@@ -152,7 +150,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
       case Seq(a, b) => (a, b)
       case x          => throw new Exception("PairwiseRDD: unexpected value: " + x)
     }
-  override def checkpoint() { }
   val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
 }
 
diff --git a/python/epydoc.conf b/python/epydoc.conf
index 91ac984ba2..45102cd9fe 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -16,4 +16,4 @@ target: docs/
 private: no
 
 exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
-         pyspark.java_gateway pyspark.examples pyspark.shell
+         pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 1e2f845f9c..a438b43fdc 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -195,3 +195,12 @@ class SparkContext(object):
         filename = path.split("/")[-1]
         os.environ["PYTHONPATH"] = \
             "%s:%s" % (filename, os.environ["PYTHONPATH"])
+
+    def setCheckpointDir(self, dirName, useExisting=False):
+        """
+        Set the directory under which RDDs are going to be checkpointed. This
+        method will create this directory and will throw an exception of the
+        path already exists (to avoid overwriting existing files may be
+        overwritten). The directory will be deleted on exit if indicated.
+        """
+        self._jsc.sc().setCheckpointDir(dirName, useExisting)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d705f0f9e1..9b676cae4a 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -49,6 +49,40 @@ class RDD(object):
         self._jrdd.cache()
         return self
 
+    def checkpoint(self):
+        """
+        Mark this RDD for checkpointing. The RDD will be saved to a file inside
+        `checkpointDir` (set using setCheckpointDir()) and all references to
+        its parent RDDs will be removed.  This is used to truncate very long
+        lineages.  In the current implementation, Spark will save this RDD to
+        a file (using saveAsObjectFile()) after the first job using this RDD is
+        done.  Hence, it is strongly recommended to use checkpoint() on RDDs
+        when
+
+        (i) checkpoint() is called before the any job has been executed on this
+        RDD.
+
+        (ii) This RDD has been made to persist in memory. Otherwise saving it
+        on a file will require recomputation.
+        """
+        self._jrdd.rdd().checkpoint()
+
+    def isCheckpointed(self):
+        """
+        Return whether this RDD has been checkpointed or not
+        """
+        return self._jrdd.rdd().isCheckpointed()
+
+    def getCheckpointFile(self):
+        """
+        Gets the name of the file to which this RDD was checkpointed
+        """
+        checkpointFile = self._jrdd.rdd().getCheckpointFile()
+        if checkpointFile.isDefined():
+            return checkpointFile.get()
+        else:
+            return None
+
     # TODO persist(self, storageLevel)
 
     def map(self, f, preservesPartitioning=False):
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
new file mode 100644
index 0000000000..c959d5dec7
--- /dev/null
+++ b/python/pyspark/tests.py
@@ -0,0 +1,46 @@
+"""
+Unit tests for PySpark; additional tests are implemented as doctests in
+individual modules.
+"""
+import atexit
+import os
+import shutil
+from tempfile import NamedTemporaryFile
+import time
+import unittest
+
+from pyspark.context import SparkContext
+
+
+class TestCheckpoint(unittest.TestCase):
+
+    def setUp(self):
+        self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2)
+
+    def tearDown(self):
+        self.sc.stop()
+
+    def test_basic_checkpointing(self):
+        checkpointDir = NamedTemporaryFile(delete=False)
+        os.unlink(checkpointDir.name)
+        self.sc.setCheckpointDir(checkpointDir.name)
+
+        parCollection = self.sc.parallelize([1, 2, 3, 4])
+        flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
+
+        self.assertFalse(flatMappedRDD.isCheckpointed())
+        self.assertIsNone(flatMappedRDD.getCheckpointFile())
+
+        flatMappedRDD.checkpoint()
+        result = flatMappedRDD.collect()
+        time.sleep(1)  # 1 second
+        self.assertTrue(flatMappedRDD.isCheckpointed())
+        self.assertEqual(flatMappedRDD.collect(), result)
+        self.assertEqual(checkpointDir.name,
+                         os.path.dirname(flatMappedRDD.getCheckpointFile()))
+
+        atexit.register(lambda: shutil.rmtree(checkpointDir.name))
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/python/run-tests b/python/run-tests
index 32470911f9..ce214e98a8 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -14,6 +14,9 @@ FAILED=$(($?||$FAILED))
 $FWDIR/pyspark -m doctest pyspark/accumulators.py
 FAILED=$(($?||$FAILED))
 
+$FWDIR/pyspark -m unittest pyspark.tests
+FAILED=$(($?||$FAILED))
+
 if [[ $FAILED != 0 ]]; then
     echo -en "\033[31m"  # Red
     echo "Had test failures; see logs."
-- 
GitLab