Skip to content
Snippets Groups Projects
Commit 7ed1bf4b authored by Josh Rosen's avatar Josh Rosen
Browse files

Add RDD checkpointing to Python API.

parent fe85a075
No related branches found
No related tags found
No related merge requests found
...@@ -135,8 +135,6 @@ private[spark] class PythonRDD[T: ClassManifest]( ...@@ -135,8 +135,6 @@ private[spark] class PythonRDD[T: ClassManifest](
} }
} }
override def checkpoint() { }
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
} }
...@@ -152,7 +150,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends ...@@ -152,7 +150,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
case Seq(a, b) => (a, b) case Seq(a, b) => (a, b)
case x => throw new Exception("PairwiseRDD: unexpected value: " + x) case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
} }
override def checkpoint() { }
val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
} }
......
...@@ -16,4 +16,4 @@ target: docs/ ...@@ -16,4 +16,4 @@ target: docs/
private: no private: no
exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
pyspark.java_gateway pyspark.examples pyspark.shell pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
...@@ -195,3 +195,12 @@ class SparkContext(object): ...@@ -195,3 +195,12 @@ class SparkContext(object):
filename = path.split("/")[-1] filename = path.split("/")[-1]
os.environ["PYTHONPATH"] = \ os.environ["PYTHONPATH"] = \
"%s:%s" % (filename, os.environ["PYTHONPATH"]) "%s:%s" % (filename, os.environ["PYTHONPATH"])
def setCheckpointDir(self, dirName, useExisting=False):
"""
Set the directory under which RDDs are going to be checkpointed. This
method will create this directory and will throw an exception of the
path already exists (to avoid overwriting existing files may be
overwritten). The directory will be deleted on exit if indicated.
"""
self._jsc.sc().setCheckpointDir(dirName, useExisting)
...@@ -49,6 +49,40 @@ class RDD(object): ...@@ -49,6 +49,40 @@ class RDD(object):
self._jrdd.cache() self._jrdd.cache()
return self return self
def checkpoint(self):
"""
Mark this RDD for checkpointing. The RDD will be saved to a file inside
`checkpointDir` (set using setCheckpointDir()) and all references to
its parent RDDs will be removed. This is used to truncate very long
lineages. In the current implementation, Spark will save this RDD to
a file (using saveAsObjectFile()) after the first job using this RDD is
done. Hence, it is strongly recommended to use checkpoint() on RDDs
when
(i) checkpoint() is called before the any job has been executed on this
RDD.
(ii) This RDD has been made to persist in memory. Otherwise saving it
on a file will require recomputation.
"""
self._jrdd.rdd().checkpoint()
def isCheckpointed(self):
"""
Return whether this RDD has been checkpointed or not
"""
return self._jrdd.rdd().isCheckpointed()
def getCheckpointFile(self):
"""
Gets the name of the file to which this RDD was checkpointed
"""
checkpointFile = self._jrdd.rdd().getCheckpointFile()
if checkpointFile.isDefined():
return checkpointFile.get()
else:
return None
# TODO persist(self, storageLevel) # TODO persist(self, storageLevel)
def map(self, f, preservesPartitioning=False): def map(self, f, preservesPartitioning=False):
......
"""
Unit tests for PySpark; additional tests are implemented as doctests in
individual modules.
"""
import atexit
import os
import shutil
from tempfile import NamedTemporaryFile
import time
import unittest
from pyspark.context import SparkContext
class TestCheckpoint(unittest.TestCase):
def setUp(self):
self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2)
def tearDown(self):
self.sc.stop()
def test_basic_checkpointing(self):
checkpointDir = NamedTemporaryFile(delete=False)
os.unlink(checkpointDir.name)
self.sc.setCheckpointDir(checkpointDir.name)
parCollection = self.sc.parallelize([1, 2, 3, 4])
flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
self.assertFalse(flatMappedRDD.isCheckpointed())
self.assertIsNone(flatMappedRDD.getCheckpointFile())
flatMappedRDD.checkpoint()
result = flatMappedRDD.collect()
time.sleep(1) # 1 second
self.assertTrue(flatMappedRDD.isCheckpointed())
self.assertEqual(flatMappedRDD.collect(), result)
self.assertEqual(checkpointDir.name,
os.path.dirname(flatMappedRDD.getCheckpointFile()))
atexit.register(lambda: shutil.rmtree(checkpointDir.name))
if __name__ == "__main__":
unittest.main()
...@@ -14,6 +14,9 @@ FAILED=$(($?||$FAILED)) ...@@ -14,6 +14,9 @@ FAILED=$(($?||$FAILED))
$FWDIR/pyspark -m doctest pyspark/accumulators.py $FWDIR/pyspark -m doctest pyspark/accumulators.py
FAILED=$(($?||$FAILED)) FAILED=$(($?||$FAILED))
$FWDIR/pyspark -m unittest pyspark.tests
FAILED=$(($?||$FAILED))
if [[ $FAILED != 0 ]]; then if [[ $FAILED != 0 ]]; then
echo -en "\033[31m" # Red echo -en "\033[31m" # Red
echo "Had test failures; see logs." echo "Had test failures; see logs."
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment