Skip to content
Snippets Groups Projects
Commit f79a1e4d authored by Josh Rosen's avatar Josh Rosen
Browse files

Add broadcast variables to Python API.

parent 65e84060
No related branches found
No related tags found
No related merge requests found
...@@ -7,14 +7,13 @@ import scala.collection.JavaConversions._ ...@@ -7,14 +7,13 @@ import scala.collection.JavaConversions._
import scala.io.Source import scala.io.Source
import spark._ import spark._
import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import scala.{collection, Some} import broadcast.Broadcast
import collection.parallel.mutable
import scala.collection import scala.collection
import scala.Some
trait PythonRDDBase { trait PythonRDDBase {
def compute[T](split: Split, envVars: Map[String, String], def compute[T](split: Split, envVars: Map[String, String],
command: Seq[String], parent: RDD[T], pythonExec: String): Iterator[Array[Byte]] = { command: Seq[String], parent: RDD[T], pythonExec: String,
broadcastVars: java.util.List[Broadcast[Array[Byte]]]): Iterator[Array[Byte]] = {
val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME") val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py")) val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/pyspark/pyspark/worker.py"))
...@@ -42,11 +41,18 @@ trait PythonRDDBase { ...@@ -42,11 +41,18 @@ trait PythonRDDBase {
override def run() { override def run() {
SparkEnv.set(env) SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream) val out = new PrintWriter(proc.getOutputStream)
val dOut = new DataOutputStream(proc.getOutputStream)
out.println(broadcastVars.length)
for (broadcast <- broadcastVars) {
out.print(broadcast.uuid.toString)
dOut.writeInt(broadcast.value.length)
dOut.write(broadcast.value)
dOut.flush()
}
for (elem <- command) { for (elem <- command) {
out.println(elem) out.println(elem)
} }
out.flush() out.flush()
val dOut = new DataOutputStream(proc.getOutputStream)
for (elem <- parent.iterator(split)) { for (elem <- parent.iterator(split)) {
if (elem.isInstanceOf[Array[Byte]]) { if (elem.isInstanceOf[Array[Byte]]) {
val arr = elem.asInstanceOf[Array[Byte]] val arr = elem.asInstanceOf[Array[Byte]]
...@@ -121,16 +127,17 @@ trait PythonRDDBase { ...@@ -121,16 +127,17 @@ trait PythonRDDBase {
class PythonRDD[T: ClassManifest]( class PythonRDD[T: ClassManifest](
parent: RDD[T], command: Seq[String], envVars: Map[String, String], parent: RDD[T], command: Seq[String], envVars: Map[String, String],
preservePartitoning: Boolean, pythonExec: String) preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]])
extends RDD[Array[Byte]](parent.context) with PythonRDDBase { extends RDD[Array[Byte]](parent.context) with PythonRDDBase {
def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean,
this(parent, command, Map(), preservePartitoning, pythonExec) pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars)
// Similar to Runtime.exec(), if we are given a single string, split it into words // Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces) // using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) = def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec) this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars)
override def splits = parent.splits override def splits = parent.splits
...@@ -139,23 +146,25 @@ class PythonRDD[T: ClassManifest]( ...@@ -139,23 +146,25 @@ class PythonRDD[T: ClassManifest](
override val partitioner = if (preservePartitoning) parent.partitioner else None override val partitioner = if (preservePartitoning) parent.partitioner else None
override def compute(split: Split): Iterator[Array[Byte]] = override def compute(split: Split): Iterator[Array[Byte]] =
compute(split, envVars, command, parent, pythonExec) compute(split, envVars, command, parent, pythonExec, broadcastVars)
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
} }
class PythonPairRDD[T: ClassManifest] ( class PythonPairRDD[T: ClassManifest] (
parent: RDD[T], command: Seq[String], envVars: Map[String, String], parent: RDD[T], command: Seq[String], envVars: Map[String, String],
preservePartitoning: Boolean, pythonExec: String) preservePartitoning: Boolean, pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]])
extends RDD[(Array[Byte], Array[Byte])](parent.context) with PythonRDDBase { extends RDD[(Array[Byte], Array[Byte])](parent.context) with PythonRDDBase {
def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean, pythonExec: String) = def this(parent: RDD[T], command: Seq[String], preservePartitoning: Boolean,
this(parent, command, Map(), preservePartitoning, pythonExec) pythonExec: String, broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
this(parent, command, Map(), preservePartitoning, pythonExec, broadcastVars)
// Similar to Runtime.exec(), if we are given a single string, split it into words // Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces) // using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String) = def this(parent: RDD[T], command: String, preservePartitoning: Boolean, pythonExec: String,
this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec) broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
this(parent, PipedRDD.tokenize(command), preservePartitoning, pythonExec, broadcastVars)
override def splits = parent.splits override def splits = parent.splits
...@@ -164,7 +173,7 @@ class PythonPairRDD[T: ClassManifest] ( ...@@ -164,7 +173,7 @@ class PythonPairRDD[T: ClassManifest] (
override val partitioner = if (preservePartitoning) parent.partitioner else None override val partitioner = if (preservePartitoning) parent.partitioner else None
override def compute(split: Split): Iterator[(Array[Byte], Array[Byte])] = { override def compute(split: Split): Iterator[(Array[Byte], Array[Byte])] = {
compute(split, envVars, command, parent, pythonExec).grouped(2).map { compute(split, envVars, command, parent, pythonExec, broadcastVars).grouped(2).map {
case Seq(a, b) => (a, b) case Seq(a, b) => (a, b)
case x => throw new Exception("PythonPairRDD: unexpected value: " + x) case x => throw new Exception("PythonPairRDD: unexpected value: " + x)
} }
......
"""
>>> from pyspark.context import SparkContext
>>> sc = SparkContext('local', 'test')
>>> b = sc.broadcast([1, 2, 3, 4, 5])
>>> b.value
[1, 2, 3, 4, 5]
>>> from pyspark.broadcast import _broadcastRegistry
>>> _broadcastRegistry[b.uuid] = b
>>> from cPickle import dumps, loads
>>> loads(dumps(b)).value
[1, 2, 3, 4, 5]
>>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
"""
# Holds broadcasted data received from Java, keyed by UUID.
_broadcastRegistry = {}
def _from_uuid(uuid):
from pyspark.broadcast import _broadcastRegistry
if uuid not in _broadcastRegistry:
raise Exception("Broadcast variable '%s' not loaded!" % uuid)
return _broadcastRegistry[uuid]
class Broadcast(object):
def __init__(self, uuid, value, java_broadcast=None, pickle_registry=None):
self.value = value
self.uuid = uuid
self._jbroadcast = java_broadcast
self._pickle_registry = pickle_registry
def __reduce__(self):
self._pickle_registry.add(self)
return (_from_uuid, (self.uuid, ))
def _test():
import doctest
doctest.testmod()
if __name__ == "__main__":
_test()
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
import atexit import atexit
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from pyspark.broadcast import Broadcast
from pyspark.java_gateway import launch_gateway from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, dumps from pyspark.serializers import PickleSerializer, dumps
from pyspark.rdd import RDD from pyspark.rdd import RDD
...@@ -24,6 +25,11 @@ class SparkContext(object): ...@@ -24,6 +25,11 @@ class SparkContext(object):
self.defaultParallelism = \ self.defaultParallelism = \
defaultParallelism or self._jsc.sc().defaultParallelism() defaultParallelism or self._jsc.sc().defaultParallelism()
self.pythonExec = pythonExec self.pythonExec = pythonExec
# Broadcast's __reduce__ method stores Broadcast instances here.
# This allows other code to determine which Broadcast instances have
# been pickled, so it can determine which Java broadcast objects to
# send.
self._pickled_broadcast_vars = set()
def __del__(self): def __del__(self):
if self._jsc: if self._jsc:
...@@ -52,7 +58,12 @@ class SparkContext(object): ...@@ -52,7 +58,12 @@ class SparkContext(object):
jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices) jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices)
return RDD(jrdd, self) return RDD(jrdd, self)
def textFile(self, name, numSlices=None): def textFile(self, name, minSplits=None):
numSlices = numSlices or self.defaultParallelism minSplits = minSplits or min(self.defaultParallelism, 2)
jrdd = self._jsc.textFile(name, numSlices) jrdd = self._jsc.textFile(name, minSplits)
return RDD(jrdd, self) return RDD(jrdd, self)
def broadcast(self, value):
jbroadcast = self._jsc.broadcast(bytearray(PickleSerializer.dumps(value)))
return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast,
self._pickled_broadcast_vars)
...@@ -6,6 +6,8 @@ from pyspark.serializers import PickleSerializer ...@@ -6,6 +6,8 @@ from pyspark.serializers import PickleSerializer
from pyspark.join import python_join, python_left_outer_join, \ from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup python_right_outer_join, python_cogroup
from py4j.java_collections import ListConverter
class RDD(object): class RDD(object):
...@@ -15,11 +17,15 @@ class RDD(object): ...@@ -15,11 +17,15 @@ class RDD(object):
self.ctx = ctx self.ctx = ctx
@classmethod @classmethod
def _get_pipe_command(cls, command, functions): def _get_pipe_command(cls, ctx, command, functions):
worker_args = [command] worker_args = [command]
for f in functions: for f in functions:
worker_args.append(b64enc(cloudpickle.dumps(f))) worker_args.append(b64enc(cloudpickle.dumps(f)))
return " ".join(worker_args) broadcast_vars = [x._jbroadcast for x in ctx._pickled_broadcast_vars]
broadcast_vars = ListConverter().convert(broadcast_vars,
ctx.gateway._gateway_client)
ctx._pickled_broadcast_vars.clear()
return (" ".join(worker_args), broadcast_vars)
def cache(self): def cache(self):
self.is_cached = True self.is_cached = True
...@@ -52,9 +58,10 @@ class RDD(object): ...@@ -52,9 +58,10 @@ class RDD(object):
def _pipe(self, functions, command): def _pipe(self, functions, command):
class_manifest = self._jrdd.classManifest() class_manifest = self._jrdd.classManifest()
pipe_command = RDD._get_pipe_command(command, functions) (pipe_command, broadcast_vars) = \
RDD._get_pipe_command(self.ctx, command, functions)
python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command, python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command,
False, self.ctx.pythonExec, class_manifest) False, self.ctx.pythonExec, broadcast_vars, class_manifest)
return python_rdd.asJavaRDD() return python_rdd.asJavaRDD()
def distinct(self): def distinct(self):
...@@ -249,10 +256,12 @@ class RDD(object): ...@@ -249,10 +256,12 @@ class RDD(object):
def shuffle(self, numSplits, hashFunc=hash): def shuffle(self, numSplits, hashFunc=hash):
if numSplits is None: if numSplits is None:
numSplits = self.ctx.defaultParallelism numSplits = self.ctx.defaultParallelism
pipe_command = RDD._get_pipe_command('shuffle_map_step', [hashFunc]) (pipe_command, broadcast_vars) = \
RDD._get_pipe_command(self.ctx, 'shuffle_map_step', [hashFunc])
class_manifest = self._jrdd.classManifest() class_manifest = self._jrdd.classManifest()
python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(), python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(),
pipe_command, False, self.ctx.pythonExec, class_manifest) pipe_command, False, self.ctx.pythonExec, broadcast_vars,
class_manifest)
partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits) partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits)
jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner) jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner)
jrdd = jrdd.map(self.ctx.jvm.ExtractValue()) jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
...@@ -360,12 +369,12 @@ class PipelinedRDD(RDD): ...@@ -360,12 +369,12 @@ class PipelinedRDD(RDD):
@property @property
def _jrdd(self): def _jrdd(self):
if not self._jrdd_val: if not self._jrdd_val:
funcs = [self.func] (pipe_command, broadcast_vars) = \
pipe_command = RDD._get_pipe_command("pipeline", funcs) RDD._get_pipe_command(self.ctx, "pipeline", [self.func])
class_manifest = self._prev_jrdd.classManifest() class_manifest = self._prev_jrdd.classManifest()
python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
pipe_command, self.preservesPartitioning, self.ctx.pythonExec, pipe_command, self.preservesPartitioning, self.ctx.pythonExec,
class_manifest) broadcast_vars, class_manifest)
self._jrdd_val = python_rdd.asJavaRDD() self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val return self._jrdd_val
......
...@@ -5,6 +5,7 @@ import sys ...@@ -5,6 +5,7 @@ import sys
from base64 import standard_b64decode from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the # CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module. # copy_reg module.
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler from pyspark.cloudpickle import CloudPickler
from pyspark.serializers import dumps, loads, PickleSerializer from pyspark.serializers import dumps, loads, PickleSerializer
import cPickle import cPickle
...@@ -63,6 +64,11 @@ def do_shuffle_map_step(): ...@@ -63,6 +64,11 @@ def do_shuffle_map_step():
def main(): def main():
num_broadcast_variables = int(sys.stdin.readline().strip())
for _ in range(num_broadcast_variables):
uuid = sys.stdin.read(36)
value = loads(sys.stdin)
_broadcastRegistry[uuid] = Broadcast(uuid, cPickle.loads(value))
command = sys.stdin.readline().strip() command = sys.stdin.readline().strip()
if command == "pipeline": if command == "pipeline":
do_pipeline() do_pipeline()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment