Skip to content
Snippets Groups Projects
Commit e211f405 authored by Josh Rosen's avatar Josh Rosen
Browse files

Use spark.local.dir for PySpark temp files (SPARK-580).

parent b6a60921
No related branches found
No related tags found
No related merge requests found
import os
import atexit
import shutil
import sys
import tempfile
from threading import Lock
from tempfile import NamedTemporaryFile
......@@ -94,6 +92,11 @@ class SparkContext(object):
SparkFiles._sc = self
sys.path.append(SparkFiles.getRootDirectory())
# Create a temporary directory inside spark.local.dir:
local_dir = self._jvm.spark.Utils.getLocalDir()
self._temp_dir = \
self._jvm.spark.Utils.createTempDir(local_dir).getAbsolutePath()
@property
def defaultParallelism(self):
"""
......@@ -126,8 +129,7 @@ class SparkContext(object):
# Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().
tempFile = NamedTemporaryFile(delete=False)
atexit.register(lambda: os.unlink(tempFile.name))
tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
if self.batchSize != 1:
c = batched(c, self.batchSize)
for x in c:
......@@ -247,7 +249,9 @@ class SparkContext(object):
def _test():
import atexit
import doctest
import tempfile
globs = globals().copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
globs['tempdir'] = tempfile.mkdtemp()
......
import atexit
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
......@@ -264,12 +263,8 @@ class RDD(object):
# Transferring lots of data through Py4J can be slow because
# socket.readline() is inefficient. Instead, we'll dump the data to a
# file and read it back.
tempFile = NamedTemporaryFile(delete=False)
tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
tempFile.close()
def clean_up_file():
try: os.unlink(tempFile.name)
except: pass
atexit.register(clean_up_file)
self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment