Skip to content
Snippets Groups Projects
Commit ac32447c authored by Josh Rosen's avatar Josh Rosen
Browse files

Use addFile() to ship code to cluster in PySpark.

Add options to pyspark.SparkContext constructor.
parent 85b8f2c6
No related branches found
No related tags found
No related merge requests found
......@@ -22,20 +22,54 @@ class SparkContext(object):
readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile
def __init__(self, master, name, defaultParallelism=None, batchSize=-1):
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024):
"""
Create a new SparkContext.
@param master: Cluster URL to connect to
(e.g. mesos://host:port, spark://host:port, local[4]).
@param jobName: A name for your job, to display on the cluster web UI
@param sparkHome: Location where Spark is installed on cluster nodes.
@param pyFiles: Collection of .zip or .py files to send to the cluster
and add to PYTHONPATH. These can be paths on the local file
system or HDFS, HTTP, HTTPS, or FTP URLs.
@param environment: A dictionary of environment variables to set on
worker nodes.
@param batchSize: The number of Python objects represented as a single
Java object. Set 1 to disable batching or -1 to use an
unlimited batch size.
"""
self.master = master
self.name = name
self._jsc = self.jvm.JavaSparkContext(master, name)
self.defaultParallelism = \
defaultParallelism or self._jsc.sc().defaultParallelism()
self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python')
self.jobName = jobName
self.sparkHome = sparkHome or None # None becomes null in Py4J
self.environment = environment or {}
self.batchSize = batchSize # -1 represents a unlimited batch size
# Create the Java SparkContext through Py4J
empty_string_array = self.gateway.new_array(self.jvm.String, 0)
self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome,
empty_string_array)
self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python')
# Broadcast's __reduce__ method stores Broadcast instances here.
# This allows other code to determine which Broadcast instances have
# been pickled, so it can determine which Java broadcast objects to
# send.
self._pickled_broadcast_vars = set()
# Deploy any code dependencies specified in the constructor
for path in (pyFiles or []):
self.addPyFile(path)
@property
def defaultParallelism(self):
"""
Default level of parallelism to use when not given by user (e.g. for
reduce tasks)
"""
return self._jsc.sc().defaultParallelism()
def __del__(self):
if self._jsc:
self._jsc.stop()
......@@ -75,7 +109,7 @@ class SparkContext(object):
def union(self, rdds):
"""
Build the union of a list of RDDs
Build the union of a list of RDDs.
"""
first = rdds[0]._jrdd
rest = [x._jrdd for x in rdds[1:]]
......@@ -91,3 +125,32 @@ class SparkContext(object):
jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
return Broadcast(jbroadcast.id(), value, jbroadcast,
self._pickled_broadcast_vars)
def addFile(self, path):
"""
Add a file to be downloaded into the working directory of this Spark
job on every node. The C{path} passed can be either a local file,
a file in HDFS (or other Hadoop-supported filesystems), or an HTTP,
HTTPS or FTP URI.
"""
self._jsc.sc().addFile(path)
def clearFiles(self):
"""
Clear the job's list of files added by L{addFile} or L{addPyFile} so
that they do not get downloaded to any new nodes.
"""
# TODO: remove added .py or .zip files from the PYTHONPATH?
self._jsc.sc().clearFiles()
def addPyFile(self, path):
"""
Add a .py or .zip dependency for all tasks to be executed on this
SparkContext in the future. The C{path} passed can be either a local
file, a file in HDFS (or other Hadoop-supported filesystems), or an
HTTP, HTTPS or FTP URI.
"""
self.addFile(path)
filename = path.split("/")[-1]
os.environ["PYTHONPATH"] = \
"%s:%s" % (filename, os.environ["PYTHONPATH"])
import atexit
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
from itertools import chain, ifilter, imap
import operator
......@@ -673,9 +674,9 @@ class PipelinedRDD(RDD):
self.ctx.gateway._gateway_client)
self.ctx._pickled_broadcast_vars.clear()
class_manifest = self._prev_jrdd.classManifest()
env = MapConverter().convert(
{'PYTHONPATH' : os.environ.get("PYTHONPATH", "")},
self.ctx.gateway._gateway_client)
env = copy.copy(self.ctx.environment)
env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "")
env = MapConverter().convert(env, self.ctx.gateway._gateway_client)
python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
broadcast_vars, class_manifest)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment