Skip to content
Snippets Groups Projects
Commit 073bf9d4 authored by Josh Rosen's avatar Josh Rosen Committed by Shixiong Zhu
Browse files

[SPARK-13807] De-duplicate `Python*Helper` instantiation code in PySpark streaming

This patch de-duplicates code in PySpark streaming which loads the `Python*Helper` classes. I also changed a few `raise e` statements to simply `raise` in order to preserve the full exception stacktrace when re-throwing.

Here's a link to the whitespace-change-free diff: https://github.com/apache/spark/compare/master...JoshRosen:pyspark-reflection-deduplication?w=0

Author: Josh Rosen <joshrosen@databricks.com>

Closes #11641 from JoshRosen/pyspark-reflection-deduplication.
parent ff776b2f
No related branches found
No related tags found
No related merge requests found
......@@ -55,17 +55,8 @@ class FlumeUtils(object):
:return: A DStream object
"""
jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
try:
helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
.loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper")
helper = helperClass.newInstance()
jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression)
except Py4JJavaError as e:
if 'ClassNotFoundException' in str(e.java_exception):
FlumeUtils._printErrorMsg(ssc.sparkContext)
raise e
helper = FlumeUtils._get_helper(ssc._sc)
jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression)
return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder)
@staticmethod
......@@ -95,18 +86,9 @@ class FlumeUtils(object):
for (host, port) in addresses:
hosts.append(host)
ports.append(port)
try:
helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
.loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper")
helper = helperClass.newInstance()
jstream = helper.createPollingStream(
ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism)
except Py4JJavaError as e:
if 'ClassNotFoundException' in str(e.java_exception):
FlumeUtils._printErrorMsg(ssc.sparkContext)
raise e
helper = FlumeUtils._get_helper(ssc._sc)
jstream = helper.createPollingStream(
ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism)
return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder)
@staticmethod
......@@ -126,6 +108,18 @@ class FlumeUtils(object):
return (headers, body)
return stream.map(func)
@staticmethod
def _get_helper(sc):
try:
helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
.loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper")
return helperClass.newInstance()
except Py4JJavaError as e:
# TODO: use --jar once it also work on driver
if 'ClassNotFoundException' in str(e.java_exception):
FlumeUtils._printErrorMsg(sc)
raise
@staticmethod
def _printErrorMsg(sc):
print("""
......
......@@ -66,18 +66,8 @@ class KafkaUtils(object):
if not isinstance(topics, dict):
raise TypeError("topics should be dict")
jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
try:
# Use KafkaUtilsPythonHelper to access Scala's KafkaUtils (see SPARK-6027)
helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
helper = helperClass.newInstance()
jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel)
except Py4JJavaError as e:
# TODO: use --jar once it also work on driver
if 'ClassNotFoundException' in str(e.java_exception):
KafkaUtils._printErrorMsg(ssc.sparkContext)
raise e
helper = KafkaUtils._get_helper(ssc._sc)
jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel)
ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
stream = DStream(jstream, ssc, ser)
return stream.map(lambda k_v: (keyDecoder(k_v[0]), valueDecoder(k_v[1])))
......@@ -129,27 +119,20 @@ class KafkaUtils(object):
m._set_value_decoder(valueDecoder)
return messageHandler(m)
try:
helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
helper = helperClass.newInstance()
jfromOffsets = dict([(k._jTopicAndPartition(helper),
v) for (k, v) in fromOffsets.items()])
if messageHandler is None:
ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
func = funcWithoutMessageHandler
jstream = helper.createDirectStreamWithoutMessageHandler(
ssc._jssc, kafkaParams, set(topics), jfromOffsets)
else:
ser = AutoBatchedSerializer(PickleSerializer())
func = funcWithMessageHandler
jstream = helper.createDirectStreamWithMessageHandler(
ssc._jssc, kafkaParams, set(topics), jfromOffsets)
except Py4JJavaError as e:
if 'ClassNotFoundException' in str(e.java_exception):
KafkaUtils._printErrorMsg(ssc.sparkContext)
raise e
helper = KafkaUtils._get_helper(ssc._sc)
jfromOffsets = dict([(k._jTopicAndPartition(helper),
v) for (k, v) in fromOffsets.items()])
if messageHandler is None:
ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
func = funcWithoutMessageHandler
jstream = helper.createDirectStreamWithoutMessageHandler(
ssc._jssc, kafkaParams, set(topics), jfromOffsets)
else:
ser = AutoBatchedSerializer(PickleSerializer())
func = funcWithMessageHandler
jstream = helper.createDirectStreamWithMessageHandler(
ssc._jssc, kafkaParams, set(topics), jfromOffsets)
stream = DStream(jstream, ssc, ser).map(func)
return KafkaDStream(stream._jdstream, ssc, stream._jrdd_deserializer)
......@@ -189,28 +172,35 @@ class KafkaUtils(object):
m._set_value_decoder(valueDecoder)
return messageHandler(m)
helper = KafkaUtils._get_helper(sc)
joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges]
jleaders = dict([(k._jTopicAndPartition(helper),
v._jBroker(helper)) for (k, v) in leaders.items()])
if messageHandler is None:
jrdd = helper.createRDDWithoutMessageHandler(
sc._jsc, kafkaParams, joffsetRanges, jleaders)
ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
rdd = RDD(jrdd, sc, ser).map(funcWithoutMessageHandler)
else:
jrdd = helper.createRDDWithMessageHandler(
sc._jsc, kafkaParams, joffsetRanges, jleaders)
rdd = RDD(jrdd, sc).map(funcWithMessageHandler)
return KafkaRDD(rdd._jrdd, sc, rdd._jrdd_deserializer)
@staticmethod
def _get_helper(sc):
try:
# Use KafkaUtilsPythonHelper to access Scala's KafkaUtils (see SPARK-6027)
helperClass = sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
helper = helperClass.newInstance()
joffsetRanges = [o._jOffsetRange(helper) for o in offsetRanges]
jleaders = dict([(k._jTopicAndPartition(helper),
v._jBroker(helper)) for (k, v) in leaders.items()])
if messageHandler is None:
jrdd = helper.createRDDWithoutMessageHandler(
sc._jsc, kafkaParams, joffsetRanges, jleaders)
ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
rdd = RDD(jrdd, sc, ser).map(funcWithoutMessageHandler)
else:
jrdd = helper.createRDDWithMessageHandler(
sc._jsc, kafkaParams, joffsetRanges, jleaders)
rdd = RDD(jrdd, sc).map(funcWithMessageHandler)
return helperClass.newInstance()
except Py4JJavaError as e:
# TODO: use --jar once it also work on driver
if 'ClassNotFoundException' in str(e.java_exception):
KafkaUtils._printErrorMsg(sc)
raise e
return KafkaRDD(rdd._jrdd, sc, rdd._jrdd_deserializer)
raise
@staticmethod
def _printErrorMsg(sc):
......@@ -333,16 +323,8 @@ class KafkaRDD(RDD):
Get the OffsetRange of specific KafkaRDD.
:return: A list of OffsetRange
"""
try:
helperClass = self.ctx._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
.loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
helper = helperClass.newInstance()
joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd())
except Py4JJavaError as e:
if 'ClassNotFoundException' in str(e.java_exception):
KafkaUtils._printErrorMsg(self.ctx)
raise e
helper = KafkaUtils._get_helper(self.ctx)
joffsetRanges = helper.offsetRangesOfKafkaRDD(self._jrdd.rdd())
ranges = [OffsetRange(o.topic(), o.partition(), o.fromOffset(), o.untilOffset())
for o in joffsetRanges]
return ranges
......
......@@ -83,7 +83,7 @@ class KinesisUtils(object):
except Py4JJavaError as e:
if 'ClassNotFoundException' in str(e.java_exception):
KinesisUtils._printErrorMsg(ssc.sparkContext)
raise e
raise
stream = DStream(jstream, ssc, NoOpSerializer())
return stream.map(lambda v: decoder(v))
......
......@@ -48,7 +48,7 @@ class MQTTUtils(object):
except Py4JJavaError as e:
if 'ClassNotFoundException' in str(e.java_exception):
MQTTUtils._printErrorMsg(ssc.sparkContext)
raise e
raise
return DStream(jstream, ssc, UTF8Deserializer())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment