diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 45beb8fc8c9259cd3c4f8bb2428395c9b5347ef8..b80c771d58a8f9e99dca9a902a40987e9d47ba4b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -47,7 +47,7 @@ private[spark] class PythonRDD( pythonIncludes: JList[String], preservePartitoning: Boolean, pythonExec: String, - broadcastVars: JList[Broadcast[Array[Byte]]], + broadcastVars: JList[Broadcast[Array[Array[Byte]]]], accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { @@ -230,8 +230,8 @@ private[spark] class PythonRDD( if (!oldBids.contains(broadcast.id)) { // send new broadcast dataOut.writeLong(broadcast.id) - dataOut.writeInt(broadcast.value.length) - dataOut.write(broadcast.value) + dataOut.writeLong(broadcast.value.map(_.length.toLong).sum) + broadcast.value.foreach(dataOut.write) oldBids.add(broadcast.id) } } @@ -368,16 +368,24 @@ private[spark] object PythonRDD extends Logging { } } - def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = { + def readBroadcastFromFile( + sc: JavaSparkContext, + filename: String): Broadcast[Array[Array[Byte]]] = { + val size = new File(filename).length() val file = new DataInputStream(new FileInputStream(filename)) + val blockSize = 1 << 20 + val n = ((size + blockSize - 1) / blockSize).toInt + val obj = new Array[Array[Byte]](n) try { - val length = file.readInt() - val obj = new Array[Byte](length) - file.readFully(obj) - sc.broadcast(obj) + for (i <- 0 until n) { + val length = if (i < (n - 1)) blockSize else (size % blockSize).toInt + obj(i) = new Array[Byte](length) + file.readFully(obj(i)) + } } finally { file.close() } + sc.broadcast(obj) } def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index f124dc6c07575627aa814b8db428c6ae4e22df1c..01cac3c72c690f7555b4237cc7a654ed305c2fcf 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -29,7 +29,7 @@ """ import os -from pyspark.serializers import CompressedSerializer, PickleSerializer +from pyspark.serializers import LargeObjectSerializer __all__ = ['Broadcast'] @@ -73,7 +73,7 @@ class Broadcast(object): """ Return the broadcasted value """ if not hasattr(self, "_value") and self.path is not None: - ser = CompressedSerializer(PickleSerializer()) + ser = LargeObjectSerializer() self._value = ser.load_stream(open(self.path)).next() return self._value diff --git a/python/pyspark/context.py b/python/pyspark/context.py index b6c991453d4defb544d97d2a644440e42546ac6e..ec67ec8d0f824bf77ad6daffa04d2ce7678f98f5 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.conf import SparkConf from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer, CompressedSerializer, AutoBatchedSerializer, NoOpSerializer + PairDeserializer, AutoBatchedSerializer, NoOpSerializer, LargeObjectSerializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -624,7 +624,8 @@ class SparkContext(object): object for reading it in distributed functions. The variable will be sent to each cluster only once. """ - ser = CompressedSerializer(PickleSerializer()) + ser = LargeObjectSerializer() + # pass large object by py4j is very slow and need much memory tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) ser.dump_stream([value], tempFile) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d597cbf94e1b1fc6623f92d4e8d44fd54b662e45..760a509f0ef6d131aa508111f53cc19647be947a 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -133,6 +133,8 @@ class FramedSerializer(Serializer): def _write_with_length(self, obj, stream): serialized = self.dumps(obj) + if len(serialized) > (1 << 31): + raise ValueError("can not serialize object larger than 2G") write_int(len(serialized), stream) if self._only_write_strings: stream.write(str(serialized)) @@ -446,20 +448,184 @@ class AutoSerializer(FramedSerializer): raise ValueError("invalid sevialization type: %s" % _type) -class CompressedSerializer(FramedSerializer): +class SizeLimitedStream(object): """ - Compress the serialized data + Read at most `limit` bytes from underlying stream + + >>> from StringIO import StringIO + >>> io = StringIO() + >>> io.write("Hello world") + >>> io.seek(0) + >>> lio = SizeLimitedStream(io, 5) + >>> lio.read() + 'Hello' + """ + def __init__(self, stream, limit): + self.stream = stream + self.limit = limit + + def read(self, n=0): + if n > self.limit or n == 0: + n = self.limit + buf = self.stream.read(n) + self.limit -= len(buf) + return buf + + +class CompressedStream(object): + """ + Compress the data using zlib + + >>> from StringIO import StringIO + >>> io = StringIO() + >>> wio = CompressedStream(io, 'w') + >>> wio.write("Hello world") + >>> wio.flush() + >>> io.seek(0) + >>> rio = CompressedStream(io, 'r') + >>> rio.read() + 'Hello world' + >>> rio.read() + '' + """ + MAX_BATCH = 1 << 20 # 1MB + + def __init__(self, stream, mode='w', level=1): + self.stream = stream + self.mode = mode + if mode == 'w': + self.compresser = zlib.compressobj(level) + elif mode == 'r': + self.decompresser = zlib.decompressobj() + self.buf = '' + else: + raise ValueError("can only support mode 'w' or 'r' ") + + def write(self, buf): + assert self.mode == 'w', "It's not opened for write" + if len(buf) > self.MAX_BATCH: + # zlib can not compress string larger than 2G + batches = len(buf) / self.MAX_BATCH + 1 # last one may be empty + for i in xrange(batches): + self.write(buf[i * self.MAX_BATCH:(i + 1) * self.MAX_BATCH]) + else: + compressed = self.compresser.compress(buf) + self.stream.write(compressed) + + def flush(self, mode=zlib.Z_FULL_FLUSH): + if self.mode == 'w': + d = self.compresser.flush(mode) + self.stream.write(d) + self.stream.flush() + + def close(self): + if self.mode == 'w': + self.flush(zlib.Z_FINISH) + self.stream.close() + + def read(self, size=0): + assert self.mode == 'r', "It's not opened for read" + if not size: + data = self.stream.read() + result = self.decompresser.decompress(data) + last = self.decompresser.flush() + return self.buf + result + last + + # fast path for small read() + if size <= len(self.buf): + result = self.buf[:size] + self.buf = self.buf[size:] + return result + + result = [self.buf] + size -= len(self.buf) + self.buf = '' + while size: + need = min(size, self.MAX_BATCH) + input = self.stream.read(need) + if input: + buf = self.decompresser.decompress(input) + else: + buf = self.decompresser.flush() + + if len(buf) >= size: + self.buf = buf[size:] + result.append(buf[:size]) + return ''.join(result) + + size -= len(buf) + result.append(buf) + if not input: + return ''.join(result) + + def readline(self): + """ + This is needed for pickle, but not used in protocol 2 + """ + line = [] + b = self.read(1) + while b and b != '\n': + line.append(b) + b = self.read(1) + line.append(b) + return ''.join(line) + + +class LargeObjectSerializer(Serializer): + """ + Serialize large object which could be larger than 2G + + It uses cPickle to serialize the objects """ + def dump_stream(self, iterator, stream): + stream = CompressedStream(stream, 'w') + for value in iterator: + if isinstance(value, basestring): + if isinstance(value, unicode): + stream.write('U') + value = value.encode("utf-8") + else: + stream.write('S') + write_long(len(value), stream) + stream.write(value) + else: + stream.write('P') + cPickle.dump(value, stream, 2) + stream.flush() + def load_stream(self, stream): + stream = CompressedStream(stream, 'r') + while True: + type = stream.read(1) + if not type: + return + if type in ('S', 'U'): + length = read_long(stream) + value = stream.read(length) + if type == 'U': + value = value.decode('utf-8') + yield value + elif type == 'P': + yield cPickle.load(stream) + else: + raise ValueError("unknown type: %s" % type) + + +class CompressedSerializer(Serializer): + """ + Compress the serialized data + """ def __init__(self, serializer): - FramedSerializer.__init__(self) self.serializer = serializer - def dumps(self, obj): - return zlib.compress(self.serializer.dumps(obj), 1) + def load_stream(self, stream): + stream = CompressedStream(stream, "r") + return self.serializer.load_stream(stream) - def loads(self, obj): - return self.serializer.loads(zlib.decompress(obj)) + def dump_stream(self, iterator, stream): + stream = CompressedStream(stream, "w") + self.serializer.dump_stream(iterator, stream) + stream.flush() class UTF8Deserializer(Serializer): @@ -517,3 +683,8 @@ def write_int(value, stream): def write_with_length(obj, stream): write_int(len(obj), stream) stream.write(obj) + + +if __name__ == '__main__': + import doctest + doctest.testmod() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 491e445a216bf65bf0a02f2827007dc97899ffe8..a01bd8d4157874fd36e8f2a73c6503f489076646 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -32,6 +32,7 @@ import time import zipfile import random import threading +import hashlib if sys.version_info[:2] <= (2, 6): try: @@ -47,7 +48,7 @@ from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ - CloudPickleSerializer + CloudPickleSerializer, SizeLimitedStream, CompressedSerializer, LargeObjectSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ UserDefinedType, DoubleType @@ -236,6 +237,27 @@ class SerializationTestCase(unittest.TestCase): self.assertTrue("exit" in foo.func_code.co_names) ser.dumps(foo) + def _test_serializer(self, ser): + from StringIO import StringIO + io = StringIO() + ser.dump_stream(["abc", u"123", range(5)], io) + io.seek(0) + self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) + size = io.tell() + ser.dump_stream(range(1000), io) + io.seek(0) + first = SizeLimitedStream(io, size) + self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(first))) + self.assertEqual(range(1000), list(ser.load_stream(io))) + + def test_compressed_serializer(self): + ser = CompressedSerializer(PickleSerializer()) + self._test_serializer(ser) + + def test_large_object_serializer(self): + ser = LargeObjectSerializer() + self._test_serializer(ser) + class PySparkTestCase(unittest.TestCase): @@ -440,7 +462,7 @@ class RDDTests(ReusedPySparkTestCase): subset = data.takeSample(False, 10) self.assertEqual(len(subset), 10) - def testAggregateByKey(self): + def test_aggregate_by_key(self): data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) def seqOp(x, y): @@ -478,6 +500,32 @@ class RDDTests(ReusedPySparkTestCase): m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() self.assertEquals(N, m) + def test_multiple_broadcasts(self): + N = 1 << 21 + b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM + r = range(1 << 15) + random.shuffle(r) + s = str(r) + checksum = hashlib.md5(s).hexdigest() + b2 = self.sc.broadcast(s) + r = list(set(self.sc.parallelize(range(10), 10).map( + lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) + self.assertEqual(1, len(r)) + size, csum = r[0] + self.assertEqual(N, size) + self.assertEqual(checksum, csum) + + random.shuffle(r) + s = str(r) + checksum = hashlib.md5(s).hexdigest() + b2 = self.sc.broadcast(s) + r = list(set(self.sc.parallelize(range(10), 10).map( + lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) + self.assertEqual(1, len(r)) + size, csum = r[0] + self.assertEqual(N, size) + self.assertEqual(checksum, csum) + def test_large_closure(self): N = 1000000 data = [float(i) for i in xrange(N)] diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 2bdccb5e93f0953fc68f8a057e9dd8311ea2eda0..e1552a0b0b4ff795af292cd81e69927e63d21309 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -31,7 +31,7 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ - CompressedSerializer + SizeLimitedStream, LargeObjectSerializer from pyspark import shuffle pickleSer = PickleSerializer() @@ -78,11 +78,13 @@ def main(infile, outfile): # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) - ser = CompressedSerializer(pickleSer) + bser = LargeObjectSerializer() for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: - value = ser._read_with_length(infile) + size = read_long(infile) + s = SizeLimitedStream(infile, size) + value = list((bser.load_stream(s)))[0] # read out all the bytes _broadcastRegistry[bid] = Broadcast(bid, value) else: bid = - bid - 1 diff --git a/python/run-tests b/python/run-tests index e66854b44dfa6926b68aaeca4252041a5b6a5145..9ee19ed6e6b2654b003bab9b06e165842d64ed0a 100755 --- a/python/run-tests +++ b/python/run-tests @@ -56,7 +56,7 @@ function run_core_tests() { run_test "pyspark/conf.py" PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" - PYSPARK_DOC_TEST=1 run_test "pyspark/serializers.py" + run_test "pyspark/serializers.py" run_test "pyspark/shuffle.py" run_test "pyspark/tests.py" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala index 6d4c0d82ac7af3e3f6ddf0fd5f6e5f4600a1df3a..ddcb5db6c3a213e6f969660fb4a96eeca537de5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -39,7 +39,7 @@ private[sql] trait UDFRegistration { envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, - broadcastVars: JList[Broadcast[Array[Byte]]], + broadcastVars: JList[Broadcast[Array[Array[Byte]]]], accumulator: Accumulator[JList[Array[Byte]]], stringDataType: String): Unit = { log.debug( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index a83cf5d441d1e395c189bfcbb4383651609bd532..f98cae3f17e4a26f689ff01c729083a9e3f57d8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -45,7 +45,7 @@ private[spark] case class PythonUDF( envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, - broadcastVars: JList[Broadcast[Array[Byte]]], + broadcastVars: JList[Broadcast[Array[Array[Byte]]]], accumulator: Accumulator[JList[Array[Byte]]], dataType: DataType, children: Seq[Expression]) extends Expression with SparkLogging {