diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py index 7758d3e375e1d050a1703f8ba41641add25e8eab..988c81cd5d9acae5faddfc4727e2bf483e91024d 100644 --- a/pyspark/pyspark/context.py +++ b/pyspark/pyspark/context.py @@ -22,20 +22,54 @@ class SparkContext(object): readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile - def __init__(self, master, name, defaultParallelism=None, batchSize=-1): + def __init__(self, master, jobName, sparkHome=None, pyFiles=None, + environment=None, batchSize=1024): + """ + Create a new SparkContext. + + @param master: Cluster URL to connect to + (e.g. mesos://host:port, spark://host:port, local[4]). + @param jobName: A name for your job, to display on the cluster web UI + @param sparkHome: Location where Spark is installed on cluster nodes. + @param pyFiles: Collection of .zip or .py files to send to the cluster + and add to PYTHONPATH. These can be paths on the local file + system or HDFS, HTTP, HTTPS, or FTP URLs. + @param environment: A dictionary of environment variables to set on + worker nodes. + @param batchSize: The number of Python objects represented as a single + Java object. Set 1 to disable batching or -1 to use an + unlimited batch size. + """ self.master = master - self.name = name - self._jsc = self.jvm.JavaSparkContext(master, name) - self.defaultParallelism = \ - defaultParallelism or self._jsc.sc().defaultParallelism() - self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python') + self.jobName = jobName + self.sparkHome = sparkHome or None # None becomes null in Py4J + self.environment = environment or {} self.batchSize = batchSize # -1 represents a unlimited batch size + + # Create the Java SparkContext through Py4J + empty_string_array = self.gateway.new_array(self.jvm.String, 0) + self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome, + empty_string_array) + + self.pythonExec = os.environ.get("PYSPARK_PYTHON_EXEC", 'python') # Broadcast's __reduce__ method stores Broadcast instances here. # This allows other code to determine which Broadcast instances have # been pickled, so it can determine which Java broadcast objects to # send. self._pickled_broadcast_vars = set() + # Deploy any code dependencies specified in the constructor + for path in (pyFiles or []): + self.addPyFile(path) + + @property + def defaultParallelism(self): + """ + Default level of parallelism to use when not given by user (e.g. for + reduce tasks) + """ + return self._jsc.sc().defaultParallelism() + def __del__(self): if self._jsc: self._jsc.stop() @@ -75,7 +109,7 @@ class SparkContext(object): def union(self, rdds): """ - Build the union of a list of RDDs + Build the union of a list of RDDs. """ first = rdds[0]._jrdd rest = [x._jrdd for x in rdds[1:]] @@ -91,3 +125,32 @@ class SparkContext(object): jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) + + def addFile(self, path): + """ + Add a file to be downloaded into the working directory of this Spark + job on every node. The C{path} passed can be either a local file, + a file in HDFS (or other Hadoop-supported filesystems), or an HTTP, + HTTPS or FTP URI. + """ + self._jsc.sc().addFile(path) + + def clearFiles(self): + """ + Clear the job's list of files added by L{addFile} or L{addPyFile} so + that they do not get downloaded to any new nodes. + """ + # TODO: remove added .py or .zip files from the PYTHONPATH? + self._jsc.sc().clearFiles() + + def addPyFile(self, path): + """ + Add a .py or .zip dependency for all tasks to be executed on this + SparkContext in the future. The C{path} passed can be either a local + file, a file in HDFS (or other Hadoop-supported filesystems), or an + HTTP, HTTPS or FTP URI. + """ + self.addFile(path) + filename = path.split("/")[-1] + os.environ["PYTHONPATH"] = \ + "%s:%s" % (filename, os.environ["PYTHONPATH"]) diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py index 5af105ef62277b165d123ae10ea1fd286029661d..bf32472d2572e9b7f7e4f328f4e32de4d555835f 100644 --- a/pyspark/pyspark/rdd.py +++ b/pyspark/pyspark/rdd.py @@ -1,5 +1,6 @@ import atexit from base64 import standard_b64encode as b64enc +import copy from collections import defaultdict from itertools import chain, ifilter, imap import operator @@ -673,9 +674,9 @@ class PipelinedRDD(RDD): self.ctx.gateway._gateway_client) self.ctx._pickled_broadcast_vars.clear() class_manifest = self._prev_jrdd.classManifest() - env = MapConverter().convert( - {'PYTHONPATH' : os.environ.get("PYTHONPATH", "")}, - self.ctx.gateway._gateway_client) + env = copy.copy(self.ctx.environment) + env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "") + env = MapConverter().convert(env, self.ctx.gateway._gateway_client) python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(), pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec, broadcast_vars, class_manifest)