Skip to content
Snippets Groups Projects
Commit ae2ed294 authored by Josh Rosen's avatar Josh Rosen
Browse files

Allow PySpark's SparkFiles to be used from driver

Fix minor documentation formatting issues.
parent 35168d9c
No related branches found
No related tags found
No related merge requests found
......@@ -3,23 +3,23 @@ package spark;
import java.io.File;
/**
* Resolves paths to files added through `addFile().
* Resolves paths to files added through `SparkContext.addFile()`.
*/
public class SparkFiles {
private SparkFiles() {}
/**
* Get the absolute path of a file added through `addFile()`.
* Get the absolute path of a file added through `SparkContext.addFile()`.
*/
public static String get(String filename) {
return new File(getRootDirectory(), filename).getAbsolutePath();
}
/**
* Get the root directory that contains files added through `addFile()`.
* Get the root directory that contains files added through `SparkContext.addFile()`.
*/
public static String getRootDirectory() {
return SparkEnv.get().sparkFilesDir();
}
}
\ No newline at end of file
}
import os
import atexit
import shutil
import sys
import tempfile
from threading import Lock
from tempfile import NamedTemporaryFile
from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import dump_pickle, write_with_length, batched
from pyspark.rdd import RDD
......@@ -27,6 +30,8 @@ class SparkContext(object):
_writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
_takePartition = jvm.PythonRDD.takePartition
_next_accum_id = 0
_active_spark_context = None
_lock = Lock()
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024):
......@@ -46,6 +51,11 @@ class SparkContext(object):
Java object. Set 1 to disable batching or -1 to use an
unlimited batch size.
"""
with SparkContext._lock:
if SparkContext._active_spark_context:
raise ValueError("Cannot run multiple SparkContexts at once")
else:
SparkContext._active_spark_context = self
self.master = master
self.jobName = jobName
self.sparkHome = sparkHome or None # None becomes null in Py4J
......@@ -75,6 +85,8 @@ class SparkContext(object):
# Deploy any code dependencies specified in the constructor
for path in (pyFiles or []):
self.addPyFile(path)
SparkFiles._sc = self
sys.path.append(SparkFiles.getRootDirectory())
@property
def defaultParallelism(self):
......@@ -85,17 +97,20 @@ class SparkContext(object):
return self._jsc.sc().defaultParallelism()
def __del__(self):
if self._jsc:
self._jsc.stop()
if self._accumulatorServer:
self._accumulatorServer.shutdown()
self.stop()
def stop(self):
"""
Shut down the SparkContext.
"""
self._jsc.stop()
self._jsc = None
if self._jsc:
self._jsc.stop()
self._jsc = None
if self._accumulatorServer:
self._accumulatorServer.shutdown()
self._accumulatorServer = None
with SparkContext._lock:
SparkContext._active_spark_context = None
def parallelize(self, c, numSlices=None):
"""
......
......@@ -4,13 +4,15 @@ import os
class SparkFiles(object):
"""
Resolves paths to files added through
L{addFile()<pyspark.context.SparkContext.addFile>}.
L{SparkContext.addFile()<pyspark.context.SparkContext.addFile>}.
SparkFiles contains only classmethods; users should not create SparkFiles
instances.
"""
_root_directory = None
_is_running_on_worker = False
_sc = None
def __init__(self):
raise NotImplementedError("Do not construct SparkFiles objects")
......@@ -18,7 +20,19 @@ class SparkFiles(object):
@classmethod
def get(cls, filename):
"""
Get the absolute path of a file added through C{addFile()}.
Get the absolute path of a file added through C{SparkContext.addFile()}.
"""
path = os.path.join(SparkFiles._root_directory, filename)
path = os.path.join(SparkFiles.getRootDirectory(), filename)
return os.path.abspath(path)
@classmethod
def getRootDirectory(cls):
"""
Get the root directory that contains files added through
C{SparkContext.addFile()}.
"""
if cls._is_running_on_worker:
return cls._root_directory
else:
# This will have to change if we support multiple SparkContexts:
return cls._sc.jvm.spark.SparkFiles.getRootDirectory()
......@@ -4,22 +4,26 @@ individual modules.
"""
import os
import shutil
import sys
from tempfile import NamedTemporaryFile
import time
import unittest
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.java_gateway import SPARK_HOME
class PySparkTestCase(unittest.TestCase):
def setUp(self):
self._old_sys_path = list(sys.path)
class_name = self.__class__.__name__
self.sc = SparkContext('local[4]', class_name , batchSize=2)
def tearDown(self):
self.sc.stop()
sys.path = self._old_sys_path
# To avoid Akka rebinding to the same port, since it doesn't unbind
# immediately on shutdown
self.sc.jvm.System.clearProperty("spark.master.port")
......@@ -84,6 +88,25 @@ class TestAddFile(PySparkTestCase):
res = self.sc.parallelize(range(2)).map(func).first()
self.assertEqual("Hello World!", res)
def test_add_file_locally(self):
path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
self.sc.addFile(path)
download_path = SparkFiles.get("hello.txt")
self.assertNotEqual(path, download_path)
with open(download_path) as test_file:
self.assertEquals("Hello World!\n", test_file.readline())
def test_add_py_file_locally(self):
# To ensure that we're actually testing addPyFile's effects, check that
# this fails due to `userlibrary` not being on the Python path:
def func():
from userlibrary import UserClass
self.assertRaises(ImportError, func)
path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
self.sc.addFile(path)
from userlibrary import UserClass
self.assertEqual("Hello World!", UserClass().hello())
if __name__ == "__main__":
unittest.main()
......@@ -26,6 +26,7 @@ def main():
split_index = read_int(sys.stdin)
spark_files_dir = load_pickle(read_with_length(sys.stdin))
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
sys.path.append(spark_files_dir)
num_broadcast_variables = read_int(sys.stdin)
for _ in range(num_broadcast_variables):
......
Hello World!
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment