Skip to content
Snippets Groups Projects
Commit d0ba80dc authored by Josh Rosen's avatar Josh Rosen
Browse files

Add checkpointFile() and more tests to PySpark.

parent 7ed1bf4b
No related branches found
No related tags found
No related merge requests found
...@@ -123,6 +123,10 @@ class SparkContext(object): ...@@ -123,6 +123,10 @@ class SparkContext(object):
jrdd = self._jsc.textFile(name, minSplits) jrdd = self._jsc.textFile(name, minSplits)
return RDD(jrdd, self) return RDD(jrdd, self)
def _checkpointFile(self, name):
jrdd = self._jsc.checkpointFile(name)
return RDD(jrdd, self)
def union(self, rdds): def union(self, rdds):
""" """
Build the union of a list of RDDs. Build the union of a list of RDDs.
...@@ -145,7 +149,7 @@ class SparkContext(object): ...@@ -145,7 +149,7 @@ class SparkContext(object):
def accumulator(self, value, accum_param=None): def accumulator(self, value, accum_param=None):
""" """
Create an C{Accumulator} with the given initial value, using a given Create an C{Accumulator} with the given initial value, using a given
AccumulatorParam helper object to define how to add values of the data AccumulatorParam helper object to define how to add values of the data
type if provided. Default AccumulatorParams are used for integers and type if provided. Default AccumulatorParams are used for integers and
floating-point numbers if you do not provide one. For other types, the floating-point numbers if you do not provide one. For other types, the
AccumulatorParam must implement two methods: AccumulatorParam must implement two methods:
......
...@@ -32,6 +32,7 @@ class RDD(object): ...@@ -32,6 +32,7 @@ class RDD(object):
def __init__(self, jrdd, ctx): def __init__(self, jrdd, ctx):
self._jrdd = jrdd self._jrdd = jrdd
self.is_cached = False self.is_cached = False
self.is_checkpointed = False
self.ctx = ctx self.ctx = ctx
@property @property
...@@ -65,6 +66,7 @@ class RDD(object): ...@@ -65,6 +66,7 @@ class RDD(object):
(ii) This RDD has been made to persist in memory. Otherwise saving it (ii) This RDD has been made to persist in memory. Otherwise saving it
on a file will require recomputation. on a file will require recomputation.
""" """
self.is_checkpointed = True
self._jrdd.rdd().checkpoint() self._jrdd.rdd().checkpoint()
def isCheckpointed(self): def isCheckpointed(self):
...@@ -696,7 +698,7 @@ class PipelinedRDD(RDD): ...@@ -696,7 +698,7 @@ class PipelinedRDD(RDD):
20 20
""" """
def __init__(self, prev, func, preservesPartitioning=False): def __init__(self, prev, func, preservesPartitioning=False):
if isinstance(prev, PipelinedRDD) and not prev.is_cached: if isinstance(prev, PipelinedRDD) and prev._is_pipelinable:
prev_func = prev.func prev_func = prev.func
def pipeline_func(split, iterator): def pipeline_func(split, iterator):
return func(split, prev_func(split, iterator)) return func(split, prev_func(split, iterator))
...@@ -709,6 +711,7 @@ class PipelinedRDD(RDD): ...@@ -709,6 +711,7 @@ class PipelinedRDD(RDD):
self.preservesPartitioning = preservesPartitioning self.preservesPartitioning = preservesPartitioning
self._prev_jrdd = prev._jrdd self._prev_jrdd = prev._jrdd
self.is_cached = False self.is_cached = False
self.is_checkpointed = False
self.ctx = prev.ctx self.ctx = prev.ctx
self.prev = prev self.prev = prev
self._jrdd_val = None self._jrdd_val = None
...@@ -741,6 +744,10 @@ class PipelinedRDD(RDD): ...@@ -741,6 +744,10 @@ class PipelinedRDD(RDD):
self._jrdd_val = python_rdd.asJavaRDD() self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val return self._jrdd_val
@property
def _is_pipelinable(self):
return not (self.is_cached or self.is_checkpointed)
def _test(): def _test():
import doctest import doctest
......
...@@ -19,6 +19,9 @@ class TestCheckpoint(unittest.TestCase): ...@@ -19,6 +19,9 @@ class TestCheckpoint(unittest.TestCase):
def tearDown(self): def tearDown(self):
self.sc.stop() self.sc.stop()
# To avoid Akka rebinding to the same port, since it doesn't unbind
# immediately on shutdown
self.sc.jvm.System.clearProperty("spark.master.port")
def test_basic_checkpointing(self): def test_basic_checkpointing(self):
checkpointDir = NamedTemporaryFile(delete=False) checkpointDir = NamedTemporaryFile(delete=False)
...@@ -41,6 +44,27 @@ class TestCheckpoint(unittest.TestCase): ...@@ -41,6 +44,27 @@ class TestCheckpoint(unittest.TestCase):
atexit.register(lambda: shutil.rmtree(checkpointDir.name)) atexit.register(lambda: shutil.rmtree(checkpointDir.name))
def test_checkpoint_and_restore(self):
checkpointDir = NamedTemporaryFile(delete=False)
os.unlink(checkpointDir.name)
self.sc.setCheckpointDir(checkpointDir.name)
parCollection = self.sc.parallelize([1, 2, 3, 4])
flatMappedRDD = parCollection.flatMap(lambda x: [x])
self.assertFalse(flatMappedRDD.isCheckpointed())
self.assertIsNone(flatMappedRDD.getCheckpointFile())
flatMappedRDD.checkpoint()
flatMappedRDD.count() # forces a checkpoint to be computed
time.sleep(1) # 1 second
self.assertIsNotNone(flatMappedRDD.getCheckpointFile())
recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
self.assertEquals([1, 2, 3, 4], recovered.collect())
atexit.register(lambda: shutil.rmtree(checkpointDir.name))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment