From 00288ea2a463180e91fd16c8e2b627e69566e1f0 Mon Sep 17 00:00:00 2001
From: Holden Karau <holden@us.ibm.com>
Date: Sun, 10 Apr 2016 02:34:54 +0100
Subject: [PATCH] [SPARK-13687][PYTHON] Cleanup PySpark parallelize temporary
 files

## What changes were proposed in this pull request?

Eagerly cleanup PySpark's temporary parallelize cleanup files rather than waiting for shut down.

## How was this patch tested?

Unit tests

Author: Holden Karau <holden@us.ibm.com>

Closes #12233 from holdenk/SPARK-13687-cleanup-pyspark-temporary-files.
---
 python/pyspark/context.py | 22 +++++++++++++---------
 python/pyspark/tests.py   |  7 +++++++
 2 files changed, 20 insertions(+), 9 deletions(-)

diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 529d16b480..cb15b4b91f 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -428,15 +428,19 @@ class SparkContext(object):
         # because it sends O(n) Py4J commands.  As an alternative, serialized
         # objects are written to a file and loaded through textFile().
         tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
-        # Make sure we distribute data evenly if it's smaller than self.batchSize
-        if "__len__" not in dir(c):
-            c = list(c)    # Make it a list so we can compute its length
-        batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
-        serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
-        serializer.dump_stream(c, tempFile)
-        tempFile.close()
-        readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
-        jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
+        try:
+            # Make sure we distribute data evenly if it's smaller than self.batchSize
+            if "__len__" not in dir(c):
+                c = list(c)    # Make it a list so we can compute its length
+            batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
+            serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
+            serializer.dump_stream(c, tempFile)
+            tempFile.close()
+            readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
+            jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
+        finally:
+            # readRDDFromFile eagerily reads the file so we can delete right after.
+            os.unlink(tempFile.name)
         return RDD(jrdd, self, serializer)
 
     def pickleFile(self, name, minPartitions=None):
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 15c87e22f9..97ea39dde0 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1914,6 +1914,13 @@ class ContextTests(unittest.TestCase):
         with SparkContext.getOrCreate() as sc:
             self.assertTrue(SparkContext.getOrCreate() is sc)
 
+    def test_parallelize_eager_cleanup(self):
+        with SparkContext() as sc:
+            temp_files = os.listdir(sc._temp_dir)
+            rdd = sc.parallelize([0, 1, 2])
+            post_parallalize_temp_files = os.listdir(sc._temp_dir)
+            self.assertEqual(temp_files, post_parallalize_temp_files)
+
     def test_stop(self):
         sc = SparkContext()
         self.assertNotEqual(SparkContext._active_spark_context, None)
-- 
GitLab