diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 3403f6d20d78998a544ce20e4e11d764e0872967..a0e0267cafa58ef511682967f15cc045951090ee 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -403,6 +403,22 @@ class BasicOperationTests(PySparkStreamingTestCase): expected = [[('k', v)] for v in expected] self._test_func(input, func, expected) + def test_failed_func(self): + input = [self.sc.parallelize([d], 1) for d in range(4)] + input_stream = self.ssc.queueStream(input) + + def failed_func(i): + raise ValueError("failed") + + input_stream.map(failed_func).pprint() + self.ssc.start() + try: + self.ssc.awaitTerminationOrTimeout(10) + except: + return + + self.fail("a failed func should throw an error") + class StreamingListenerTests(PySparkStreamingTestCase): diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index b20613b1283bd3f17816392486e6fbc6ac00213e..767c732eb90b49e8c980e9061710f642af58bf92 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -64,6 +64,7 @@ class TransformFunction(object): return r._jrdd except Exception: traceback.print_exc() + raise def __repr__(self): return "TransformFunction(%s)" % self.func @@ -95,6 +96,7 @@ class TransformFunctionSerializer(object): return bytearray(self.serializer.dumps((func.func, func.deserializers))) except Exception: traceback.print_exc() + raise def loads(self, data): try: @@ -102,6 +104,7 @@ class TransformFunctionSerializer(object): return TransformFunction(self.ctx, f, *deserializers) except Exception: traceback.print_exc() + raise def __repr__(self): return "TransformFunctionSerializer(%s)" % self.serializer