diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 4969d85f52b2375bd0a53f7637698c3b9e0c3b32..afd74d937a41324c3a1eb16e3b37b5c4fe51efaf 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -21,7 +21,7 @@ import os import shutil import signal import sys -from threading import Lock +from threading import RLock from tempfile import NamedTemporaryFile from pyspark import accumulators @@ -65,7 +65,7 @@ class SparkContext(object): _jvm = None _next_accum_id = 0 _active_spark_context = None - _lock = Lock() + _lock = RLock() _python_includes = None # zip and egg files that need to be added to PYTHONPATH PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar') @@ -280,6 +280,18 @@ class SparkContext(object): """ self.stop() + @classmethod + def getOrCreate(cls, conf=None): + """ + Get or instantiate a SparkContext and register it as a singleton object. + + :param conf: SparkConf (optional) + """ + with SparkContext._lock: + if SparkContext._active_spark_context is None: + SparkContext(conf=conf or SparkConf()) + return SparkContext._active_spark_context + def setLogLevel(self, logLevel): """ Control our logLevel. This overrides any user-defined log settings. diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 89c8c6e0d94f1ccc1140a2f1b7bb399e035a540e..79453658a167a04ae171d9ba78119cbe2644c875 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -75,6 +75,8 @@ class SQLContext(object): SQLContext in the JVM, instead we make all calls to this object. """ + _instantiatedContext = None + @ignore_unicode_prefix def __init__(self, sparkContext, sqlContext=None): """Creates a new SQLContext. @@ -99,6 +101,8 @@ class SQLContext(object): self._scala_SQLContext = sqlContext _monkey_patch_RDD(self) install_exception_handler() + if SQLContext._instantiatedContext is None: + SQLContext._instantiatedContext = self @property def _ssql_ctx(self): @@ -111,6 +115,29 @@ class SQLContext(object): self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) return self._scala_SQLContext + @classmethod + @since(1.6) + def getOrCreate(cls, sc): + """ + Get the existing SQLContext or create a new one with given SparkContext. + + :param sc: SparkContext + """ + if cls._instantiatedContext is None: + jsqlContext = sc._jvm.SQLContext.getOrCreate(sc._jsc.sc()) + cls(sc, jsqlContext) + return cls._instantiatedContext + + @since(1.6) + def newSession(self): + """ + Returns a new SQLContext as new session, that has separate SQLConf, + registered temporary tables and UDFs, but shared SparkContext and + table cache. + """ + jsqlContext = self._ssql_ctx.newSession() + return self.__class__(self._sc, jsqlContext) + @since(1.3) def setConf(self, key, value): """Sets the given Spark SQL configuration property. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 645133b2b2d847681ab6ba80e648c88522b1efeb..f465e1fa209419478f10b6d7944a008e0f1f72ba 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -174,6 +174,20 @@ class DataTypeTests(unittest.TestCase): self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1)) +class SQLContextTests(ReusedPySparkTestCase): + def test_get_or_create(self): + sqlCtx = SQLContext.getOrCreate(self.sc) + self.assertTrue(SQLContext.getOrCreate(self.sc) is sqlCtx) + + def test_new_session(self): + sqlCtx = SQLContext.getOrCreate(self.sc) + sqlCtx.setConf("test_key", "a") + sqlCtx2 = sqlCtx.newSession() + sqlCtx2.setConf("test_key", "b") + self.assertEqual(sqlCtx.getConf("test_key", ""), "a") + self.assertEqual(sqlCtx2.getConf("test_key", ""), "b") + + class SQLTests(ReusedPySparkTestCase): @classmethod diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 63cc87e0c4b2c2835a1ee43fa9a73d6f243db435..3c51809444401aa73b8f1fa84139821d97deac44 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1883,6 +1883,10 @@ class ContextTests(unittest.TestCase): # Regression test for SPARK-1550 self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) + def test_get_or_create(self): + with SparkContext.getOrCreate() as sc: + self.assertTrue(SparkContext.getOrCreate() is sc) + def test_stop(self): sc = SparkContext() self.assertNotEqual(SparkContext._active_spark_context, None)