diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index 924da3eecf214cc5abb45642f339dc1bf927ea07..64b6f238e9c32cd36d1bcdda8fc5ba0a84b7e59e 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -52,6 +52,14 @@ spark.home=/path >>> sorted(conf.getAll(), key=lambda p: p[0]) [(u'spark.executorEnv.VAR1', u'value1'), (u'spark.executorEnv.VAR3', u'value3'), \ (u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')] +>>> conf._jconf.setExecutorEnv("VAR5", "value5") +JavaObject id... +>>> print(conf.toDebugString()) +spark.executorEnv.VAR1=value1 +spark.executorEnv.VAR3=value3 +spark.executorEnv.VAR4=value4 +spark.executorEnv.VAR5=value5 +spark.home=/path """ __all__ = ['SparkConf'] @@ -101,13 +109,24 @@ class SparkConf(object): self._jconf = _jconf else: from pyspark.context import SparkContext - SparkContext._ensure_initialized() _jvm = _jvm or SparkContext._jvm - self._jconf = _jvm.SparkConf(loadDefaults) + + if _jvm is not None: + # JVM is created, so create self._jconf directly through JVM + self._jconf = _jvm.SparkConf(loadDefaults) + self._conf = None + else: + # JVM is not created, so store data in self._conf first + self._jconf = None + self._conf = {} def set(self, key, value): """Set a configuration property.""" - self._jconf.set(key, unicode(value)) + # Try to set self._jconf first if JVM is created, set self._conf if JVM is not created yet. + if self._jconf is not None: + self._jconf.set(key, unicode(value)) + else: + self._conf[key] = unicode(value) return self def setIfMissing(self, key, value): @@ -118,17 +137,17 @@ class SparkConf(object): def setMaster(self, value): """Set master URL to connect to.""" - self._jconf.setMaster(value) + self.set("spark.master", value) return self def setAppName(self, value): """Set application name.""" - self._jconf.setAppName(value) + self.set("spark.app.name", value) return self def setSparkHome(self, value): """Set path where Spark is installed on worker nodes.""" - self._jconf.setSparkHome(value) + self.set("spark.home", value) return self def setExecutorEnv(self, key=None, value=None, pairs=None): @@ -136,10 +155,10 @@ class SparkConf(object): if (key is not None and pairs is not None) or (key is None and pairs is None): raise Exception("Either pass one key-value pair or a list of pairs") elif key is not None: - self._jconf.setExecutorEnv(key, value) + self.set("spark.executorEnv." + key, value) elif pairs is not None: for (k, v) in pairs: - self._jconf.setExecutorEnv(k, v) + self.set("spark.executorEnv." + k, v) return self def setAll(self, pairs): @@ -149,35 +168,49 @@ class SparkConf(object): :param pairs: list of key-value pairs to set """ for (k, v) in pairs: - self._jconf.set(k, v) + self.set(k, v) return self def get(self, key, defaultValue=None): """Get the configured value for some key, or return a default otherwise.""" if defaultValue is None: # Py4J doesn't call the right get() if we pass None - if not self._jconf.contains(key): - return None - return self._jconf.get(key) + if self._jconf is not None: + if not self._jconf.contains(key): + return None + return self._jconf.get(key) + else: + if key not in self._conf: + return None + return self._conf[key] else: - return self._jconf.get(key, defaultValue) + if self._jconf is not None: + return self._jconf.get(key, defaultValue) + else: + return self._conf.get(key, defaultValue) def getAll(self): """Get all values as a list of key-value pairs.""" - pairs = [] - for elem in self._jconf.getAll(): - pairs.append((elem._1(), elem._2())) - return pairs + if self._jconf is not None: + return [(elem._1(), elem._2()) for elem in self._jconf.getAll()] + else: + return self._conf.items() def contains(self, key): """Does this configuration contain a given key?""" - return self._jconf.contains(key) + if self._jconf is not None: + return self._jconf.contains(key) + else: + return key in self._conf def toDebugString(self): """ Returns a printable version of the configuration, as a list of key=value pairs, one per line. """ - return self._jconf.toDebugString() + if self._jconf is not None: + return self._jconf.toDebugString() + else: + return '\n'.join('%s=%s' % (k, v) for k, v in self._conf.items()) def _test(): diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a3dd1950a522f2c96539ba30c245546a364d35a5..1b2e199c395be889ec39c2b07f0dc692ee2e8eb7 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -109,7 +109,7 @@ class SparkContext(object): ValueError:... """ self._callsite = first_spark_call() or CallSite(None, None, None) - SparkContext._ensure_initialized(self, gateway=gateway) + SparkContext._ensure_initialized(self, gateway=gateway, conf=conf) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, conf, jsc, profiler_cls) @@ -121,7 +121,15 @@ class SparkContext(object): def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, conf, jsc, profiler_cls): self.environment = environment or {} - self._conf = conf or SparkConf(_jvm=self._jvm) + # java gateway must have been launched at this point. + if conf is not None and conf._jconf is not None: + # conf has been initialized in JVM properly, so use conf directly. This represent the + # scenario that JVM has been launched before SparkConf is created (e.g. SparkContext is + # created and then stopped, and we create a new SparkConf and new SparkContext again) + self._conf = conf + else: + self._conf = SparkConf(_jvm=SparkContext._jvm) + self._batchSize = batchSize # -1 represents an unlimited batch size self._unbatched_serializer = serializer if batchSize == 0: @@ -232,14 +240,14 @@ class SparkContext(object): return self._jvm.JavaSparkContext(jconf) @classmethod - def _ensure_initialized(cls, instance=None, gateway=None): + def _ensure_initialized(cls, instance=None, gateway=None, conf=None): """ Checks whether a SparkContext is initialized or not. Throws error if a SparkContext is already running. """ with SparkContext._lock: if not SparkContext._gateway: - SparkContext._gateway = gateway or launch_gateway() + SparkContext._gateway = gateway or launch_gateway(conf) SparkContext._jvm = SparkContext._gateway.jvm if instance: diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index f76cadcf624382328b055dedd6937056caa06908..c1cf843d84388584ec064fd25380ea4f63c09f86 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -32,7 +32,12 @@ from py4j.java_gateway import java_import, JavaGateway, GatewayClient from pyspark.serializers import read_int -def launch_gateway(): +def launch_gateway(conf=None): + """ + launch jvm gateway + :param conf: spark configuration passed to spark-submit + :return: + """ if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) else: @@ -41,13 +46,17 @@ def launch_gateway(): # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" + command = [os.path.join(SPARK_HOME, script)] + if conf: + for k, v in conf.getAll(): + command += ['--conf', '%s=%s' % (k, v)] submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") if os.environ.get("SPARK_TESTING"): submit_args = ' '.join([ "--conf spark.ui.enabled=false", submit_args ]) - command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args) + command = command + shlex.split(submit_args) # Start a socket that will be used by PythonGatewayServer to communicate its port to us callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)