From 8e7f098a2c9e5e85cb9435f28d53a3a5847c14aa Mon Sep 17 00:00:00 2001 From: Matei Zaharia <matei@eecs.berkeley.edu> Date: Sun, 20 Jan 2013 01:57:44 -0800 Subject: [PATCH] Added accumulators to PySpark --- .../scala/spark/api/python/PythonRDD.scala | 83 +++++++-- python/pyspark/__init__.py | 4 + python/pyspark/accumulators.py | 166 ++++++++++++++++++ python/pyspark/context.py | 38 ++++ python/pyspark/rdd.py | 2 +- python/pyspark/serializers.py | 7 +- python/pyspark/shell.py | 4 +- python/pyspark/worker.py | 7 +- 8 files changed, 290 insertions(+), 21 deletions(-) create mode 100644 python/pyspark/accumulators.py diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index f431ef28d3..fb13e84658 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -1,7 +1,8 @@ package spark.api.python import java.io._ -import java.util.{List => JList} +import java.net._ +import java.util.{List => JList, ArrayList => JArrayList, Collections} import scala.collection.JavaConversions._ import scala.io.Source @@ -10,25 +11,26 @@ import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import spark.broadcast.Broadcast import spark._ import spark.rdd.PipedRDD -import java.util private[spark] class PythonRDD[T: ClassManifest]( - parent: RDD[T], - command: Seq[String], - envVars: java.util.Map[String, String], - preservePartitoning: Boolean, - pythonExec: String, - broadcastVars: java.util.List[Broadcast[Array[Byte]]]) + parent: RDD[T], + command: Seq[String], + envVars: java.util.Map[String, String], + preservePartitoning: Boolean, + pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent.context) { // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String], - preservePartitoning: Boolean, pythonExec: String, - broadcastVars: java.util.List[Broadcast[Array[Byte]]]) = + preservePartitoning: Boolean, pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], + accumulator: Accumulator[JList[Array[Byte]]]) = this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec, - broadcastVars) + broadcastVars, accumulator) override def splits = parent.splits @@ -93,18 +95,30 @@ private[spark] class PythonRDD[T: ClassManifest]( // Return an iterator that read lines from the process's stdout val stream = new DataInputStream(proc.getInputStream) return new Iterator[Array[Byte]] { - def next() = { + def next(): Array[Byte] = { val obj = _nextObj _nextObj = read() obj } - private def read() = { + private def read(): Array[Byte] = { try { val length = stream.readInt() - val obj = new Array[Byte](length) - stream.readFully(obj) - obj + if (length != -1) { + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + } else { + // We've finished the data section of the output, but we can still read some + // accumulator updates; let's do that, breaking when we get EOFException + while (true) { + val len2 = stream.readInt() + val update = new Array[Byte](len2) + stream.readFully(update) + accumulator += Collections.singletonList(update) + } + new Array[Byte](0) + } } catch { case eof: EOFException => { val exitStatus = proc.waitFor() @@ -246,3 +260,40 @@ private class ExtractValue extends spark.api.java.function.Function[(Array[Byte] private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] { override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } + +/** + * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it + * collects a list of pickled strings that we pass to Python through a socket. + */ +class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) + extends AccumulatorParam[JList[Array[Byte]]] { + + override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList + + override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) + : JList[Array[Byte]] = { + if (serverHost == null) { + // This happens on the worker node, where we just want to remember all the updates + val1.addAll(val2) + val1 + } else { + // This happens on the master, where we pass the updates to Python through a socket + val socket = new Socket(serverHost, serverPort) + val in = socket.getInputStream + val out = new DataOutputStream(socket.getOutputStream) + out.writeInt(val2.size) + for (array <- val2) { + out.writeInt(array.length) + out.write(array) + } + out.flush() + // Wait for a byte from the Python side as an acknowledgement + val byteRead = in.read() + if (byteRead == -1) { + throw new SparkException("EOF reached before Python server acknowledged") + } + socket.close() + null + } + } +} diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index c595ae0842..00666bc0a3 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -7,6 +7,10 @@ Public classes: Main entry point for Spark functionality. - L{RDD<pyspark.rdd.RDD>} A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. + - L{Broadcast<pyspark.broadcast.Broadcast>} + A broadcast variable that gets reused across tasks. + - L{Accumulator<pyspark.accumulators.Accumulator>} + An "add-only" shared variable that tasks can only add values to. """ import sys import os diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py new file mode 100644 index 0000000000..438af4cfc0 --- /dev/null +++ b/python/pyspark/accumulators.py @@ -0,0 +1,166 @@ +""" +>>> from pyspark.context import SparkContext +>>> sc = SparkContext('local', 'test') +>>> a = sc.accumulator(1) +>>> a.value +1 +>>> a.value = 2 +>>> a.value +2 +>>> a += 5 +>>> a.value +7 + +>>> rdd = sc.parallelize([1,2,3]) +>>> def f(x): +... global a +... a += x +>>> rdd.foreach(f) +>>> a.value +13 + +>>> class VectorAccumulatorParam(object): +... def zero(self, value): +... return [0.0] * len(value) +... def addInPlace(self, val1, val2): +... for i in xrange(len(val1)): +... val1[i] += val2[i] +... return val1 +>>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) +>>> va.value +[1.0, 2.0, 3.0] +>>> def g(x): +... global va +... va += [x] * 3 +>>> rdd.foreach(g) +>>> va.value +[7.0, 8.0, 9.0] + +>>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Py4JJavaError:... + +>>> def h(x): +... global a +... a.value = 7 +>>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Py4JJavaError:... + +>>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL +Traceback (most recent call last): + ... +Exception:... +""" + +import struct +import SocketServer +import threading +from pyspark.cloudpickle import CloudPickler +from pyspark.serializers import read_int, read_with_length, load_pickle + + +# Holds accumulators registered on the current machine, keyed by ID. This is then used to send +# the local accumulator updates back to the driver program at the end of a task. +_accumulatorRegistry = {} + + +def _deserialize_accumulator(aid, zero_value, accum_param): + from pyspark.accumulators import _accumulatorRegistry + accum = Accumulator(aid, zero_value, accum_param) + accum._deserialized = True + _accumulatorRegistry[aid] = accum + return accum + + +class Accumulator(object): + def __init__(self, aid, value, accum_param): + """Create a new Accumulator with a given initial value and AccumulatorParam object""" + from pyspark.accumulators import _accumulatorRegistry + self.aid = aid + self.accum_param = accum_param + self._value = value + self._deserialized = False + _accumulatorRegistry[aid] = self + + def __reduce__(self): + """Custom serialization; saves the zero value from our AccumulatorParam""" + param = self.accum_param + return (_deserialize_accumulator, (self.aid, param.zero(self._value), param)) + + @property + def value(self): + """Get the accumulator's value; only usable in driver program""" + if self._deserialized: + raise Exception("Accumulator.value cannot be accessed inside tasks") + return self._value + + @value.setter + def value(self, value): + """Sets the accumulator's value; only usable in driver program""" + if self._deserialized: + raise Exception("Accumulator.value cannot be accessed inside tasks") + self._value = value + + def __iadd__(self, term): + """The += operator; adds a term to this accumulator's value""" + self._value = self.accum_param.addInPlace(self._value, term) + return self + + def __str__(self): + return str(self._value) + + +class AddingAccumulatorParam(object): + """ + An AccumulatorParam that uses the + operators to add values. Designed for simple types + such as integers, floats, and lists. Requires the zero value for the underlying type + as a parameter. + """ + + def __init__(self, zero_value): + self.zero_value = zero_value + + def zero(self, value): + return self.zero_value + + def addInPlace(self, value1, value2): + value1 += value2 + return value1 + + +# Singleton accumulator params for some standard types +INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0) +DOUBLE_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) +COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) + + +class _UpdateRequestHandler(SocketServer.StreamRequestHandler): + def handle(self): + from pyspark.accumulators import _accumulatorRegistry + num_updates = read_int(self.rfile) + for _ in range(num_updates): + (aid, update) = load_pickle(read_with_length(self.rfile)) + _accumulatorRegistry[aid] += update + # Write a byte in acknowledgement + self.wfile.write(struct.pack("!b", 1)) + + +def _start_update_server(): + """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" + server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler) + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + thread.start() + return server + + +def _test(): + import doctest + doctest.testmod() + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index e486f206b0..1e2f845f9c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -2,6 +2,8 @@ import os import atexit from tempfile import NamedTemporaryFile +from pyspark import accumulators +from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast from pyspark.java_gateway import launch_gateway from pyspark.serializers import dump_pickle, write_with_length, batched @@ -22,6 +24,7 @@ class SparkContext(object): _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile _takePartition = jvm.PythonRDD.takePartition + _next_accum_id = 0 def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): @@ -52,6 +55,14 @@ class SparkContext(object): self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, empty_string_array) + # Create a single Accumulator in Java that we'll send all our updates through; + # they will be passed back to us through a TCP server + self._accumulatorServer = accumulators._start_update_server() + (host, port) = self._accumulatorServer.server_address + self._javaAccumulator = self._jsc.accumulator( + self.jvm.java.util.ArrayList(), + self.jvm.PythonAccumulatorParam(host, port)) + self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have @@ -74,6 +85,8 @@ class SparkContext(object): def __del__(self): if self._jsc: self._jsc.stop() + if self._accumulatorServer: + self._accumulatorServer.shutdown() def stop(self): """ @@ -129,6 +142,31 @@ class SparkContext(object): return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) + def accumulator(self, value, accum_param=None): + """ + Create an C{Accumulator} with the given initial value, using a given + AccumulatorParam helper object to define how to add values of the data + type if provided. Default AccumulatorParams are used for integers and + floating-point numbers if you do not provide one. For other types, the + AccumulatorParam must implement two methods: + - C{zero(value)}: provide a "zero value" for the type, compatible in + dimensions with the provided C{value} (e.g., a zero vector). + - C{addInPlace(val1, val2)}: add two values of the accumulator's data + type, returning a new value; for efficiency, can also update C{val1} + in place and return it. + """ + if accum_param == None: + if isinstance(value, int): + accum_param = accumulators.INT_ACCUMULATOR_PARAM + elif isinstance(value, float): + accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM + elif isinstance(value, complex): + accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM + else: + raise Exception("No default accumulator param for type %s" % type(value)) + SparkContext._next_accum_id += 1 + return Accumulator(SparkContext._next_accum_id - 1, value, accum_param) + def addFile(self, path): """ Add a file to be downloaded into the working directory of this Spark diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1d36da42b0..d705f0f9e1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -703,7 +703,7 @@ class PipelinedRDD(RDD): env = MapConverter().convert(env, self.ctx.gateway._gateway_client) python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, class_manifest) + broadcast_vars, self.ctx._javaAccumulator, class_manifest) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 9a5151ea00..115cf28cc2 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -52,8 +52,13 @@ def read_int(stream): raise EOFError return struct.unpack("!i", length)[0] + +def write_int(value, stream): + stream.write(struct.pack("!i", value)) + + def write_with_length(obj, stream): - stream.write(struct.pack("!i", len(obj))) + write_int(len(obj), stream) stream.write(obj) diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 7e6ad3aa76..f6328c561f 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -1,7 +1,7 @@ """ An interactive shell. -This fle is designed to be launched as a PYTHONSTARTUP script. +This file is designed to be launched as a PYTHONSTARTUP script. """ import os from pyspark.context import SparkContext @@ -14,4 +14,4 @@ print "Spark context avaiable as sc." # which allows us to execute the user's PYTHONSTARTUP file: _pythonstartup = os.environ.get('OLD_PYTHONSTARTUP') if _pythonstartup and os.path.isfile(_pythonstartup): - execfile(_pythonstartup) + execfile(_pythonstartup) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 3d792bbaa2..b2b9288089 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -5,9 +5,10 @@ import sys from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. +from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler -from pyspark.serializers import write_with_length, read_with_length, \ +from pyspark.serializers import write_with_length, read_with_length, write_int, \ read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file @@ -36,6 +37,10 @@ def main(): iterator = read_from_pickle_file(sys.stdin) for obj in func(split_index, iterator): write_with_length(dumps(obj), old_stdout) + # Mark the beginning of the accumulators section of the output + write_int(-1, old_stdout) + for aid, accum in _accumulatorRegistry.items(): + write_with_length(dump_pickle((aid, accum._value)), old_stdout) if __name__ == '__main__': -- GitLab