Skip to content
Snippets Groups Projects
Commit 00d70cd6 authored by Josh Rosen's avatar Josh Rosen
Browse files

Clean up setup code in PySpark checkpointing tests

parent 5b6ea9e9
No related branches found
No related tags found
No related merge requests found
......@@ -691,7 +691,7 @@ class PipelinedRDD(RDD):
20
"""
def __init__(self, prev, func, preservesPartitioning=False):
if isinstance(prev, PipelinedRDD) and prev._is_pipelinable:
if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
prev_func = prev.func
def pipeline_func(split, iterator):
return func(split, prev_func(split, iterator))
......@@ -737,7 +737,6 @@ class PipelinedRDD(RDD):
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
@property
def _is_pipelinable(self):
return not (self.is_cached or self.is_checkpointed)
......
......@@ -2,7 +2,6 @@
Unit tests for PySpark; additional tests are implemented as doctests in
individual modules.
"""
import atexit
import os
import shutil
from tempfile import NamedTemporaryFile
......@@ -16,18 +15,18 @@ class TestCheckpoint(unittest.TestCase):
def setUp(self):
self.sc = SparkContext('local[4]', 'TestPartitioning', batchSize=2)
self.checkpointDir = NamedTemporaryFile(delete=False)
os.unlink(self.checkpointDir.name)
self.sc.setCheckpointDir(self.checkpointDir.name)
def tearDown(self):
self.sc.stop()
# To avoid Akka rebinding to the same port, since it doesn't unbind
# immediately on shutdown
self.sc.jvm.System.clearProperty("spark.master.port")
shutil.rmtree(self.checkpointDir.name)
def test_basic_checkpointing(self):
checkpointDir = NamedTemporaryFile(delete=False)
os.unlink(checkpointDir.name)
self.sc.setCheckpointDir(checkpointDir.name)
parCollection = self.sc.parallelize([1, 2, 3, 4])
flatMappedRDD = parCollection.flatMap(lambda x: range(1, x + 1))
......@@ -39,16 +38,10 @@ class TestCheckpoint(unittest.TestCase):
time.sleep(1) # 1 second
self.assertTrue(flatMappedRDD.isCheckpointed())
self.assertEqual(flatMappedRDD.collect(), result)
self.assertEqual(checkpointDir.name,
self.assertEqual(self.checkpointDir.name,
os.path.dirname(flatMappedRDD.getCheckpointFile()))
atexit.register(lambda: shutil.rmtree(checkpointDir.name))
def test_checkpoint_and_restore(self):
checkpointDir = NamedTemporaryFile(delete=False)
os.unlink(checkpointDir.name)
self.sc.setCheckpointDir(checkpointDir.name)
parCollection = self.sc.parallelize([1, 2, 3, 4])
flatMappedRDD = parCollection.flatMap(lambda x: [x])
......@@ -63,8 +56,6 @@ class TestCheckpoint(unittest.TestCase):
recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
self.assertEquals([1, 2, 3, 4], recovered.collect())
atexit.register(lambda: shutil.rmtree(checkpointDir.name))
if __name__ == "__main__":
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment