Skip to content
Snippets Groups Projects
Commit 200d248d authored by Josh Rosen's avatar Josh Rosen
Browse files

Simplify Python worker; pipeline the map step of partitionBy().

parent 6904cb77
No related branches found
No related tags found
No related merge requests found
...@@ -151,38 +151,18 @@ class PythonRDD[T: ClassManifest]( ...@@ -151,38 +151,18 @@ class PythonRDD[T: ClassManifest](
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
} }
class PythonPairRDD[T: ClassManifest] ( private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
parent: RDD[T], command: Seq[String], envVars: Map[String, String], RDD[(Array[Byte], Array[Byte])](prev.context) {
preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) override def splits = prev.splits
extends RDD[(Array[Byte], Array[Byte])](parent.context) with PythonRDDBase { override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) =
def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, prev.iterator(split).grouped(2).map {
pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars)
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String,
broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars)
override def splits = parent.splits
override val dependencies = List(new OneToOneDependency(parent))
override val partitioner = if (preservePartitoning) parent.partitioner else None
override def compute(split: Split): Iterator[(Array[Byte], Array[Byte])] = {
compute(split, envVars, command, parent, pythonExec, broadcastVars).grouped(2).map {
case Seq(a, b) => (a, b) case Seq(a, b) => (a, b)
case x => throw new Exception("PythonPairRDD: unexpected value: " + x) case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
} }
}
val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this) val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
} }
object PythonRDD { object PythonRDD {
/** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */
......
...@@ -4,7 +4,7 @@ from tempfile import NamedTemporaryFile ...@@ -4,7 +4,7 @@ from tempfile import NamedTemporaryFile
from pyspark.broadcast import Broadcast from pyspark.broadcast import Broadcast
from pyspark.java_gateway import launch_gateway from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, dumps from pyspark.serializers import dump_pickle, write_with_length
from pyspark.rdd import RDD from pyspark.rdd import RDD
...@@ -16,9 +16,8 @@ class SparkContext(object): ...@@ -16,9 +16,8 @@ class SparkContext(object):
asPickle = jvm.spark.api.python.PythonRDD.asPickle asPickle = jvm.spark.api.python.PythonRDD.asPickle
arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle
def __init__(self, master, name, defaultParallelism=None, def __init__(self, master, name, defaultParallelism=None,
pythonExec='python'): pythonExec='python'):
self.master = master self.master = master
self.name = name self.name = name
self._jsc = self.jvm.JavaSparkContext(master, name) self._jsc = self.jvm.JavaSparkContext(master, name)
...@@ -52,7 +51,7 @@ class SparkContext(object): ...@@ -52,7 +51,7 @@ class SparkContext(object):
# objects are written to a file and loaded through textFile(). # objects are written to a file and loaded through textFile().
tempFile = NamedTemporaryFile(delete=False) tempFile = NamedTemporaryFile(delete=False)
for x in c: for x in c:
dumps(PickleSerializer.dumps(x), tempFile) write_with_length(dump_pickle(x), tempFile)
tempFile.close() tempFile.close()
atexit.register(lambda: os.unlink(tempFile.name)) atexit.register(lambda: os.unlink(tempFile.name))
jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices) jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices)
...@@ -64,6 +63,6 @@ class SparkContext(object): ...@@ -64,6 +63,6 @@ class SparkContext(object):
return RDD(jrdd, self) return RDD(jrdd, self)
def broadcast(self, value): def broadcast(self, value):
jbroadcast = self._jsc.broadcast(bytearray(PickleSerializer.dumps(value))) jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast, return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast,
self._pickled_broadcast_vars) self._pickled_broadcast_vars)
...@@ -3,7 +3,7 @@ from collections import Counter ...@@ -3,7 +3,7 @@ from collections import Counter
from itertools import chain, ifilter, imap from itertools import chain, ifilter, imap
from pyspark import cloudpickle from pyspark import cloudpickle
from pyspark.serializers import PickleSerializer from pyspark.serializers import dump_pickle, load_pickle
from pyspark.join import python_join, python_left_outer_join, \ from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup python_right_outer_join, python_cogroup
...@@ -17,17 +17,6 @@ class RDD(object): ...@@ -17,17 +17,6 @@ class RDD(object):
self.is_cached = False self.is_cached = False
self.ctx = ctx self.ctx = ctx
@classmethod
def _get_pipe_command(cls, ctx, command, functions):
worker_args = [command]
for f in functions:
worker_args.append(b64enc(cloudpickle.dumps(f)))
broadcast_vars = [x._jbroadcast for x in ctx._pickled_broadcast_vars]
broadcast_vars = ListConverter().convert(broadcast_vars,
ctx.gateway._gateway_client)
ctx._pickled_broadcast_vars.clear()
return (" ".join(worker_args), broadcast_vars)
def cache(self): def cache(self):
self.is_cached = True self.is_cached = True
self._jrdd.cache() self._jrdd.cache()
...@@ -66,14 +55,6 @@ class RDD(object): ...@@ -66,14 +55,6 @@ class RDD(object):
def func(iterator): return ifilter(f, iterator) def func(iterator): return ifilter(f, iterator)
return self.mapPartitions(func) return self.mapPartitions(func)
def _pipe(self, functions, command):
class_manifest = self._jrdd.classManifest()
(pipe_command, broadcast_vars) = \
RDD._get_pipe_command(self.ctx, command, functions)
python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command,
False, self.ctx.pythonExec, broadcast_vars, class_manifest)
return python_rdd.asJavaRDD()
def distinct(self): def distinct(self):
""" """
>>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect()) >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect())
...@@ -89,7 +70,7 @@ class RDD(object): ...@@ -89,7 +70,7 @@ class RDD(object):
def takeSample(self, withReplacement, num, seed): def takeSample(self, withReplacement, num, seed):
vals = self._jrdd.takeSample(withReplacement, num, seed) vals = self._jrdd.takeSample(withReplacement, num, seed)
return [PickleSerializer.loads(bytes(x)) for x in vals] return [load_pickle(bytes(x)) for x in vals]
def union(self, other): def union(self, other):
""" """
...@@ -148,7 +129,7 @@ class RDD(object): ...@@ -148,7 +129,7 @@ class RDD(object):
def collect(self): def collect(self):
pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().collect()) pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().collect())
return PickleSerializer.loads(bytes(pickle)) return load_pickle(bytes(pickle))
def reduce(self, f): def reduce(self, f):
""" """
...@@ -216,19 +197,17 @@ class RDD(object): ...@@ -216,19 +197,17 @@ class RDD(object):
[2, 3] [2, 3]
""" """
pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num)) pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num))
return PickleSerializer.loads(bytes(pickle)) return load_pickle(bytes(pickle))
def first(self): def first(self):
""" """
>>> sc.parallelize([2, 3, 4]).first() >>> sc.parallelize([2, 3, 4]).first()
2 2
""" """
return PickleSerializer.loads(bytes(self.ctx.asPickle(self._jrdd.first()))) return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first())))
# TODO: saveAsTextFile # TODO: saveAsTextFile
# TODO: saveAsObjectFile
# Pair functions # Pair functions
def collectAsMap(self): def collectAsMap(self):
...@@ -303,19 +282,18 @@ class RDD(object): ...@@ -303,19 +282,18 @@ class RDD(object):
""" """
return python_right_outer_join(self, other, numSplits) return python_right_outer_join(self, other, numSplits)
# TODO: pipelining
# TODO: optimizations
def partitionBy(self, numSplits, hashFunc=hash): def partitionBy(self, numSplits, hashFunc=hash):
if numSplits is None: if numSplits is None:
numSplits = self.ctx.defaultParallelism numSplits = self.ctx.defaultParallelism
(pipe_command, broadcast_vars) = \ def add_shuffle_key(iterator):
RDD._get_pipe_command(self.ctx, 'shuffle_map_step', [hashFunc]) for (k, v) in iterator:
class_manifest = self._jrdd.classManifest() yield str(hashFunc(k))
python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), yield dump_pickle((k, v))
pipe_command, False, self.ctx.pythonExec, broadcast_vars, keyed = PipelinedRDD(self, add_shuffle_key)
class_manifest) keyed._bypass_serializer = True
pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits)
jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner) jrdd = pairRDD.partitionBy(partitioner)
jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
return RDD(jrdd, self.ctx) return RDD(jrdd, self.ctx)
...@@ -430,17 +408,23 @@ class PipelinedRDD(RDD): ...@@ -430,17 +408,23 @@ class PipelinedRDD(RDD):
self.ctx = prev.ctx self.ctx = prev.ctx
self.prev = prev self.prev = prev
self._jrdd_val = None self._jrdd_val = None
self._bypass_serializer = False
@property @property
def _jrdd(self): def _jrdd(self):
if not self._jrdd_val: if self._jrdd_val:
(pipe_command, broadcast_vars) = \ return self._jrdd_val
RDD._get_pipe_command(self.ctx, "pipeline", [self.func]) funcs = [self.func, self._bypass_serializer]
class_manifest = self._prev_jrdd.classManifest() pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in funcs)
python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), broadcast_vars = ListConverter().convert(
pipe_command, self.preservesPartitioning, self.ctx.pythonExec, [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
broadcast_vars, class_manifest) self.ctx.gateway._gateway_client)
self._jrdd_val = python_rdd.asJavaRDD() self.ctx._pickled_broadcast_vars.clear()
class_manifest = self._prev_jrdd.classManifest()
python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
pipe_command, self.preservesPartitioning, self.ctx.pythonExec,
broadcast_vars, class_manifest)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val return self._jrdd_val
......
"""
Data serialization methods.
The Spark Python API is built on top of the Spark Java API. RDDs created in
Python are stored in Java as RDD[Array[Byte]]. Python objects are
automatically serialized/deserialized, so this representation is transparent to
the end-user.
"""
from collections import namedtuple
import cPickle
import struct import struct
import cPickle
Serializer = namedtuple("Serializer", ["dumps","loads"]) def dump_pickle(obj):
return cPickle.dumps(obj, 2)
PickleSerializer = Serializer( load_pickle = cPickle.loads
lambda obj: cPickle.dumps(obj, -1),
cPickle.loads)
def dumps(obj, stream): def write_with_length(obj, stream):
# TODO: determining the length of non-byte objects.
stream.write(struct.pack("!i", len(obj))) stream.write(struct.pack("!i", len(obj)))
stream.write(obj) stream.write(obj)
def loads(stream): def read_with_length(stream):
length = stream.read(4) length = stream.read(4)
if length == "": if length == "":
raise EOFError raise EOFError
......
...@@ -7,61 +7,41 @@ from base64 import standard_b64decode ...@@ -7,61 +7,41 @@ from base64 import standard_b64decode
# copy_reg module. # copy_reg module.
from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler from pyspark.cloudpickle import CloudPickler
from pyspark.serializers import dumps, loads, PickleSerializer from pyspark.serializers import write_with_length, read_with_length, \
import cPickle dump_pickle, load_pickle
# Redirect stdout to stderr so that users must return values from functions. # Redirect stdout to stderr so that users must return values from functions.
old_stdout = sys.stdout old_stdout = sys.stdout
sys.stdout = sys.stderr sys.stdout = sys.stderr
def load_function(): def load_obj():
return cPickle.loads(standard_b64decode(sys.stdin.readline().strip())) return load_pickle(standard_b64decode(sys.stdin.readline().strip()))
def output(x):
dumps(x, old_stdout)
def read_input(): def read_input():
try: try:
while True: while True:
yield cPickle.loads(loads(sys.stdin)) yield load_pickle(read_with_length(sys.stdin))
except EOFError: except EOFError:
return return
def do_pipeline():
f = load_function()
for obj in f(read_input()):
output(PickleSerializer.dumps(obj))
def do_shuffle_map_step():
hashFunc = load_function()
while True:
try:
pickled = loads(sys.stdin)
except EOFError:
return
key = cPickle.loads(pickled)[0]
output(str(hashFunc(key)))
output(pickled)
def main(): def main():
num_broadcast_variables = int(sys.stdin.readline().strip()) num_broadcast_variables = int(sys.stdin.readline().strip())
for _ in range(num_broadcast_variables): for _ in range(num_broadcast_variables):
uuid = sys.stdin.read(36) uuid = sys.stdin.read(36)
value = loads(sys.stdin) value = read_with_length(sys.stdin)
_broadcastRegistry[uuid] = Broadcast(uuid, cPickle.loads(value)) _broadcastRegistry[uuid] = Broadcast(uuid, load_pickle(value))
command = sys.stdin.readline().strip() func = load_obj()
if command == "pipeline": bypassSerializer = load_obj()
do_pipeline() if bypassSerializer:
elif command == "shuffle_map_step": dumps = lambda x: x
do_shuffle_map_step()
else: else:
raise Exception("Unsupported command %s" % command) dumps = dump_pickle
for obj in func(read_input()):
write_with_length(dumps(obj), old_stdout)
if __name__ == '__main__': if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment