diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 5ac007cd598b9e22086bbc169c8ef7c9178b3738..080aa3b55d268885ec0cd970385c476ae872b064 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -903,11 +903,11 @@ class CheckpointTests(unittest.TestCase): def setup(): conf = SparkConf().set("spark.default.parallelism", 1) sc = SparkContext(conf=conf) - ssc = StreamingContext(sc, 0.5) + ssc = StreamingContext(sc, 2) dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1)) wc = dstream.updateStateByKey(updater) wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") - wc.checkpoint(.5) + wc.checkpoint(2) self.setupCalled = True return ssc @@ -921,21 +921,22 @@ class CheckpointTests(unittest.TestCase): def check_output(n): while not os.listdir(outputd): - time.sleep(0.01) + if self.ssc.awaitTerminationOrTimeout(0.5): + raise Exception("ssc stopped") time.sleep(1) # make sure mtime is larger than the previous one with open(os.path.join(inputd, str(n)), 'w') as f: f.writelines(["%d\n" % i for i in range(10)]) while True: + if self.ssc.awaitTerminationOrTimeout(0.5): + raise Exception("ssc stopped") p = os.path.join(outputd, max(os.listdir(outputd))) if '_SUCCESS' not in os.listdir(p): # not finished - time.sleep(0.01) continue ordd = self.ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) d = ordd.values().map(int).collect() if not d: - time.sleep(0.01) continue self.assertEqual(10, len(d)) s = set(d)