From be7a2cfd978143f6f265eca63e9e24f755bc9f22 Mon Sep 17 00:00:00 2001
From: Shixiong Zhu <shixiong@databricks.com>
Date: Fri, 20 Nov 2015 14:23:01 -0800
Subject: [PATCH] [SPARK-11870][STREAMING][PYSPARK] Rethrow the exceptions in
 TransformFunction and TransformFunctionSerializer

TransformFunction and TransformFunctionSerializer don't rethrow the exception, so when any exception happens, it just return None. This will cause some weird NPE and confuse people.

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #9847 from zsxwing/pyspark-streaming-exception.
---
 python/pyspark/streaming/tests.py | 16 ++++++++++++++++
 python/pyspark/streaming/util.py  |  3 +++
 2 files changed, 19 insertions(+)

diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 3403f6d20d..a0e0267caf 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -403,6 +403,22 @@ class BasicOperationTests(PySparkStreamingTestCase):
         expected = [[('k', v)] for v in expected]
         self._test_func(input, func, expected)
 
+    def test_failed_func(self):
+        input = [self.sc.parallelize([d], 1) for d in range(4)]
+        input_stream = self.ssc.queueStream(input)
+
+        def failed_func(i):
+            raise ValueError("failed")
+
+        input_stream.map(failed_func).pprint()
+        self.ssc.start()
+        try:
+            self.ssc.awaitTerminationOrTimeout(10)
+        except:
+            return
+
+        self.fail("a failed func should throw an error")
+
 
 class StreamingListenerTests(PySparkStreamingTestCase):
 
diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
index b20613b128..767c732eb9 100644
--- a/python/pyspark/streaming/util.py
+++ b/python/pyspark/streaming/util.py
@@ -64,6 +64,7 @@ class TransformFunction(object):
                 return r._jrdd
         except Exception:
             traceback.print_exc()
+            raise
 
     def __repr__(self):
         return "TransformFunction(%s)" % self.func
@@ -95,6 +96,7 @@ class TransformFunctionSerializer(object):
             return bytearray(self.serializer.dumps((func.func, func.deserializers)))
         except Exception:
             traceback.print_exc()
+            raise
 
     def loads(self, data):
         try:
@@ -102,6 +104,7 @@ class TransformFunctionSerializer(object):
             return TransformFunction(self.ctx, f, *deserializers)
         except Exception:
             traceback.print_exc()
+            raise
 
     def __repr__(self):
         return "TransformFunctionSerializer(%s)" % self.serializer
-- 
GitLab