diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f9ff4ea6ca157b5e8bd82db47c089162f3d29108..924141475383db7bf6b2c07c6c05ffa7051b6ac1 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -339,26 +339,34 @@ private[spark] object PythonRDD extends Logging { def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) - val objs = new collection.mutable.ArrayBuffer[Array[Byte]] try { - while (true) { - val length = file.readInt() - val obj = new Array[Byte](length) - file.readFully(obj) - objs.append(obj) + val objs = new collection.mutable.ArrayBuffer[Array[Byte]] + try { + while (true) { + val length = file.readInt() + val obj = new Array[Byte](length) + file.readFully(obj) + objs.append(obj) + } + } catch { + case eof: EOFException => {} } - } catch { - case eof: EOFException => {} + JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) + } finally { + file.close() } - JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) - val length = file.readInt() - val obj = new Array[Byte](length) - file.readFully(obj) - sc.broadcast(obj) + try { + val length = file.readInt() + val obj = new Array[Byte](length) + file.readFully(obj) + sc.broadcast(obj) + } finally { + file.close() + } } def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 8ed89e2f9769f8695051aeb2651f3bc0aeabec4f..dc6497772e50294e413b14b6c08139b749d5ba56 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2073,6 +2073,12 @@ class PipelinedRDD(RDD): self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None + self._broadcast = None + + def __del__(self): + if self._broadcast: + self._broadcast.unpersist() + self._broadcast = None @property def _jrdd(self): @@ -2087,9 +2093,9 @@ class PipelinedRDD(RDD): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() pickled_command = ser.dumps(command) - if pickled_command > (1 << 20): # 1M - broadcast = self.ctx.broadcast(pickled_command) - pickled_command = ser.dumps(broadcast) + if len(pickled_command) > (1 << 20): # 1M + self._broadcast = self.ctx.broadcast(pickled_command) + pickled_command = ser.dumps(self._broadcast) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index d8bdf22355ec8b538951d14b9d7168bbd6527d21..974b5e287bc0063f2389e19aa573933559855507 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -965,7 +965,7 @@ class SQLContext(object): BatchedSerializer(PickleSerializer(), 1024)) ser = CloudPickleSerializer() pickled_command = ser.dumps(command) - if pickled_command > (1 << 20): # 1M + if len(pickled_command) > (1 << 20): # 1M broadcast = self._sc.broadcast(pickled_command) pickled_command = ser.dumps(broadcast) broadcast_vars = ListConverter().convert( diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7e2bbc9cb617ffe91acffad8b206dee576a6c5d1..6fb6bc998c752e8c217ab402e5bade52dc4634ac 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -467,8 +467,12 @@ class TestRDDFunctions(PySparkTestCase): def test_large_closure(self): N = 1000000 data = [float(i) for i in xrange(N)] - m = self.sc.parallelize(range(1), 1).map(lambda x: len(data)).sum() - self.assertEquals(N, m) + rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data)) + self.assertEquals(N, rdd.first()) + self.assertTrue(rdd._broadcast is not None) + rdd = self.sc.parallelize(range(1), 1).map(lambda x: 1) + self.assertEqual(1, rdd.first()) + self.assertTrue(rdd._broadcast is None) def test_zip_with_different_serializers(self): a = self.sc.parallelize(range(5))