diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 1e40b9c39fc4facdae512ff11a3749b0c177f45e..9f4772eec9f2ac3b88f725c51f87d77c8f684829 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -161,8 +161,8 @@ class SparkSession(object): with self._lock: from pyspark.context import SparkContext from pyspark.conf import SparkConf - session = SparkSession._instantiatedContext - if session is None: + session = SparkSession._instantiatedSession + if session is None or session._sc._jsc is None: sparkConf = SparkConf() for key, value in self._options.items(): sparkConf.set(key, value) @@ -183,7 +183,7 @@ class SparkSession(object): builder = Builder() - _instantiatedContext = None + _instantiatedSession = None @ignore_unicode_prefix def __init__(self, sparkContext, jsparkSession=None): @@ -214,8 +214,12 @@ class SparkSession(object): self._wrapped = SQLContext(self._sc, self, self._jwrapped) _monkey_patch_RDD(self) install_exception_handler() - if SparkSession._instantiatedContext is None: - SparkSession._instantiatedContext = self + # If we had an instantiated SparkSession attached with a SparkContext + # which is stopped now, we need to renew the instantiated SparkSession. + # Otherwise, we will use invalid SparkSession when we call Builder.getOrCreate. + if SparkSession._instantiatedSession is None \ + or SparkSession._instantiatedSession._sc._jsc is None: + SparkSession._instantiatedSession = self @since(2.0) def newSession(self): @@ -595,7 +599,7 @@ class SparkSession(object): """Stop the underlying :class:`SparkContext`. """ self._sc.stop() - SparkSession._instantiatedContext = None + SparkSession._instantiatedSession = None @since(2.0) def __enter__(self): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6de63e649325c327754de8b2d46352e25f535636..fe034bc0a4a76287ac797d5b550b49c961442329 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -46,6 +46,7 @@ if sys.version_info[:2] <= (2, 6): else: import unittest +from pyspark import SparkContext from pyspark.sql import SparkSession, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type @@ -1877,6 +1878,28 @@ class HiveSparkSubmitTests(SparkSubmitTests): self.assertTrue(os.path.exists(metastore_path)) +class SQLTests2(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() + + # We can't include this test into SQLTests because we will stop class's SparkContext and cause + # other tests failed. + def test_sparksession_with_stopped_sparkcontext(self): + self.sc.stop() + sc = SparkContext('local[4]', self.sc.appName) + spark = SparkSession.builder.getOrCreate() + df = spark.createDataFrame([(1, 2)], ["c", "c"]) + df.collect() + + class HiveContextSQLTests(ReusedPySparkTestCase): @classmethod