Skip to content
Snippets Groups Projects
Commit e211f405 authored by Josh Rosen's avatar Josh Rosen
Browse files

Use spark.local.dir for PySpark temp files (SPARK-580).

parent b6a60921
No related branches found
No related tags found
No related merge requests found
import os import os
import atexit
import shutil import shutil
import sys import sys
import tempfile
from threading import Lock from threading import Lock
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
...@@ -94,6 +92,11 @@ class SparkContext(object): ...@@ -94,6 +92,11 @@ class SparkContext(object):
SparkFiles._sc = self SparkFiles._sc = self
sys.path.append(SparkFiles.getRootDirectory()) sys.path.append(SparkFiles.getRootDirectory())
# Create a temporary directory inside spark.local.dir:
local_dir = self._jvm.spark.Utils.getLocalDir()
self._temp_dir = \
self._jvm.spark.Utils.createTempDir(local_dir).getAbsolutePath()
@property @property
def defaultParallelism(self): def defaultParallelism(self):
""" """
...@@ -126,8 +129,7 @@ class SparkContext(object): ...@@ -126,8 +129,7 @@ class SparkContext(object):
# Calling the Java parallelize() method with an ArrayList is too slow, # Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized # because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile(). # objects are written to a file and loaded through textFile().
tempFile = NamedTemporaryFile(delete=False) tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
atexit.register(lambda: os.unlink(tempFile.name))
if self.batchSize != 1: if self.batchSize != 1:
c = batched(c, self.batchSize) c = batched(c, self.batchSize)
for x in c: for x in c:
...@@ -247,7 +249,9 @@ class SparkContext(object): ...@@ -247,7 +249,9 @@ class SparkContext(object):
def _test(): def _test():
import atexit
import doctest import doctest
import tempfile
globs = globals().copy() globs = globals().copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
globs['tempdir'] = tempfile.mkdtemp() globs['tempdir'] = tempfile.mkdtemp()
......
import atexit
from base64 import standard_b64encode as b64enc from base64 import standard_b64encode as b64enc
import copy import copy
from collections import defaultdict from collections import defaultdict
...@@ -264,12 +263,8 @@ class RDD(object): ...@@ -264,12 +263,8 @@ class RDD(object):
# Transferring lots of data through Py4J can be slow because # Transferring lots of data through Py4J can be slow because
# socket.readline() is inefficient. Instead, we'll dump the data to a # socket.readline() is inefficient. Instead, we'll dump the data to a
# file and read it back. # file and read it back.
tempFile = NamedTemporaryFile(delete=False) tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
tempFile.close() tempFile.close()
def clean_up_file():
try: os.unlink(tempFile.name)
except: pass
atexit.register(clean_up_file)
self.ctx._writeIteratorToPickleFile(iterator, tempFile.name) self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
# Read the data into Python and deserialize it: # Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile: with open(tempFile.name, 'rb') as tempFile:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment