Skip to content
Snippets Groups Projects
Commit 4608902f authored by Josh Rosen's avatar Josh Rosen
Browse files

Use filesystem to collect RDDs in PySpark.

Passing large volumes of data through Py4J seems
to be slow.  It appears to be faster to write the
data to the local filesystem and read it back from
Python.
parent ccd075cf
No related branches found
No related tags found
No related merge requests found
package spark.api.python
import java.io._
import java.util.{List => JList}
import scala.collection.Map
import scala.collection.JavaConversions._
......@@ -59,36 +60,7 @@ trait PythonRDDBase {
}
out.flush()
for (elem <- parent.iterator(split)) {
if (elem.isInstanceOf[Array[Byte]]) {
val arr = elem.asInstanceOf[Array[Byte]]
dOut.writeInt(arr.length)
dOut.write(arr)
} else if (elem.isInstanceOf[scala.Tuple2[_, _]]) {
val t = elem.asInstanceOf[scala.Tuple2[_, _]]
val t1 = t._1.asInstanceOf[Array[Byte]]
val t2 = t._2.asInstanceOf[Array[Byte]]
val length = t1.length + t2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes
dOut.writeInt(length)
dOut.writeByte(Pickle.PROTO)
dOut.writeByte(Pickle.TWO)
dOut.write(PythonRDD.stripPickle(t1))
dOut.write(PythonRDD.stripPickle(t2))
dOut.writeByte(Pickle.TUPLE2)
dOut.writeByte(Pickle.STOP)
} else if (elem.isInstanceOf[String]) {
// For uniformity, strings are wrapped into Pickles.
val s = elem.asInstanceOf[String].getBytes("UTF-8")
val length = 2 + 1 + 4 + s.length + 1
dOut.writeInt(length)
dOut.writeByte(Pickle.PROTO)
dOut.writeByte(Pickle.TWO)
dOut.writeByte(Pickle.BINUNICODE)
dOut.writeInt(Integer.reverseBytes(s.length))
dOut.write(s)
dOut.writeByte(Pickle.STOP)
} else {
throw new Exception("Unexpected RDD type")
}
PythonRDD.writeAsPickle(elem, dOut)
}
dOut.flush()
out.flush()
......@@ -174,36 +146,45 @@ object PythonRDD {
arr.slice(2, arr.length - 1)
}
def asPickle(elem: Any) : Array[Byte] = {
val baos = new ByteArrayOutputStream();
val dOut = new DataOutputStream(baos);
/**
* Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
* The data format is a 32-bit integer representing the pickled object's length (in bytes),
* followed by the pickled data.
* @param elem the object to write
* @param dOut a data output stream
*/
def writeAsPickle(elem: Any, dOut: DataOutputStream) {
if (elem.isInstanceOf[Array[Byte]]) {
elem.asInstanceOf[Array[Byte]]
val arr = elem.asInstanceOf[Array[Byte]]
dOut.writeInt(arr.length)
dOut.write(arr)
} else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes
dOut.writeInt(length)
dOut.writeByte(Pickle.PROTO)
dOut.writeByte(Pickle.TWO)
dOut.write(PythonRDD.stripPickle(t._1))
dOut.write(PythonRDD.stripPickle(t._2))
dOut.writeByte(Pickle.TUPLE2)
dOut.writeByte(Pickle.STOP)
baos.toByteArray()
} else if (elem.isInstanceOf[String]) {
// For uniformity, strings are wrapped into Pickles.
val s = elem.asInstanceOf[String].getBytes("UTF-8")
val length = 2 + 1 + 4 + s.length + 1
dOut.writeInt(length)
dOut.writeByte(Pickle.PROTO)
dOut.writeByte(Pickle.TWO)
dOut.write(Pickle.BINUNICODE)
dOut.writeInt(Integer.reverseBytes(s.length))
dOut.write(s)
dOut.writeByte(Pickle.STOP)
baos.toByteArray()
} else {
throw new Exception("Unexpected RDD type")
}
}
def pickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
......@@ -221,11 +202,12 @@ object PythonRDD {
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
def arrayAsPickle(arr : Any) : Array[Byte] = {
val pickles : Array[Byte] = arr.asInstanceOf[Array[Any]].map(asPickle).map(stripPickle).flatten
Array[Byte](Pickle.PROTO, Pickle.TWO, Pickle.EMPTY_LIST, Pickle.MARK) ++ pickles ++
Array[Byte] (Pickle.APPENDS, Pickle.STOP)
def writeArrayToPickleFile[T](items: Array[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename))
for (item <- items) {
writeAsPickle(item, file)
}
file.close()
}
}
......
......@@ -14,9 +14,8 @@ class SparkContext(object):
gateway = launch_gateway()
jvm = gateway.jvm
pickleFile = jvm.spark.api.python.PythonRDD.pickleFile
asPickle = jvm.spark.api.python.PythonRDD.asPickle
arrayAsPickle = jvm.spark.api.python.PythonRDD.arrayAsPickle
readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile
def __init__(self, master, name, defaultParallelism=None):
self.master = master
......@@ -45,11 +44,11 @@ class SparkContext(object):
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().
tempFile = NamedTemporaryFile(delete=False)
atexit.register(lambda: os.unlink(tempFile.name))
for x in c:
write_with_length(dump_pickle(x), tempFile)
tempFile.close()
atexit.register(lambda: os.unlink(tempFile.name))
jrdd = self.pickleFile(self._jsc, tempFile.name, numSlices)
jrdd = self.readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
return RDD(jrdd, self)
def textFile(self, name, minSplits=None):
......
import atexit
from base64 import standard_b64encode as b64enc
from collections import defaultdict
from itertools import chain, ifilter, imap
import os
import shlex
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
from pyspark import cloudpickle
from pyspark.serializers import dump_pickle, load_pickle
from pyspark.serializers import dump_pickle, load_pickle, read_from_pickle_file
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
......@@ -145,10 +147,30 @@ class RDD(object):
self.map(f).collect() # Force evaluation
def collect(self):
# To minimize the number of transfers between Python and Java, we'll
# flatten each partition into a list before collecting it. Due to
# pipelining, this should add minimal overhead.
def asList(iterator):
yield list(iterator)
pickles = self.mapPartitions(asList)._jrdd.rdd().collect()
return list(chain.from_iterable(load_pickle(bytes(p)) for p in pickles))
picklesInJava = self.mapPartitions(asList)._jrdd.rdd().collect()
return list(chain.from_iterable(self._collect_array_through_file(picklesInJava)))
def _collect_array_through_file(self, array):
# Transferring lots of data through Py4J can be slow because
# socket.readline() is inefficient. Instead, we'll dump the data to a
# file and read it back.
tempFile = NamedTemporaryFile(delete=False)
tempFile.close()
def clean_up_file():
try: os.unlink(tempFile.name)
except: pass
atexit.register(clean_up_file)
self.ctx.writeArrayToPickleFile(array, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
for item in read_from_pickle_file(tempFile):
yield item
os.unlink(tempFile.name)
def reduce(self, f):
"""
......@@ -220,15 +242,15 @@ class RDD(object):
>>> sc.parallelize([2, 3, 4]).take(2)
[2, 3]
"""
pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num))
return load_pickle(bytes(pickle))
picklesInJava = self._jrdd.rdd().take(num)
return list(self._collect_array_through_file(picklesInJava))
def first(self):
"""
>>> sc.parallelize([2, 3, 4]).first()
2
"""
return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first())))
return self.take(1)[0]
def saveAsTextFile(self, path):
def func(iterator):
......
......@@ -33,3 +33,11 @@ def read_with_length(stream):
if obj == "":
raise EOFError
return obj
def read_from_pickle_file(stream):
try:
while True:
yield load_pickle(read_with_length(stream))
except EOFError:
return
......@@ -8,7 +8,7 @@ from base64 import standard_b64decode
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.serializers import write_with_length, read_with_length, \
read_long, read_int, dump_pickle, load_pickle
read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
# Redirect stdout to stderr so that users must return values from functions.
......@@ -20,14 +20,6 @@ def load_obj():
return load_pickle(standard_b64decode(sys.stdin.readline().strip()))
def read_input():
try:
while True:
yield load_pickle(read_with_length(sys.stdin))
except EOFError:
return
def main():
num_broadcast_variables = read_int(sys.stdin)
for _ in range(num_broadcast_variables):
......@@ -40,7 +32,7 @@ def main():
dumps = lambda x: x
else:
dumps = dump_pickle
for obj in func(read_input()):
for obj in func(read_from_pickle_file(sys.stdin)):
write_with_length(dumps(obj), old_stdout)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment