Skip to content
Snippets Groups Projects
Commit ae2ed294 authored by Josh Rosen's avatar Josh Rosen
Browse files

Allow PySpark's SparkFiles to be used from driver

Fix minor documentation formatting issues.
parent 35168d9c
No related branches found
No related tags found
No related merge requests found
...@@ -3,23 +3,23 @@ package spark; ...@@ -3,23 +3,23 @@ package spark;
import java.io.File; import java.io.File;
/** /**
* Resolves paths to files added through `addFile(). * Resolves paths to files added through `SparkContext.addFile()`.
*/ */
public class SparkFiles { public class SparkFiles {
private SparkFiles() {} private SparkFiles() {}
/** /**
* Get the absolute path of a file added through `addFile()`. * Get the absolute path of a file added through `SparkContext.addFile()`.
*/ */
public static String get(String filename) { public static String get(String filename) {
return new File(getRootDirectory(), filename).getAbsolutePath(); return new File(getRootDirectory(), filename).getAbsolutePath();
} }
/** /**
* Get the root directory that contains files added through `addFile()`. * Get the root directory that contains files added through `SparkContext.addFile()`.
*/ */
public static String getRootDirectory() { public static String getRootDirectory() {
return SparkEnv.get().sparkFilesDir(); return SparkEnv.get().sparkFilesDir();
} }
} }
\ No newline at end of file
import os import os
import atexit import atexit
import shutil import shutil
import sys
import tempfile import tempfile
from threading import Lock
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from pyspark import accumulators from pyspark import accumulators
from pyspark.accumulators import Accumulator from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast from pyspark.broadcast import Broadcast
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway from pyspark.java_gateway import launch_gateway
from pyspark.serializers import dump_pickle, write_with_length, batched from pyspark.serializers import dump_pickle, write_with_length, batched
from pyspark.rdd import RDD from pyspark.rdd import RDD
...@@ -27,6 +30,8 @@ class SparkContext(object): ...@@ -27,6 +30,8 @@ class SparkContext(object):
_writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
_takePartition = jvm.PythonRDD.takePartition _takePartition = jvm.PythonRDD.takePartition
_next_accum_id = 0 _next_accum_id = 0
_active_spark_context = None
_lock = Lock()
def __init__(self, master, jobName, sparkHome=None, pyFiles=None, def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024): environment=None, batchSize=1024):
...@@ -46,6 +51,11 @@ class SparkContext(object): ...@@ -46,6 +51,11 @@ class SparkContext(object):
Java object. Set 1 to disable batching or -1 to use an Java object. Set 1 to disable batching or -1 to use an
unlimited batch size. unlimited batch size.
""" """
with SparkContext._lock:
if SparkContext._active_spark_context:
raise ValueError("Cannot run multiple SparkContexts at once")
else:
SparkContext._active_spark_context = self
self.master = master self.master = master
self.jobName = jobName self.jobName = jobName
self.sparkHome = sparkHome or None # None becomes null in Py4J self.sparkHome = sparkHome or None # None becomes null in Py4J
...@@ -75,6 +85,8 @@ class SparkContext(object): ...@@ -75,6 +85,8 @@ class SparkContext(object):
# Deploy any code dependencies specified in the constructor # Deploy any code dependencies specified in the constructor
for path in (pyFiles or []): for path in (pyFiles or []):
self.addPyFile(path) self.addPyFile(path)
SparkFiles._sc = self
sys.path.append(SparkFiles.getRootDirectory())
@property @property
def defaultParallelism(self): def defaultParallelism(self):
...@@ -85,17 +97,20 @@ class SparkContext(object): ...@@ -85,17 +97,20 @@ class SparkContext(object):
return self._jsc.sc().defaultParallelism() return self._jsc.sc().defaultParallelism()
def __del__(self): def __del__(self):
if self._jsc: self.stop()
self._jsc.stop()
if self._accumulatorServer:
self._accumulatorServer.shutdown()
def stop(self): def stop(self):
""" """
Shut down the SparkContext. Shut down the SparkContext.
""" """
self._jsc.stop() if self._jsc:
self._jsc = None self._jsc.stop()
self._jsc = None
if self._accumulatorServer:
self._accumulatorServer.shutdown()
self._accumulatorServer = None
with SparkContext._lock:
SparkContext._active_spark_context = None
def parallelize(self, c, numSlices=None): def parallelize(self, c, numSlices=None):
""" """
......
...@@ -4,13 +4,15 @@ import os ...@@ -4,13 +4,15 @@ import os
class SparkFiles(object): class SparkFiles(object):
""" """
Resolves paths to files added through Resolves paths to files added through
L{addFile()<pyspark.context.SparkContext.addFile>}. L{SparkContext.addFile()<pyspark.context.SparkContext.addFile>}.
SparkFiles contains only classmethods; users should not create SparkFiles SparkFiles contains only classmethods; users should not create SparkFiles
instances. instances.
""" """
_root_directory = None _root_directory = None
_is_running_on_worker = False
_sc = None
def __init__(self): def __init__(self):
raise NotImplementedError("Do not construct SparkFiles objects") raise NotImplementedError("Do not construct SparkFiles objects")
...@@ -18,7 +20,19 @@ class SparkFiles(object): ...@@ -18,7 +20,19 @@ class SparkFiles(object):
@classmethod @classmethod
def get(cls, filename): def get(cls, filename):
""" """
Get the absolute path of a file added through C{addFile()}. Get the absolute path of a file added through C{SparkContext.addFile()}.
""" """
path = os.path.join(SparkFiles._root_directory, filename) path = os.path.join(SparkFiles.getRootDirectory(), filename)
return os.path.abspath(path) return os.path.abspath(path)
@classmethod
def getRootDirectory(cls):
"""
Get the root directory that contains files added through
C{SparkContext.addFile()}.
"""
if cls._is_running_on_worker:
return cls._root_directory
else:
# This will have to change if we support multiple SparkContexts:
return cls._sc.jvm.spark.SparkFiles.getRootDirectory()
...@@ -4,22 +4,26 @@ individual modules. ...@@ -4,22 +4,26 @@ individual modules.
""" """
import os import os
import shutil import shutil
import sys
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
import time import time
import unittest import unittest
from pyspark.context import SparkContext from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.java_gateway import SPARK_HOME from pyspark.java_gateway import SPARK_HOME
class PySparkTestCase(unittest.TestCase): class PySparkTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self._old_sys_path = list(sys.path)
class_name = self.__class__.__name__ class_name = self.__class__.__name__
self.sc = SparkContext('local[4]', class_name , batchSize=2) self.sc = SparkContext('local[4]', class_name , batchSize=2)
def tearDown(self): def tearDown(self):
self.sc.stop() self.sc.stop()
sys.path = self._old_sys_path
# To avoid Akka rebinding to the same port, since it doesn't unbind # To avoid Akka rebinding to the same port, since it doesn't unbind
# immediately on shutdown # immediately on shutdown
self.sc.jvm.System.clearProperty("spark.master.port") self.sc.jvm.System.clearProperty("spark.master.port")
...@@ -84,6 +88,25 @@ class TestAddFile(PySparkTestCase): ...@@ -84,6 +88,25 @@ class TestAddFile(PySparkTestCase):
res = self.sc.parallelize(range(2)).map(func).first() res = self.sc.parallelize(range(2)).map(func).first()
self.assertEqual("Hello World!", res) self.assertEqual("Hello World!", res)
def test_add_file_locally(self):
path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
self.sc.addFile(path)
download_path = SparkFiles.get("hello.txt")
self.assertNotEqual(path, download_path)
with open(download_path) as test_file:
self.assertEquals("Hello World!\n", test_file.readline())
def test_add_py_file_locally(self):
# To ensure that we're actually testing addPyFile's effects, check that
# this fails due to `userlibrary` not being on the Python path:
def func():
from userlibrary import UserClass
self.assertRaises(ImportError, func)
path = os.path.join(SPARK_HOME, "python/test_support/userlibrary.py")
self.sc.addFile(path)
from userlibrary import UserClass
self.assertEqual("Hello World!", UserClass().hello())
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -26,6 +26,7 @@ def main(): ...@@ -26,6 +26,7 @@ def main():
split_index = read_int(sys.stdin) split_index = read_int(sys.stdin)
spark_files_dir = load_pickle(read_with_length(sys.stdin)) spark_files_dir = load_pickle(read_with_length(sys.stdin))
SparkFiles._root_directory = spark_files_dir SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
sys.path.append(spark_files_dir) sys.path.append(spark_files_dir)
num_broadcast_variables = read_int(sys.stdin) num_broadcast_variables = read_int(sys.stdin)
for _ in range(num_broadcast_variables): for _ in range(num_broadcast_variables):
......
Hello World!
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment