From 073bf9d4d91e0242a813f3d227e52e76c26a2200 Mon Sep 17 00:00:00 2001
From: Josh Rosen <joshrosen@databricks.com>
Date: Fri, 11 Mar 2016 11:18:51 -0800
Subject: [PATCH] [SPARK-13807] De-duplicate `Python*Helper` instantiation code
 in PySpark streaming

This patch de-duplicates code in PySpark streaming which loads the `Python*Helper` classes. I also changed a few `raise e` statements to simply `raise` in order to preserve the full exception stacktrace when re-throwing.

Here's a link to the whitespace-change-free diff: https://github.com/apache/spark/compare/master...JoshRosen:pyspark-reflection-deduplication?w=0

Author: Josh Rosen <joshrosen@databricks.com>

Closes #11641 from JoshRosen/pyspark-reflection-deduplication.
---
 python/pyspark/streaming/flume.py   |  40 +++++------
 python/pyspark/streaming/kafka.py   | 100 ++++++++++++----------------
 python/pyspark/streaming/kinesis.py |   2 +-
 python/pyspark/streaming/mqtt.py    |   2 +-
 4 files changed, 60 insertions(+), 84 deletions(-)

diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py
index b1fff0a5c7..edd5886a85 100644
--- a/python/pyspark/streaming/flume.py
+++ b/python/pyspark/streaming/flume.py
@@ -55,17 +55,8 @@ class FlumeUtils(object):
         :return: A DStream object
         """
         jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
-
-        try:
-            helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
-                .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper")
-            helper = helperClass.newInstance()
-            jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression)
-        except Py4JJavaError as e:
-            if 'ClassNotFoundException' in str(e.java_exception):
-                FlumeUtils._printErrorMsg(ssc.sparkContext)
-            raise e
-
+        helper = FlumeUtils._get_helper(ssc._sc)
+        jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression)
         return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder)
 
     @staticmethod
@@ -95,18 +86,9 @@ class FlumeUtils(object):
         for (host, port) in addresses:
             hosts.append(host)
             ports.append(port)
-
-        try:
-            helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
-                .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper")
-            helper = helperClass.newInstance()
-            jstream = helper.createPollingStream(
-                ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism)
-        except Py4JJavaError as e:
-            if 'ClassNotFoundException' in str(e.java_exception):
-                FlumeUtils._printErrorMsg(ssc.sparkContext)
-            raise e
-
+        helper = FlumeUtils._get_helper(ssc._sc)
+        jstream = helper.createPollingStream(
+            ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism)
         return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder)
 
     @staticmethod
@@ -126,6 +108,18 @@ class FlumeUtils(object):
             return (headers, body)
         return stream.map(func)
 
+    @staticmethod
+    def _get_helper(sc):
+        try:
+            helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
+                .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper")
+            return helperClass.newInstance()
+        except Py4JJavaError as e:
+            # TODO: use --jar once it also work on driver
+            if 'ClassNotFoundException' in str(e.java_exception):
+                FlumeUtils._printErrorMsg(sc)
+            raise
+
     @staticmethod
     def _printErrorMsg(sc):
         print("""
diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py
index 13f8f9578e..a70b99249d 100644
--- a/python/pyspark/streaming/kafka.py
+++ b/python/pyspark/streaming/kafka.py
@@ -66,18 +66,8 @@ class KafkaUtils(object):
         if not isinstance(topics, dict):
             raise TypeError("topics should be dict")
         jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
-
-        try:
-            # Use KafkaUtilsPythonHelper to access Scala's KafkaUtils (see SPARK-6027)
-            helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
-                .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
-            helper = helperClass.newInstance()
-            jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel)
-        except Py4JJavaError as e:
-            # TODO: use --jar once it also work on driver
-            if 'ClassNotFoundException' in str(e.java_exception):
-                KafkaUtils._printErrorMsg(ssc.sparkContext)
-            raise e
+        helper = KafkaUtils._get_helper(ssc._sc)
+        jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel)
         ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
         stream = DStream(jstream, ssc, ser)
         return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
@@ -129,27 +119,20 @@ class KafkaUtils(object):
             m._set_value_decoder(valueDecoder)
             return messageHandler(m)
 
-        try:
-            helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
-                .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
-            helper = helperClass.newInstance()
-
-            jfromOffsets = dict([(k._jTopicAndPartition(helper),
-                                  v) for (k, v) in fromOffsets.items()])
-            if messageHandler is None:
-                ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
-                func = funcWithoutMessageHandler
-                jstream = helper.createDirectStreamWithoutMessageHandler(
-                    ssc._jssc, kafkaParams, set(topics), jfromOffsets)
-            else:
-                ser = AutoBatchedSerializer(PickleSerializer())
-                func = funcWithMessageHandler
-                jstream = helper.createDirectStreamWithMessageHandler(
-                    ssc._jssc, kafkaParams, set(topics), jfromOffsets)
-        except Py4JJavaError as e:
-            if 'ClassNotFoundException' in str(e.java_exception):
-                KafkaUtils._printErrorMsg(ssc.sparkContext)
-            raise e
+        helper = KafkaUtils._get_helper(ssc._sc)
+
+        jfromOffsets = dict([(k._jTopicAndPartition(helper),
+                              v) for (k, v) in fromOffsets.items()])
+        if messageHandler is None:
+            ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
+            func = funcWithoutMessageHandler
+            jstream = helper.createDirectStreamWithoutMessageHandler(
+                ssc._jssc, kafkaParams, set(topics), jfromOffsets)
+        else:
+            ser = AutoBatchedSerializer(PickleSerializer())
+            func = funcWithMessageHandler
+            jstream = helper.createDirectStreamWithMessageHandler(
+                ssc._jssc, kafkaParams, set(topics), jfromOffsets)
 
         stream = DStream(jstream, ssc, ser).map(func)
         return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer)
@@ -189,28 +172,35 @@ class KafkaUtils(object):
             m._set_value_decoder(valueDecoder)
             return messageHandler(m)
 
+        helper = KafkaUtils._get_helper(sc)
+
+        joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges]
+        jleaders = dict([(k._jTopicAndPartition(helper),
+                          v._jBroker(helper)) for (k, v) in leaders.items()])
+        if messageHandler is None:
+            jrdd = helper.createRDDWithoutMessageHandler(
+                sc._jsc, kafkaParams, joffsetRanges, jleaders)
+            ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
+            rdd = RDD(jrdd, sc, ser).map(funcWithoutMessageHandler)
+        else:
+            jrdd = helper.createRDDWithMessageHandler(
+                sc._jsc, kafkaParams, joffsetRanges, jleaders)
+            rdd = RDD(jrdd, sc).map(funcWithMessageHandler)
+
+        return KafkaRDD(rdd._jrdd, sc, rdd._jrdd_deserializer)
+
+    @staticmethod
+    def _get_helper(sc):
         try:
+            # Use KafkaUtilsPythonHelper to access Scala's KafkaUtils (see SPARK-6027)
             helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
                 .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
-            helper = helperClass.newInstance()
-            joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges]
-            jleaders = dict([(k._jTopicAndPartition(helper),
-                              v._jBroker(helper)) for (k, v) in leaders.items()])
-            if messageHandler is None:
-                jrdd = helper.createRDDWithoutMessageHandler(
-                    sc._jsc, kafkaParams, joffsetRanges, jleaders)
-                ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
-                rdd = RDD(jrdd, sc, ser).map(funcWithoutMessageHandler)
-            else:
-                jrdd = helper.createRDDWithMessageHandler(
-                    sc._jsc, kafkaParams, joffsetRanges, jleaders)
-                rdd = RDD(jrdd, sc).map(funcWithMessageHandler)
+            return helperClass.newInstance()
         except Py4JJavaError as e:
+            # TODO: use --jar once it also work on driver
             if 'ClassNotFoundException' in str(e.java_exception):
                 KafkaUtils._printErrorMsg(sc)
-            raise e
-
-        return KafkaRDD(rdd._jrdd, sc, rdd._jrdd_deserializer)
+            raise
 
     @staticmethod
     def _printErrorMsg(sc):
@@ -333,16 +323,8 @@ class KafkaRDD(RDD):
         Get the OffsetRange of specific KafkaRDD.
         :return: A list of OffsetRange
         """
-        try:
-            helperClass = self.ctx._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
-                .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
-            helper = helperClass.newInstance()
-            joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd())
-        except Py4JJavaError as e:
-            if 'ClassNotFoundException' in str(e.java_exception):
-                KafkaUtils._printErrorMsg(self.ctx)
-            raise e
-
+        helper = KafkaUtils._get_helper(self.ctx)
+        joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd())
         ranges = [OffsetRange(o.topic(), o.partition(), o.fromOffset(), o.untilOffset())
                   for o in joffsetRanges]
         return ranges
diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py
index af72c3d690..e681301681 100644
--- a/python/pyspark/streaming/kinesis.py
+++ b/python/pyspark/streaming/kinesis.py
@@ -83,7 +83,7 @@ class KinesisUtils(object):
         except Py4JJavaError as e:
             if 'ClassNotFoundException' in str(e.java_exception):
                 KinesisUtils._printErrorMsg(ssc.sparkContext)
-            raise e
+            raise
         stream = DStream(jstream, ssc, NoOpSerializer())
         return stream.map(lambda v: decoder(v))
 
diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py
index 3a515ea499..388e9526ba 100644
--- a/python/pyspark/streaming/mqtt.py
+++ b/python/pyspark/streaming/mqtt.py
@@ -48,7 +48,7 @@ class MQTTUtils(object):
         except Py4JJavaError as e:
             if 'ClassNotFoundException' in str(e.java_exception):
                 MQTTUtils._printErrorMsg(ssc.sparkContext)
-            raise e
+            raise
 
         return DStream(jstream, ssc, UTF8Deserializer())
 
-- 
GitLab