From 8e7f098a2c9e5e85cb9435f28d53a3a5847c14aa Mon Sep 17 00:00:00 2001
From: Matei Zaharia <matei@eecs.berkeley.edu>
Date: Sun, 20 Jan 2013 01:57:44 -0800
Subject: [PATCH] Added accumulators to PySpark

---
 .../scala/spark/api/python/PythonRDD.scala    |  83 +++++++--
 python/pyspark/__init__.py                    |   4 +
 python/pyspark/accumulators.py                | 166 ++++++++++++++++++
 python/pyspark/context.py                     |  38 ++++
 python/pyspark/rdd.py                         |   2 +-
 python/pyspark/serializers.py                 |   7 +-
 python/pyspark/shell.py                       |   4 +-
 python/pyspark/worker.py                      |   7 +-
 8 files changed, 290 insertions(+), 21 deletions(-)
 create mode 100644 python/pyspark/accumulators.py

diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index f431ef28d3..fb13e84658 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -1,7 +1,8 @@
 package spark.api.python
 
 import java.io._
-import java.util.{List => JList}
+import java.net._
+import java.util.{List => JList, ArrayList => JArrayList, Collections}
 
 import scala.collection.JavaConversions._
 import scala.io.Source
@@ -10,25 +11,26 @@ import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
 import spark.broadcast.Broadcast
 import spark._
 import spark.rdd.PipedRDD
-import java.util
 
 
 private[spark] class PythonRDD[T: ClassManifest](
-  parent: RDD[T],
-  command: Seq[String],
-  envVars: java.util.Map[String, String],
-  preservePartitoning: Boolean,
-  pythonExec: String,
-  broadcastVars: java.util.List[Broadcast[Array[Byte]]])
+    parent: RDD[T],
+    command: Seq[String],
+    envVars: java.util.Map[String, String],
+    preservePartitoning: Boolean,
+    pythonExec: String,
+    broadcastVars: JList[Broadcast[Array[Byte]]],
+    accumulator: Accumulator[JList[Array[Byte]]])
   extends RDD[Array[Byte]](parent.context) {
 
   // Similar to Runtime.exec(), if we are given a single string, split it into words
   // using a standard StringTokenizer (i.e. by spaces)
   def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
-    preservePartitoning: Boolean, pythonExec: String,
-    broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
+      preservePartitoning: Boolean, pythonExec: String,
+      broadcastVars: JList[Broadcast[Array[Byte]]],
+      accumulator: Accumulator[JList[Array[Byte]]]) =
     this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec,
-      broadcastVars)
+      broadcastVars, accumulator)
 
   override def splits = parent.splits
 
@@ -93,18 +95,30 @@ private[spark] class PythonRDD[T: ClassManifest](
     // Return an iterator that read lines from the process's stdout
     val stream = new DataInputStream(proc.getInputStream)
     return new Iterator[Array[Byte]] {
-      def next() = {
+      def next(): Array[Byte] = {
         val obj = _nextObj
         _nextObj = read()
         obj
       }
 
-      private def read() = {
+      private def read(): Array[Byte] = {
         try {
           val length = stream.readInt()
-          val obj = new Array[Byte](length)
-          stream.readFully(obj)
-          obj
+          if (length != -1) {
+            val obj = new Array[Byte](length)
+            stream.readFully(obj)
+            obj
+          } else {
+            // We've finished the data section of the output, but we can still read some
+            // accumulator updates; let's do that, breaking when we get EOFException
+            while (true) {
+              val len2 = stream.readInt()
+              val update = new Array[Byte](len2)
+              stream.readFully(update)
+              accumulator += Collections.singletonList(update)
+            }
+            new Array[Byte](0)
+          }
         } catch {
           case eof: EOFException => {
             val exitStatus = proc.waitFor()
@@ -246,3 +260,40 @@ private class ExtractValue extends spark.api.java.function.Function[(Array[Byte]
 private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
   override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
 }
+
+/**
+ * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
+ * collects a list of pickled strings that we pass to Python through a socket.
+ */
+class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
+  extends AccumulatorParam[JList[Array[Byte]]] {
+  
+  override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
+
+  override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
+      : JList[Array[Byte]] = {
+    if (serverHost == null) {
+      // This happens on the worker node, where we just want to remember all the updates
+      val1.addAll(val2)
+      val1
+    } else {
+      // This happens on the master, where we pass the updates to Python through a socket
+      val socket = new Socket(serverHost, serverPort)
+      val in = socket.getInputStream
+      val out = new DataOutputStream(socket.getOutputStream)
+      out.writeInt(val2.size)
+      for (array <- val2) {
+        out.writeInt(array.length)
+        out.write(array)
+      }
+      out.flush()
+      // Wait for a byte from the Python side as an acknowledgement
+      val byteRead = in.read()
+      if (byteRead == -1) {
+        throw new SparkException("EOF reached before Python server acknowledged")
+      }
+      socket.close()
+      null
+    }
+  }
+}
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index c595ae0842..00666bc0a3 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -7,6 +7,10 @@ Public classes:
         Main entry point for Spark functionality.
     - L{RDD<pyspark.rdd.RDD>}
         A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
+    - L{Broadcast<pyspark.broadcast.Broadcast>}
+        A broadcast variable that gets reused across tasks.
+    - L{Accumulator<pyspark.accumulators.Accumulator>}
+        An "add-only" shared variable that tasks can only add values to.
 """
 import sys
 import os
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
new file mode 100644
index 0000000000..438af4cfc0
--- /dev/null
+++ b/python/pyspark/accumulators.py
@@ -0,0 +1,166 @@
+"""
+>>> from pyspark.context import SparkContext
+>>> sc = SparkContext('local', 'test')
+>>> a = sc.accumulator(1)
+>>> a.value
+1
+>>> a.value = 2
+>>> a.value
+2
+>>> a += 5
+>>> a.value
+7
+
+>>> rdd = sc.parallelize([1,2,3])
+>>> def f(x):
+...     global a
+...     a += x
+>>> rdd.foreach(f)
+>>> a.value
+13
+
+>>> class VectorAccumulatorParam(object):
+...     def zero(self, value):
+...         return [0.0] * len(value)
+...     def addInPlace(self, val1, val2):
+...         for i in xrange(len(val1)):
+...              val1[i] += val2[i]
+...         return val1
+>>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam())
+>>> va.value
+[1.0, 2.0, 3.0]
+>>> def g(x):
+...     global va
+...     va += [x] * 3
+>>> rdd.foreach(g)
+>>> va.value
+[7.0, 8.0, 9.0]
+
+>>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL
+Traceback (most recent call last):
+    ...
+Py4JJavaError:...
+
+>>> def h(x):
+...     global a
+...     a.value = 7
+>>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL
+Traceback (most recent call last):
+    ...
+Py4JJavaError:...
+
+>>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL
+Traceback (most recent call last):
+    ...
+Exception:...
+"""
+
+import struct
+import SocketServer
+import threading
+from pyspark.cloudpickle import CloudPickler
+from pyspark.serializers import read_int, read_with_length, load_pickle
+
+
+# Holds accumulators registered on the current machine, keyed by ID. This is then used to send
+# the local accumulator updates back to the driver program at the end of a task.
+_accumulatorRegistry = {}
+
+
+def _deserialize_accumulator(aid, zero_value, accum_param):
+    from pyspark.accumulators import _accumulatorRegistry
+    accum = Accumulator(aid, zero_value, accum_param)
+    accum._deserialized = True
+    _accumulatorRegistry[aid] = accum
+    return accum
+
+
+class Accumulator(object):
+    def __init__(self, aid, value, accum_param):
+        """Create a new Accumulator with a given initial value and AccumulatorParam object"""
+        from pyspark.accumulators import _accumulatorRegistry
+        self.aid = aid
+        self.accum_param = accum_param
+        self._value = value
+        self._deserialized = False
+        _accumulatorRegistry[aid] = self
+
+    def __reduce__(self):
+        """Custom serialization; saves the zero value from our AccumulatorParam"""
+        param = self.accum_param
+        return (_deserialize_accumulator, (self.aid, param.zero(self._value), param))
+
+    @property
+    def value(self):
+        """Get the accumulator's value; only usable in driver program"""
+        if self._deserialized:
+            raise Exception("Accumulator.value cannot be accessed inside tasks")
+        return self._value
+
+    @value.setter
+    def value(self, value):
+        """Sets the accumulator's value; only usable in driver program"""
+        if self._deserialized:
+            raise Exception("Accumulator.value cannot be accessed inside tasks")
+        self._value = value
+
+    def __iadd__(self, term):
+        """The += operator; adds a term to this accumulator's value"""
+        self._value = self.accum_param.addInPlace(self._value, term)
+        return self
+
+    def __str__(self):
+        return str(self._value)
+
+
+class AddingAccumulatorParam(object):
+    """
+    An AccumulatorParam that uses the + operators to add values. Designed for simple types
+    such as integers, floats, and lists. Requires the zero value for the underlying type
+    as a parameter.
+    """
+
+    def __init__(self, zero_value):
+        self.zero_value = zero_value
+
+    def zero(self, value):
+        return self.zero_value
+
+    def addInPlace(self, value1, value2):
+        value1 += value2
+        return value1
+
+
+# Singleton accumulator params for some standard types
+INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0)
+DOUBLE_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
+COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
+
+
+class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
+    def handle(self):
+        from pyspark.accumulators import _accumulatorRegistry
+        num_updates = read_int(self.rfile)
+        for _ in range(num_updates):
+            (aid, update) = load_pickle(read_with_length(self.rfile))
+            _accumulatorRegistry[aid] += update
+        # Write a byte in acknowledgement
+        self.wfile.write(struct.pack("!b", 1))
+
+
+def _start_update_server():
+    """Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
+    server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler)
+    thread = threading.Thread(target=server.serve_forever)
+    thread.daemon = True
+    thread.start()
+    return server
+
+
+def _test():
+    import doctest
+    doctest.testmod()
+
+
+if __name__ == "__main__":
+    _test()
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index e486f206b0..1e2f845f9c 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -2,6 +2,8 @@ import os
 import atexit
 from tempfile import NamedTemporaryFile
 
+from pyspark import accumulators
+from pyspark.accumulators import Accumulator
 from pyspark.broadcast import Broadcast
 from pyspark.java_gateway import launch_gateway
 from pyspark.serializers import dump_pickle, write_with_length, batched
@@ -22,6 +24,7 @@ class SparkContext(object):
     _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
     _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
     _takePartition = jvm.PythonRDD.takePartition
+    _next_accum_id = 0
 
     def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
         environment=None, batchSize=1024):
@@ -52,6 +55,14 @@ class SparkContext(object):
         self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome,
                                               empty_string_array)
 
+        # Create a single Accumulator in Java that we'll send all our updates through;
+        # they will be passed back to us through a TCP server
+        self._accumulatorServer = accumulators._start_update_server()
+        (host, port) = self._accumulatorServer.server_address
+        self._javaAccumulator = self._jsc.accumulator(
+                self.jvm.java.util.ArrayList(),
+                self.jvm.PythonAccumulatorParam(host, port))
+
         self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
         # Broadcast's __reduce__ method stores Broadcast instances here.
         # This allows other code to determine which Broadcast instances have
@@ -74,6 +85,8 @@ class SparkContext(object):
     def __del__(self):
         if self._jsc:
             self._jsc.stop()
+        if self._accumulatorServer:
+            self._accumulatorServer.shutdown()
 
     def stop(self):
         """
@@ -129,6 +142,31 @@ class SparkContext(object):
         return Broadcast(jbroadcast.id(), value, jbroadcast,
                          self._pickled_broadcast_vars)
 
+    def accumulator(self, value, accum_param=None):
+        """
+        Create an C{Accumulator} with the given initial value, using a given
+        AccumulatorParam helper object to define how to add values of the data 
+        type if provided. Default AccumulatorParams are used for integers and
+        floating-point numbers if you do not provide one. For other types, the
+        AccumulatorParam must implement two methods:
+        - C{zero(value)}: provide a "zero value" for the type, compatible in
+          dimensions with the provided C{value} (e.g., a zero vector).
+        - C{addInPlace(val1, val2)}: add two values of the accumulator's data
+          type, returning a new value; for efficiency, can also update C{val1}
+          in place and return it.
+        """
+        if accum_param == None:
+            if isinstance(value, int):
+                accum_param = accumulators.INT_ACCUMULATOR_PARAM
+            elif isinstance(value, float):
+                accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM
+            elif isinstance(value, complex):
+                accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM
+            else:
+                raise Exception("No default accumulator param for type %s" % type(value))
+        SparkContext._next_accum_id += 1
+        return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
+
     def addFile(self, path):
         """
         Add a file to be downloaded into the working directory of this Spark
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 1d36da42b0..d705f0f9e1 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -703,7 +703,7 @@ class PipelinedRDD(RDD):
         env = MapConverter().convert(env, self.ctx.gateway._gateway_client)
         python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
             pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
-            broadcast_vars, class_manifest)
+            broadcast_vars, self.ctx._javaAccumulator, class_manifest)
         self._jrdd_val = python_rdd.asJavaRDD()
         return self._jrdd_val
 
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 9a5151ea00..115cf28cc2 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -52,8 +52,13 @@ def read_int(stream):
         raise EOFError
     return struct.unpack("!i", length)[0]
 
+
+def write_int(value, stream):
+    stream.write(struct.pack("!i", value))
+
+
 def write_with_length(obj, stream):
-    stream.write(struct.pack("!i", len(obj)))
+    write_int(len(obj), stream)
     stream.write(obj)
 
 
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index 7e6ad3aa76..f6328c561f 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -1,7 +1,7 @@
 """
 An interactive shell.
 
-This fle is designed to be launched as a PYTHONSTARTUP script.
+This file is designed to be launched as a PYTHONSTARTUP script.
 """
 import os
 from pyspark.context import SparkContext
@@ -14,4 +14,4 @@ print "Spark context avaiable as sc."
 # which allows us to execute the user's PYTHONSTARTUP file:
 _pythonstartup = os.environ.get('OLD_PYTHONSTARTUP')
 if _pythonstartup and os.path.isfile(_pythonstartup):
-        execfile(_pythonstartup)
+    execfile(_pythonstartup)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 3d792bbaa2..b2b9288089 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -5,9 +5,10 @@ import sys
 from base64 import standard_b64decode
 # CloudPickler needs to be imported so that depicklers are registered using the
 # copy_reg module.
+from pyspark.accumulators import _accumulatorRegistry
 from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.cloudpickle import CloudPickler
-from pyspark.serializers import write_with_length, read_with_length, \
+from pyspark.serializers import write_with_length, read_with_length, write_int, \
     read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
 
 
@@ -36,6 +37,10 @@ def main():
     iterator = read_from_pickle_file(sys.stdin)
     for obj in func(split_index, iterator):
         write_with_length(dumps(obj), old_stdout)
+    # Mark the beginning of the accumulators section of the output
+    write_int(-1, old_stdout)
+    for aid, accum in _accumulatorRegistry.items():
+        write_with_length(dump_pickle((aid, accum._value)), old_stdout)
 
 
 if __name__ == '__main__':
-- 
GitLab