From 4a377aff2d36b64a65b54192a987aba44b8f78e0 Mon Sep 17 00:00:00 2001
From: Davies Liu <davies@databricks.com>
Date: Tue, 18 Nov 2014 16:17:51 -0800
Subject: [PATCH] [SPARK-3721] [PySpark] broadcast objects larger than 2G

This patch will bring support for broadcasting objects larger than 2G.

pickle, zlib, FrameSerializer and Array[Byte] all can not support objects larger than 2G, so this patch introduce LargeObjectSerializer to serialize broadcast objects, the object will be serialized and compressed into small chunks, it also change the type of Broadcast[Array[Byte]]] into Broadcast[Array[Array[Byte]]]].

Testing for support broadcast objects larger than 2G is slow and memory hungry, so this is tested manually, could be added into SparkPerf.

Author: Davies Liu <davies@databricks.com>
Author: Davies Liu <davies.liu@gmail.com>

Closes #2659 from davies/huge and squashes the following commits:

7b57a14 [Davies Liu] add more tests for broadcast
28acff9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into huge
a2f6a02 [Davies Liu] bug fix
4820613 [Davies Liu] Merge branch 'master' of github.com:apache/spark into huge
5875c73 [Davies Liu] address comments
10a349b [Davies Liu] address comments
0c33016 [Davies Liu] Merge branch 'master' of github.com:apache/spark into huge
6182c8f [Davies Liu] Merge branch 'master' into huge
d94b68f [Davies Liu] Merge branch 'master' of github.com:apache/spark into huge
2514848 [Davies Liu] address comments
fda395b [Davies Liu] Merge branch 'master' of github.com:apache/spark into huge
1c2d928 [Davies Liu] fix scala style
091b107 [Davies Liu] broadcast objects larger than 2G
---
 .../apache/spark/api/python/PythonRDD.scala   |  24 ++-
 python/pyspark/broadcast.py                   |   4 +-
 python/pyspark/context.py                     |   5 +-
 python/pyspark/serializers.py                 | 185 +++++++++++++++++-
 python/pyspark/tests.py                       |  52 ++++-
 python/pyspark/worker.py                      |   8 +-
 python/run-tests                              |   2 +-
 .../apache/spark/sql/UdfRegistration.scala    |   2 +-
 .../spark/sql/execution/pythonUdfs.scala      |   2 +-
 9 files changed, 257 insertions(+), 27 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 45beb8fc8c..b80c771d58 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -47,7 +47,7 @@ private[spark] class PythonRDD(
     pythonIncludes: JList[String],
     preservePartitoning: Boolean,
     pythonExec: String,
-    broadcastVars: JList[Broadcast[Array[Byte]]],
+    broadcastVars: JList[Broadcast[Array[Array[Byte]]]],
     accumulator: Accumulator[JList[Array[Byte]]])
   extends RDD[Array[Byte]](parent) {
 
@@ -230,8 +230,8 @@ private[spark] class PythonRDD(
           if (!oldBids.contains(broadcast.id)) {
             // send new broadcast
             dataOut.writeLong(broadcast.id)
-            dataOut.writeInt(broadcast.value.length)
-            dataOut.write(broadcast.value)
+            dataOut.writeLong(broadcast.value.map(_.length.toLong).sum)
+            broadcast.value.foreach(dataOut.write)
             oldBids.add(broadcast.id)
           }
         }
@@ -368,16 +368,24 @@ private[spark] object PythonRDD extends Logging {
     }
   }
 
