Skip to content
Snippets Groups Projects
Commit 9f54d7e1 authored by Josh Rosen's avatar Josh Rosen
Browse files

Merge pull request #387 from mateiz/python-accumulators

Add accumulators to PySpark
parents 2a8c2a67 ee5a0795
No related branches found
No related tags found
No related merge requests found
package spark.api.python
import java.io._
import java.util.{List => JList}
import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Collections}
import scala.collection.JavaConversions._
import scala.io.Source
......@@ -10,25 +11,26 @@ import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import spark.broadcast.Broadcast
import spark._
import spark.rdd.PipedRDD
import java.util
private[spark] class PythonRDD[T: ClassManifest](
parent: RDD[T],
command: Seq[String],
envVars: java.util.Map[String, String],
preservePartitoning: Boolean,
pythonExec: String,
broadcastVars: java.util.List[Broadcast[Array[Byte]]])
parent: RDD[T],
command: Seq[String],
envVars: java.util.Map[String, String],
preservePartitoning: Boolean,
pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent.context) {
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
preservePartitoning: Boolean, pythonExec: String,
broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
preservePartitoning: Boolean, pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]]) =
this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec,
broadcastVars)
broadcastVars, accumulator)
override def splits = parent.splits
......@@ -93,18 +95,30 @@ private[spark] class PythonRDD[T: ClassManifest](
// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(proc.getInputStream)
return new Iterator[Array[Byte]] {
def next() = {
def next(): Array[Byte] = {
val obj = _nextObj
_nextObj = read()
obj
}
private def read() = {
private def read(): Array[Byte] = {
try {
val length = stream.readInt()
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
if (length != -1) {
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
} else {
// We've finished the data section of the output, but we can still read some
// accumulator updates; let's do that, breaking when we get EOFException
while (true) {
val len2 = stream.readInt()
val update = new Array[Byte](len2)
stream.readFully(update)
accumulator += Collections.singletonList(update)
}
new Array[Byte](0)
}
} catch {
case eof: EOFException => {
val exitStatus = proc.waitFor()
......@@ -246,3 +260,40 @@ private class ExtractValue extends spark.api.java.function.Function[(Array[Byte]
private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
}
/**
* Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
* collects a list of pickled strings that we pass to Python through a socket.
*/
class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
: JList[Array[Byte]] = {
if (serverHost == null) {
// This happens on the worker node, where we just want to remember all the updates
val1.addAll(val2)
val1
} else {
// This happens on the master, where we pass the updates to Python through a socket
val socket = new Socket(serverHost, serverPort)
val in = socket.getInputStream
val out = new DataOutputStream(socket.getOutputStream)
out.writeInt(val2.size)
for (array <- val2) {
out.writeInt(array.length)
out.write(array)
}
out.flush()
// Wait for a byte from the Python side as an acknowledgement
val byteRead = in.read()
if (byteRead == -1) {
throw new SparkException("EOF reached before Python server acknowledged")
}
socket.close()
null
}
}
}
......@@ -16,7 +16,6 @@ There are a few key differences between the Python and Scala APIs:
* Python is dynamically typed, so RDDs can hold objects of different types.
* PySpark does not currently support the following Spark features:
- Accumulators
- Special functions on RDDs of doubles, such as `mean` and `stdev`
- `lookup`
- `persist` at storage levels other than `MEMORY_ONLY`
......
......@@ -7,6 +7,10 @@ Public classes:
Main entry point for Spark functionality.
- L{RDD<pyspark.rdd.RDD>}
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
- L{Broadcast<pyspark.broadcast.Broadcast>}
A broadcast variable that gets reused across tasks.
- L{Accumulator<pyspark.accumulators.Accumulator>}
An "add-only" shared variable that tasks can only add values to.
"""
import sys
import os
......
"""
>>> from pyspark.context import SparkContext
>>> sc = SparkContext('local', 'test')
>>> a = sc.accumulator(1)
>>> a.value
1
>>> a.value = 2
>>> a.value
2
>>> a += 5
>>> a.value
7
>>> rdd = sc.parallelize([1,2,3])
>>> def f(x):
... global a
... a += x
>>> rdd.foreach(f)
>>> a.value
13
>>> class VectorAccumulatorParam(object):
... def zero(self, value):
... return [0.0] * len(value)
... def addInPlace(self, val1, val2):
... for i in xrange(len(val1)):
... val1[i] += val2[i]
... return val1
>>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam())
>>> va.value
[1.0, 2.0, 3.0]
>>> def g(x):
... global va
... va += [x] * 3
>>> rdd.foreach(g)
>>> va.value
[7.0, 8.0, 9.0]
>>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
Py4JJavaError:...
>>> def h(x):
... global a
... a.value = 7
>>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
Py4JJavaError:...
>>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
Exception:...
"""
import struct
import SocketServer
import threading
from pyspark.cloudpickle import CloudPickler
from pyspark.serializers import read_int, read_with_length, load_pickle
# Holds accumulators registered on the current machine, keyed by ID. This is then used to send
# the local accumulator updates back to the driver program at the end of a task.
_accumulatorRegistry = {}
def _deserialize_accumulator(aid, zero_value, accum_param):
from pyspark.accumulators import _accumulatorRegistry
accum = Accumulator(aid, zero_value, accum_param)
accum._deserialized = True
_accumulatorRegistry[aid] = accum
return accum
class Accumulator(object):
"""
A shared variable that can be accumulated, i.e., has a commutative and associative "add"
operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=}
operator, but only the driver program is allowed to access its value, using C{value}.
Updates from the workers get propagated automatically to the driver program.
While C{SparkContext} supports accumulators for primitive data types like C{int} and
C{float}, users can also define accumulators for custom types by providing a custom
C{AccumulatorParam} object with a C{zero} and C{addInPlace} method. Refer to the doctest
of this module for an example.
"""
def __init__(self, aid, value, accum_param):
"""Create a new Accumulator with a given initial value and AccumulatorParam object"""
from pyspark.accumulators import _accumulatorRegistry
self.aid = aid
self.accum_param = accum_param
self._value = value
self._deserialized = False
_accumulatorRegistry[aid] = self
def __reduce__(self):
"""Custom serialization; saves the zero value from our AccumulatorParam"""
param = self.accum_param
return (_deserialize_accumulator, (self.aid, param.zero(self._value), param))
@property
def value(self):
"""Get the accumulator's value; only usable in driver program"""
if self._deserialized:
raise Exception("Accumulator.value cannot be accessed inside tasks")
return self._value
@value.setter
def value(self, value):
"""Sets the accumulator's value; only usable in driver program"""
if self._deserialized:
raise Exception("Accumulator.value cannot be accessed inside tasks")
self._value = value
def __iadd__(self, term):
"""The += operator; adds a term to this accumulator's value"""
self._value = self.accum_param.addInPlace(self._value, term)
return self
def __str__(self):
return str(self._value)
class AddingAccumulatorParam(object):
"""
An AccumulatorParam that uses the + operators to add values. Designed for simple types
such as integers, floats, and lists. Requires the zero value for the underlying type
as a parameter.
"""
def __init__(self, zero_value):
self.zero_value = zero_value
def zero(self, value):
return self.zero_value
def addInPlace(self, value1, value2):
value1 += value2
return value1
# Singleton accumulator params for some standard types
INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0)
DOUBLE_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
def handle(self):
from pyspark.accumulators import _accumulatorRegistry
num_updates = read_int(self.rfile)
for _ in range(num_updates):
(aid, update) = load_pickle(read_with_length(self.rfile))
_accumulatorRegistry[aid] += update
# Write a byte in acknowledgement
self.wfile.write(struct.pack("!b", 1))
def _start_update_server():
"""Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler)
thread = threading.Thread(target=server.serve_forever)
thread.daemon = True
thread.start()
return server
def _test():
import doctest
doctest.testmod()
if __name__ == "__main__":
_test()
......@@ -2,6 +2,8 @@ import os
import atexit
from tempfile import NamedTemporaryFile
from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import dump_pickle, write_with_length, batched
......@@ -22,6 +24,7 @@ class SparkContext(object):
_readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
_writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
_takePartition = jvm.PythonRDD.takePartition
_next_accum_id = 0
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024):
......@@ -52,6 +55,14 @@ class SparkContext(object):
self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome,
empty_string_array)
# Create a single Accumulator in Java that we'll send all our updates through;
# they will be passed back to us through a TCP server
self._accumulatorServer = accumulators._start_update_server()
(host, port) = self._accumulatorServer.server_address
self._javaAccumulator = self._jsc.accumulator(
self.jvm.java.util.ArrayList(),
self.jvm.PythonAccumulatorParam(host, port))
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
# Broadcast's __reduce__ method stores Broadcast instances here.
# This allows other code to determine which Broadcast instances have
......@@ -74,6 +85,8 @@ class SparkContext(object):
def __del__(self):
if self._jsc:
self._jsc.stop()
if self._accumulatorServer:
self._accumulatorServer.shutdown()
def stop(self):
"""
......@@ -129,6 +142,31 @@ class SparkContext(object):
return Broadcast(jbroadcast.id(), value, jbroadcast,
self._pickled_broadcast_vars)
def accumulator(self, value, accum_param=None):
"""
Create an C{Accumulator} with the given initial value, using a given
AccumulatorParam helper object to define how to add values of the data
type if provided. Default AccumulatorParams are used for integers and
floating-point numbers if you do not provide one. For other types, the
AccumulatorParam must implement two methods:
- C{zero(value)}: provide a "zero value" for the type, compatible in
dimensions with the provided C{value} (e.g., a zero vector).
- C{addInPlace(val1, val2)}: add two values of the accumulator's data
type, returning a new value; for efficiency, can also update C{val1}
in place and return it.
"""
if accum_param == None:
if isinstance(value, int):
accum_param = accumulators.INT_ACCUMULATOR_PARAM
elif isinstance(value, float):
accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM
elif isinstance(value, complex):
accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM
else:
raise Exception("No default accumulator param for type %s" % type(value))
SparkContext._next_accum_id += 1
return Accumulator(SparkContext._next_accum_id - 1, value, accum_param)
def addFile(self, path):
"""
Add a file to be downloaded into the working directory of this Spark
......
......@@ -703,7 +703,7 @@ class PipelinedRDD(RDD):
env = MapConverter().convert(env, self.ctx.gateway._gateway_client)
python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
broadcast_vars, class_manifest)
broadcast_vars, self.ctx._javaAccumulator, class_manifest)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
......
......@@ -52,8 +52,13 @@ def read_int(stream):
raise EOFError
return struct.unpack("!i", length)[0]
def write_int(value, stream):
stream.write(struct.pack("!i", value))
def write_with_length(obj, stream):
stream.write(struct.pack("!i", len(obj)))
write_int(len(obj), stream)
stream.write(obj)
......
"""
An interactive shell.
This fle is designed to be launched as a PYTHONSTARTUP script.
This file is designed to be launched as a PYTHONSTARTUP script.
"""
import os
from pyspark.context import SparkContext
......@@ -14,4 +14,4 @@ print "Spark context avaiable as sc."
# which allows us to execute the user's PYTHONSTARTUP file:
_pythonstartup = os.environ.get('OLD_PYTHONSTARTUP')
if _pythonstartup and os.path.isfile(_pythonstartup):
execfile(_pythonstartup)
execfile(_pythonstartup)
......@@ -5,9 +5,10 @@ import sys
from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module.
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.serializers import write_with_length, read_with_length, \
from pyspark.serializers import write_with_length, read_with_length, write_int, \
read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
......@@ -36,6 +37,10 @@ def main():
iterator = read_from_pickle_file(sys.stdin)
for obj in func(split_index, iterator):
write_with_length(dumps(obj), old_stdout)
# Mark the beginning of the accumulators section of the output
write_int(-1, old_stdout)
for aid, accum in _accumulatorRegistry.items():
write_with_length(dump_pickle((aid, accum._value)), old_stdout)
if __name__ == '__main__':
......
......@@ -11,6 +11,9 @@ FAILED=$(($?||$FAILED))
$FWDIR/pyspark -m doctest pyspark/broadcast.py
FAILED=$(($?||$FAILED))
$FWDIR/pyspark -m doctest pyspark/accumulators.py
FAILED=$(($?||$FAILED))
if [[ $FAILED != 0 ]]; then
echo -en "\033[31m" # Red
echo "Had test failures; see logs."
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment