diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index de4c335ad2752a6891ed4e04cd42f3eefcad7140..c22f4b87e1a785db498ccafd27c9977b599f9390 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -73,7 +73,7 @@ class SQLContext(object): self._jsc = self._sc._jsc self._jvm = self._sc._jvm if sparkSession is None: - sparkSession = SparkSession(sparkContext) + sparkSession = SparkSession.builder.getOrCreate() if jsqlContext is None: jsqlContext = sparkSession._jwrapped self.sparkSession = sparkSession diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d1782857e6d0defb488386689c2efd5bae32d0ba..a8250281dab351b827d18cbf1738e7fc4f06477f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -47,7 +47,7 @@ else: import unittest from pyspark import SparkContext -from pyspark.sql import SparkSession, HiveContext, Column, Row +from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests @@ -206,6 +206,11 @@ class SQLTests(ReusedPySparkTestCase): cls.spark.stop() shutil.rmtree(cls.tempdir.name, ignore_errors=True) + def test_sqlcontext_reuses_sparksession(self): + sqlContext1 = SQLContext(self.sc) + sqlContext2 = SQLContext(self.sc) + self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession) + def test_row_should_be_read_only(self): row = Row(a=1, b=2) self.assertEqual(1, row.a)