-  def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = {
+  def readBroadcastFromFile(
+      sc: JavaSparkContext,
+      filename: String): Broadcast[Array[Array[Byte]]] = {
+    val size = new File(filename).length()
     val file = new DataInputStream(new FileInputStream(filename))
+    val blockSize = 1 << 20
+    val n = ((size + blockSize - 1) / blockSize).toInt
+    val obj = new Array[Array[Byte]](n)
     try {
-      val length = file.readInt()
-      val obj = new Array[Byte](length)
-      file.readFully(obj)
-      sc.broadcast(obj)
+      for (i <- 0 until n) {
+        val length = if (i < (n - 1)) blockSize else (size % blockSize).toInt
+        obj(i) = new Array[Byte](length)
+        file.readFully(obj(i))
+      }
     } finally {
       file.close()
     }
+    sc.broadcast(obj)
   }
 
   def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index f124dc6c07..01cac3c72c 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -29,7 +29,7 @@
 """
 import os
 
-from pyspark.serializers import CompressedSerializer, PickleSerializer
+from pyspark.serializers import LargeObjectSerializer
 
 
 __all__ = ['Broadcast']
@@ -73,7 +73,7 @@ class Broadcast(object):
         """ Return the broadcasted value
         """
         if not hasattr(self, "_value") and self.path is not None:
-            ser = CompressedSerializer(PickleSerializer())
+            ser = LargeObjectSerializer()
             self._value = ser.load_stream(open(self.path)).next()
         return self._value
 
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index b6c991453d..ec67ec8d0f 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -29,7 +29,7 @@ from pyspark.conf import SparkConf
 from pyspark.files import SparkFiles
 from pyspark.java_gateway import launch_gateway
 from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
-    PairDeserializer, CompressedSerializer, AutoBatchedSerializer, NoOpSerializer
+    PairDeserializer, AutoBatchedSerializer, NoOpSerializer, LargeObjectSerializer
 from pyspark.storagelevel import StorageLevel
 from pyspark.rdd import RDD
 from pyspark.traceback_utils import CallSite, first_spark_call
@@ -624,7 +624,8 @@ class SparkContext(object):
         object for reading it in distributed functions. The variable will
         be sent to each cluster only once.
         """
-        ser = CompressedSerializer(PickleSerializer())
+        ser = LargeObjectSerializer()
+
         # pass large object by py4j is very slow and need much memory
         tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
         ser.dump_stream([value], tempFile)
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index d597cbf94e..760a509f0e 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -133,6 +133,8 @@ class FramedSerializer(Serializer):
 
     def _write_with_length(self, obj, stream):
         serialized = self.dumps(obj)
+        if len(serialized) > (1 << 31):
+            raise ValueError("can not serialize object larger than 2G")
         write_int(len(serialized), stream)
         if self._only_write_strings:
             stream.write(str(serialized))
@@ -446,20 +448,184 @@ class AutoSerializer(FramedSerializer):
             raise ValueError("invalid sevialization type: %s" % _type)
 
 
-class CompressedSerializer(FramedSerializer):
+class SizeLimitedStream(object):
     """
-    Compress the serialized data
+    Read at most `limit` bytes from underlying stream
+
+    >>> from StringIO import StringIO
+    >>> io = StringIO()
+    >>> io.write("Hello world")
+    >>> io.seek(0)
+    >>> lio = SizeLimitedStream(io, 5)
+    >>> lio.read()
+    'Hello'
+    """
+    def __init__(self, stream, limit):
+        self.stream = stream
+        self.limit = limit
+
+    def read(self, n=0):
+        if n > self.limit or n == 0:
+            n = self.limit
+        buf = self.stream.read(n)
+        self.limit -= len(buf)
+        return buf
+
+
+class CompressedStream(object):
+    """
+    Compress the data using zlib
+
+    >>> from StringIO import StringIO
+    >>> io = StringIO()
+    >>> wio = CompressedStream(io, 'w')
+    >>> wio.write("Hello world")
+    >>> wio.flush()
+    >>> io.seek(0)
+    >>> rio = CompressedStream(io, 'r')
+    >>> rio.read()
+    'Hello world'
+    >>> rio.read()
+    ''
+    """
+    MAX_BATCH = 1 << 20  # 1MB
+
+    def __init__(self, stream, mode='w', level=1):
+        self.stream = stream
+        self.mode = mode
+        if mode == 'w':
+            self.compresser = zlib.compressobj(level)
+        elif mode == 'r':
+            self.decompresser = zlib.decompressobj()
+            self.buf = ''
+        else:
+            raise ValueError("can only support mode 'w' or 'r' ")
+
+    def write(self, buf):
+        assert self.mode == 'w', "It's not opened for write"
+        if len(buf) > self.MAX_BATCH:
+            # zlib can not compress string larger than 2G
+            batches = len(buf) / self.MAX_BATCH + 1  # last one may be empty
+            for i in xrange(batches):
+                self.write(buf[i * self.MAX_BATCH:(i + 1) * self.MAX_BATCH])
+        else:
+            compressed = self.compresser.compress(buf)
+            self.stream.write(compressed)
+
+    def flush(self, mode=zlib.Z_FULL_FLUSH):
+        if self.mode == 'w':
+            d = self.compresser.flush(mode)
+            self.stream.write(d)
+            self.stream.flush()
+
+    def close(self):
+        if self.mode == 'w':
+            self.flush(zlib.Z_FINISH)
+            self.stream.close()
+
+    def read(self, size=0):
+        assert self.mode == 'r', "It's not opened for read"
+        if not size:
+            data = self.stream.read()
+            result = self.decompresser.decompress(data)
+            last = self.decompresser.flush()
+            return self.buf + result + last
+
+        # fast path for small read()
+        if size <= len(self.buf):
+            result = self.buf[:size]
+            self.buf = self.buf[size:]
+            return result
+
+        result = [self.buf]
+        size -= len(self.buf)
+        self.buf = ''
+        while size:
+            need = min(size, self.MAX_BATCH)
+            input = self.stream.read(need)
+            if input:
+                buf = self.decompresser.decompress(input)
+            else:
+                buf = self.decompresser.flush()
+
+            if len(buf) >= size:
+                self.buf = buf[size:]
+                result.append(buf[:size])
+                return ''.join(result)
+
+            size -= len(buf)
+            result.append(buf)
+            if not input:
+                return ''.join(result)
+
+    def readline(self):
+        """
+        This is needed for pickle, but not used in protocol 2
+        """
+        line = []
+        b = self.read(1)
+        while b and b != '\n':
+            line.append(b)
+            b = self.read(1)
+        line.append(b)
+        return ''.join(line)
+
+
+class LargeObjectSerializer(Serializer):
+    """
+    Serialize large object which could be larger than 2G
+
+    It uses cPickle to serialize the objects
     """
+    def dump_stream(self, iterator, stream):
+        stream = CompressedStream(stream, 'w')
+        for value in iterator:
+            if isinstance(value, basestring):
+                if isinstance(value, unicode):
+                    stream.write('U')
+                    value = value.encode("utf-8")
+                else:
+                    stream.write('S')
+                write_long(len(value), stream)
+                stream.write(value)
+            else:
+                stream.write('P')
+                cPickle.dump(value, stream, 2)
+        stream.flush()
 
+    def load_stream(self, stream):
+        stream = CompressedStream(stream, 'r')
+        while True:
+            type = stream.read(1)
+            if not type:
+                return
+            if type in ('S', 'U'):
+                length = read_long(stream)
+                value = stream.read(length)
+                if type == 'U':
+                    value = value.decode('utf-8')
+                yield value
+            elif type == 'P':
+                yield cPickle.load(stream)
+            else:
+                raise ValueError("unknown type: %s" % type)
+
+
+class CompressedSerializer(Serializer):
+    """
+    Compress the serialized data
+    """
     def __init__(self, serializer):
-        FramedSerializer.__init__(self)
         self.serializer = serializer
 
-    def dumps(self, obj):
-        return zlib.compress(self.serializer.dumps(obj), 1)
+    def load_stream(self, stream):
+        stream = CompressedStream(stream, "r")
+        return self.serializer.load_stream(stream)
 
-    def loads(self, obj):
-        return self.serializer.loads(zlib.decompress(obj))
+    def dump_stream(self, iterator, stream):
+        stream = CompressedStream(stream, "w")
+        self.serializer.dump_stream(iterator, stream)
+        stream.flush()
 
 
 class UTF8Deserializer(Serializer):
@@ -517,3 +683,8 @@ def write_int(value, stream):
 def write_with_length(obj, stream):
     write_int(len(obj), stream)
     stream.write(obj)
+
+
+if __name__ == '__main__':
+    import doctest
+    doctest.testmod()
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 491e445a21..a01bd8d415 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -32,6 +32,7 @@ import time
 import zipfile
 import random
 import threading
+import hashlib
 
 if sys.version_info[:2] <= (2, 6):
     try:
@@ -47,7 +48,7 @@ from pyspark.conf import SparkConf
 from pyspark.context import SparkContext
 from pyspark.files import SparkFiles
 from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
-    CloudPickleSerializer
+    CloudPickleSerializer, SizeLimitedStream, CompressedSerializer, LargeObjectSerializer
 from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
 from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
     UserDefinedType, DoubleType
@@ -236,6 +237,27 @@ class SerializationTestCase(unittest.TestCase):
         self.assertTrue("exit" in foo.func_code.co_names)
         ser.dumps(foo)
 
+    def _test_serializer(self, ser):
+        from StringIO import StringIO
+        io = StringIO()
+        ser.dump_stream(["abc", u"123", range(5)], io)
+        io.seek(0)
+        self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io)))
+        size = io.tell()
+        ser.dump_stream(range(1000), io)
+        io.seek(0)
+        first = SizeLimitedStream(io, size)
+        self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(first)))
+        self.assertEqual(range(1000), list(ser.load_stream(io)))
+
+    def test_compressed_serializer(self):
+        ser = CompressedSerializer(PickleSerializer())
+        self._test_serializer(ser)
+
+    def test_large_object_serializer(self):
+        ser = LargeObjectSerializer()
+        self._test_serializer(ser)
+
 
 class PySparkTestCase(unittest.TestCase):
 
@@ -440,7 +462,7 @@ class RDDTests(ReusedPySparkTestCase):
         subset = data.takeSample(False, 10)
         self.assertEqual(len(subset), 10)
 
-    def testAggregateByKey(self):
+    def test_aggregate_by_key(self):
         data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)
 
         def seqOp(x, y):
@@ -478,6 +500,32 @@ class RDDTests(ReusedPySparkTestCase):
         m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
         self.assertEquals(N, m)
 
+    def test_multiple_broadcasts(self):
+        N = 1 << 21
+        b1 = self.sc.broadcast(set(range(N)))  # multiple blocks in JVM
+        r = range(1 << 15)
+        random.shuffle(r)
+        s = str(r)
+        checksum = hashlib.md5(s).hexdigest()
+        b2 = self.sc.broadcast(s)
+        r = list(set(self.sc.parallelize(range(10), 10).map(
+            lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
+        self.assertEqual(1, len(r))
+        size, csum = r[0]
+        self.assertEqual(N, size)
+        self.assertEqual(checksum, csum)
+
+        random.shuffle(r)
+        s = str(r)
+        checksum = hashlib.md5(s).hexdigest()
+        b2 = self.sc.broadcast(s)
+        r = list(set(self.sc.parallelize(range(10), 10).map(
+            lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
+        self.assertEqual(1, len(r))
+        size, csum = r[0]
+        self.assertEqual(N, size)
+        self.assertEqual(checksum, csum)
+
     def test_large_closure(self):
         N = 1000000
         data = [float(i) for i in xrange(N)]
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 2bdccb5e93..e1552a0b0b 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -31,7 +31,7 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.files import SparkFiles
 from pyspark.serializers import write_with_length, write_int, read_long, \
     write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
-    CompressedSerializer
+    SizeLimitedStream, LargeObjectSerializer
 from pyspark import shuffle
 
 pickleSer = PickleSerializer()
@@ -78,11 +78,13 @@ def main(infile, outfile):
 
         # fetch names and values of broadcast variables
         num_broadcast_variables = read_int(infile)
-        ser = CompressedSerializer(pickleSer)
+        bser = LargeObjectSerializer()
         for _ in range(num_broadcast_variables):
             bid = read_long(infile)
             if bid >= 0:
-                value = ser._read_with_length(infile)
+                size = read_long(infile)
+                s = SizeLimitedStream(infile, size)
+                value = list((bser.load_stream(s)))[0]  # read out all the bytes
                 _broadcastRegistry[bid] = Broadcast(bid, value)
             else:
                 bid = - bid - 1
diff --git a/python/run-tests b/python/run-tests
index e66854b44d..9ee19ed6e6 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -56,7 +56,7 @@ function run_core_tests() {
     run_test "pyspark/conf.py"
     PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
     PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
-    PYSPARK_DOC_TEST=1 run_test "pyspark/serializers.py"
+    run_test "pyspark/serializers.py"
     run_test "pyspark/shuffle.py"
     run_test "pyspark/tests.py"
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
index 6d4c0d82ac..ddcb5db6c3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
@@ -39,7 +39,7 @@ private[sql] trait UDFRegistration {
       envVars: JMap[String, String],
       pythonIncludes: JList[String],
       pythonExec: String,
-      broadcastVars: JList[Broadcast[Array[Byte]]],
+      broadcastVars: JList[Broadcast[Array[Array[Byte]]]],
       accumulator: Accumulator[JList[Array[Byte]]],
       stringDataType: String): Unit = {
     log.debug(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index a83cf5d441..f98cae3f17 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -45,7 +45,7 @@ private[spark] case class PythonUDF(
     envVars: JMap[String, String],
     pythonIncludes: JList[String],
     pythonExec: String,
-    broadcastVars: JList[Broadcast[Array[Byte]]],
+    broadcastVars: JList[Broadcast[Array[Array[Byte]]]],
     accumulator: Accumulator[JList[Array[Byte]]],
     dataType: DataType,
     children: Seq[Expression]) extends Expression with SparkLogging {
-- 
GitLab