diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index ad1631fb5baa50f36d69e53983a5f87c513ab7e1..49d3a4a332fd19390dce1d67eed2b076b7ea8ef4 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -57,13 +57,25 @@ from pyspark.ml.tuning import *
 from pyspark.ml.wrapper import JavaParams
 from pyspark.mllib.common import _java2py
 from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector
-from pyspark.sql import DataFrame, SQLContext, Row
+from pyspark.sql import DataFrame, Row, SparkSession
 from pyspark.sql.functions import rand
 from pyspark.sql.utils import IllegalArgumentException
 from pyspark.storagelevel import *
 from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
 
 
+class SparkSessionTestCase(PySparkTestCase):
+    @classmethod
+    def setUpClass(cls):
+        PySparkTestCase.setUpClass()
+        cls.spark = SparkSession(cls.sc)
+
+    @classmethod
+    def tearDownClass(cls):
+        PySparkTestCase.tearDownClass()
+        cls.spark.stop()
+
+
 class MockDataset(DataFrame):
 
     def __init__(self):
@@ -350,7 +362,7 @@ class ParamTests(PySparkTestCase):
         self.assertEqual(model.getWindowSize(), 6)
 
 
-class FeatureTests(PySparkTestCase):
+class FeatureTests(SparkSessionTestCase):
 
     def test_binarizer(self):
         b0 = Binarizer()
@@ -376,8 +388,7 @@ class FeatureTests(PySparkTestCase):
         self.assertEqual(b1.getOutputCol(), "output")
 
     def test_idf(self):
-        sqlContext = SQLContext(self.sc)
-        dataset = sqlContext.createDataFrame([
+        dataset = self.spark.createDataFrame([
             (DenseVector([1.0, 2.0]),),
             (DenseVector([0.0, 1.0]),),
             (DenseVector([3.0, 0.2]),)], ["tf"])
@@ -390,8 +401,7 @@ class FeatureTests(PySparkTestCase):
         self.assertIsNotNone(output.head().idf)
 
     def test_ngram(self):
-        sqlContext = SQLContext(self.sc)
-        dataset = sqlContext.createDataFrame([
+        dataset = self.spark.createDataFrame([
             Row(input=["a", "b", "c", "d", "e"])])
         ngram0 = NGram(n=4, inputCol="input", outputCol="output")
         self.assertEqual(ngram0.getN(), 4)
@@ -401,8 +411,7 @@ class FeatureTests(PySparkTestCase):
         self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"])
 
     def test_stopwordsremover(self):
-        sqlContext = SQLContext(self.sc)
-        dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])])
+        dataset = self.spark.createDataFrame([Row(input=["a", "panda"])])
         stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output")
         # Default
         self.assertEqual(stopWordRemover.getInputCol(), "input")
@@ -419,15 +428,14 @@ class FeatureTests(PySparkTestCase):
         self.assertEqual(transformedDF.head().output, ["a"])
         # with language selection
         stopwords = StopWordsRemover.loadDefaultStopWords("turkish")
-        dataset = sqlContext.createDataFrame([Row(input=["acaba", "ama", "biri"])])
+        dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])])
         stopWordRemover.setStopWords(stopwords)
         self.assertEqual(stopWordRemover.getStopWords(), stopwords)
         transformedDF = stopWordRemover.transform(dataset)
         self.assertEqual(transformedDF.head().output, [])
 
     def test_count_vectorizer_with_binary(self):
-        sqlContext = SQLContext(self.sc)
-        dataset = sqlContext.createDataFrame([
+        dataset = self.spark.createDataFrame([
             (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),),
             (1, "a a".split(' '), SparseVector(3, {0: 1.0}),),
             (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),),
@@ -475,11 +483,10 @@ class InducedErrorEstimator(Estimator, HasInducedError):
         return model
 
 
-class CrossValidatorTests(PySparkTestCase):
+class CrossValidatorTests(SparkSessionTestCase):
 
     def test_copy(self):
-        sqlContext = SQLContext(self.sc)
-        dataset = sqlContext.createDataFrame([
+        dataset = self.spark.createDataFrame([
             (10, 10.0),
             (50, 50.0),
             (100, 100.0),
@@ -503,8 +510,7 @@ class CrossValidatorTests(PySparkTestCase):
                             < 0.0001)
 
     def test_fit_minimize_metric(self):
-        sqlContext = SQLContext(self.sc)
-        dataset = sqlContext.createDataFrame([
+        dataset = self.spark.createDataFrame([
             (10, 10.0),
             (50, 50.0),
             (100, 100.0),
@@ -527,8 +533,7 @@ class CrossValidatorTests(PySparkTestCase):
         self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0")
 
     def test_fit_maximize_metric(self):
-        sqlContext = SQLContext(self.sc)
-        dataset = sqlContext.createDataFrame([
+        dataset = self.spark.createDataFrame([
             (10, 10.0),
             (50, 50.0),
             (100, 100.0),
@@ -554,8 +559,7 @@ class CrossValidatorTests(PySparkTestCase):
         # This tests saving and loading the trained model only.
         # Save/load for CrossValidator will be added later: SPARK-13786
         temp_path = tempfile.mkdtemp()
-        sqlContext = SQLContext(self.sc)
-        dataset = sqlContext.createDataFrame(
+        dataset = self.spark.createDataFrame(
             [(Vectors.dense([0.0]), 0.0),
              (Vectors.dense([0.4]), 1.0),
              (Vectors.dense([0.5]), 0.0),
@@ -576,11 +580,10 @@ class CrossValidatorTests(PySparkTestCase):
         self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
 
 
-class TrainValidationSplitTests(PySparkTestCase):
+class TrainValidationSplitTests(SparkSessionTestCase):
 
     def test_fit_minimize_metric(self):
-        sqlContext = SQLContext(self.sc)
-        dataset = sqlContext.createDataFrame([
+        dataset = self.spark.createDataFrame([
             (10, 10.0),
             (50, 50.0),
             (100, 100.0),
@@ -603,8 +606,7 @@ class TrainValidationSplitTests(PySparkTestCase):
         self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0")
 
     def test_fit_maximize_metric(self):
-        sqlContext = SQLContext(self.sc)
-        dataset = sqlContext.createDataFrame([
+        dataset = self.spark.createDataFrame([
             (10, 10.0),
             (50, 50.0),
             (100, 100.0),
@@ -630,8 +632,7 @@ class TrainValidationSplitTests(PySparkTestCase):
         # This tests saving and loading the trained model only.
         # Save/load for TrainValidationSplit will be added later: SPARK-13786
         temp_path = tempfile.mkdtemp()
-        sqlContext = SQLContext(self.sc)
-        dataset = sqlContext.createDataFrame(
+        dataset = self.spark.createDataFrame(
             [(Vectors.dense([0.0]), 0.0),
              (Vectors.dense([0.4]), 1.0),
              (Vectors.dense([0.5]), 0.0),
@@ -652,7 +653,7 @@ class TrainValidationSplitTests(PySparkTestCase):
         self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
 
 
-class PersistenceTest(PySparkTestCase):
+class PersistenceTest(SparkSessionTestCase):
 
     def test_linear_regression(self):
         lr = LinearRegression(maxIter=1)
@@ -724,11 +725,10 @@ class PersistenceTest(PySparkTestCase):
         """
         Pipeline[HashingTF, PCA]
         """
-        sqlContext = SQLContext(self.sc)
         temp_path = tempfile.mkdtemp()
 
         try:
-            df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
+            df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
             tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
             pca = PCA(k=2, inputCol="features", outputCol="pca_features")
             pl = Pipeline(stages=[tf, pca])
@@ -753,11 +753,10 @@ class PersistenceTest(PySparkTestCase):
         """
         Pipeline[HashingTF, Pipeline[PCA]]
         """
-        sqlContext = SQLContext(self.sc)
         temp_path = tempfile.mkdtemp()
 
         try:
-            df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
+            df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
             tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
             pca = PCA(k=2, inputCol="features", outputCol="pca_features")
             p0 = Pipeline(stages=[pca])
@@ -816,7 +815,7 @@ class PersistenceTest(PySparkTestCase):
             pass
 
 
-class LDATest(PySparkTestCase):
+class LDATest(SparkSessionTestCase):
 
     def _compare(self, m1, m2):
         """
@@ -836,8 +835,7 @@ class LDATest(PySparkTestCase):
 
     def test_persistence(self):
         # Test save/load for LDA, LocalLDAModel, DistributedLDAModel.
-        sqlContext = SQLContext(self.sc)
-        df = sqlContext.createDataFrame([
+        df = self.spark.createDataFrame([
             [1, Vectors.dense([0.0, 1.0])],
             [2, Vectors.sparse(2, {0: 1.0})],
         ], ["id", "features"])
@@ -871,12 +869,11 @@ class LDATest(PySparkTestCase):
             pass
 
 
-class TrainingSummaryTest(PySparkTestCase):
+class TrainingSummaryTest(SparkSessionTestCase):
 
     def test_linear_regression_summary(self):
         from pyspark.mllib.linalg import Vectors
-        sqlContext = SQLContext(self.sc)
-        df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+        df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
                                          (0.0, 2.0, Vectors.sparse(1, [], []))],
                                         ["label", "weight", "features"])
         lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight",
@@ -914,8 +911,7 @@ class TrainingSummaryTest(PySparkTestCase):
 
     def test_logistic_regression_summary(self):
         from pyspark.mllib.linalg import Vectors
-        sqlContext = SQLContext(self.sc)
-        df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
+        df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)),
                                          (0.0, 2.0, Vectors.sparse(1, [], []))],
                                         ["label", "weight", "features"])
         lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False)
@@ -942,11 +938,10 @@ class TrainingSummaryTest(PySparkTestCase):
         self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)
 
 
-class OneVsRestTests(PySparkTestCase):
+class OneVsRestTests(SparkSessionTestCase):
 
     def test_copy(self):
-        sqlContext = SQLContext(self.sc)
-        df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
+        df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
                                          (1.0, Vectors.sparse(2, [], [])),
                                          (2.0, Vectors.dense(0.5, 0.5))],
                                         ["label", "features"])
@@ -960,8 +955,7 @@ class OneVsRestTests(PySparkTestCase):
         self.assertEqual(model1.getPredictionCol(), "indexed")
 
     def test_output_columns(self):
-        sqlContext = SQLContext(self.sc)
-        df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
+        df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
                                          (1.0, Vectors.sparse(2, [], [])),
                                          (2.0, Vectors.dense(0.5, 0.5))],
                                         ["label", "features"])
@@ -973,8 +967,7 @@ class OneVsRestTests(PySparkTestCase):
 
     def test_save_load(self):
         temp_path = tempfile.mkdtemp()
-        sqlContext = SQLContext(self.sc)
-        df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
+        df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
                                          (1.0, Vectors.sparse(2, [], [])),
                                          (2.0, Vectors.dense(0.5, 0.5))],
                                         ["label", "features"])
@@ -994,12 +987,11 @@ class OneVsRestTests(PySparkTestCase):
             self.assertEqual(m.uid, n.uid)
 
 
-class HashingTFTest(PySparkTestCase):
+class HashingTFTest(SparkSessionTestCase):
 
     def test_apply_binary_term_freqs(self):
-        sqlContext = SQLContext(self.sc)
 
-        df = sqlContext.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"])
+        df = self.spark.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"])
         n = 10
         hashingTF = HashingTF()
         hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True)
@@ -1011,11 +1003,10 @@ class HashingTFTest(PySparkTestCase):
                                    ": expected " + str(expected[i]) + ", got " + str(features[i]))
 
 
-class ALSTest(PySparkTestCase):
+class ALSTest(SparkSessionTestCase):
 
     def test_storage_levels(self):
-        sqlContext = SQLContext(self.sc)
-        df = sqlContext.createDataFrame(
+        df = self.spark.createDataFrame(
             [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
             ["user", "item", "rating"])
         als = ALS().setMaxIter(1).setRank(1)
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 53a1d2c59cb2f5a346a43746e43174a407f67428..74cf7bb8eaf9d423754fcab1a5b4961cac6282a5 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -66,7 +66,8 @@ from pyspark.mllib.util import LinearDataGenerator
 from pyspark.mllib.util import MLUtils
 from pyspark.serializers import PickleSerializer
 from pyspark.streaming import StreamingContext
-from pyspark.sql import SQLContext
+from pyspark.sql import SparkSession
+from pyspark.sql.utils import IllegalArgumentException
 from pyspark.streaming import StreamingContext
 
 _have_scipy = False
@@ -83,9 +84,10 @@ ser = PickleSerializer()
 class MLlibTestCase(unittest.TestCase):
     def setUp(self):
         self.sc = SparkContext('local[4]', "MLlib tests")
+        self.spark = SparkSession(self.sc)
 
     def tearDown(self):
-        self.sc.stop()
+        self.spark.stop()
 
 
 class MLLibStreamingTestCase(unittest.TestCase):
@@ -698,7 +700,6 @@ class VectorUDTTests(MLlibTestCase):
             self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v)))
 
     def test_infer_schema(self):
-        sqlCtx = SQLContext(self.sc)
         rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)])
         df = rdd.toDF()
         schema = df.schema
@@ -731,7 +732,6 @@ class MatrixUDTTests(MLlibTestCase):
             self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m)))
 
     def test_infer_schema(self):
-        sqlCtx = SQLContext(self.sc)
         rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)])
         df = rdd.toDF()
         schema = df.schema
@@ -919,7 +919,7 @@ class ChiSqTestTests(MLlibTestCase):
 
         # Negative counts in observed
         neg_obs = Vectors.dense([1.0, 2.0, 3.0, -4.0])
-        self.assertRaises(Py4JJavaError, Statistics.chiSqTest, neg_obs, expected1)
+        self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_obs, expected1)
 
         # Count = 0.0 in expected but not observed
         zero_expected = Vectors.dense([1.0, 0.0, 3.0])
@@ -930,7 +930,8 @@ class ChiSqTestTests(MLlibTestCase):
 
         # 0.0 in expected and observed simultaneously
         zero_observed = Vectors.dense([2.0, 0.0, 1.0])
-        self.assertRaises(Py4JJavaError, Statistics.chiSqTest, zero_observed, zero_expected)
+        self.assertRaises(
+            IllegalArgumentException, Statistics.chiSqTest, zero_observed, zero_expected)
 
     def test_matrix_independence(self):
         data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0]
@@ -944,15 +945,15 @@ class ChiSqTestTests(MLlibTestCase):
 
         # Negative counts
         neg_counts = Matrices.dense(2, 2, [4.0, 5.0, 3.0, -3.0])
-        self.assertRaises(Py4JJavaError, Statistics.chiSqTest, neg_counts)
+        self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_counts)
 
         # Row sum = 0.0
         row_zero = Matrices.dense(2, 2, [0.0, 1.0, 0.0, 2.0])
-        self.assertRaises(Py4JJavaError, Statistics.chiSqTest, row_zero)
+        self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, row_zero)
 
         # Column sum = 0.0
         col_zero = Matrices.dense(2, 2, [0.0, 0.0, 2.0, 2.0])
-        self.assertRaises(Py4JJavaError, Statistics.chiSqTest, col_zero)
+        self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, col_zero)
 
     def test_chi_sq_pearson(self):
         data = [
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 7e79df33e85f5dab5b03cd48002a7f68943746a2..bd728c97c82a8b4ec9b33c351bba14955ca37390 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -47,19 +47,19 @@ def to_str(value):
 class DataFrameReader(object):
     """
     Interface used to load a :class:`DataFrame` from external storage systems
-    (e.g. file systems, key-value stores, etc). Use :func:`SQLContext.read`
+    (e.g. file systems, key-value stores, etc). Use :func:`spark.read`
     to access this.
 
     .. versionadded:: 1.4
     """
 
-    def __init__(self, sqlContext):
-        self._jreader = sqlContext._ssql_ctx.read()
-        self._sqlContext = sqlContext
+    def __init__(self, spark):
+        self._jreader = spark._ssql_ctx.read()
+        self._spark = spark
 
     def _df(self, jdf):
         from pyspark.sql.dataframe import DataFrame
-        return DataFrame(jdf, self._sqlContext)
+        return DataFrame(jdf, self._spark)
 
     @since(1.4)
     def format(self, source):
@@ -67,7 +67,7 @@ class DataFrameReader(object):
 
         :param source: string, name of the data source, e.g. 'json', 'parquet'.
 
-        >>> df = sqlContext.read.format('json').load('python/test_support/sql/people.json')
+        >>> df = spark.read.format('json').load('python/test_support/sql/people.json')
         >>> df.dtypes
         [('age', 'bigint'), ('name', 'string')]
 
@@ -87,7 +87,7 @@ class DataFrameReader(object):
         """
         if not isinstance(schema, StructType):
             raise TypeError("schema should be StructType")
-        jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json())
+        jschema = self._spark._ssql_ctx.parseDataType(schema.json())
         self._jreader = self._jreader.schema(jschema)
         return self
 
@@ -115,12 +115,12 @@ class DataFrameReader(object):
         :param schema: optional :class:`StructType` for the input schema.
         :param options: all other string options
 
-        >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned', opt1=True,
+        >>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True,
         ...     opt2=1, opt3='str')
         >>> df.dtypes
         [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
 
-        >>> df = sqlContext.read.format('json').load(['python/test_support/sql/people.json',
+        >>> df = spark.read.format('json').load(['python/test_support/sql/people.json',
         ...     'python/test_support/sql/people1.json'])
         >>> df.dtypes
         [('age', 'bigint'), ('aka', 'string'), ('name', 'string')]
@@ -133,7 +133,7 @@ class DataFrameReader(object):
         if path is not None:
             if type(path) != list:
                 path = [path]
-            return self._df(self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
+            return self._df(self._jreader.load(self._spark._sc._jvm.PythonUtils.toSeq(path)))
         else:
             return self._df(self._jreader.load())
 
@@ -148,7 +148,7 @@ class DataFrameReader(object):
         :param schema: optional :class:`StructType` for the input schema.
         :param options: all other string options
 
-        >>> df = sqlContext.read.format('text').stream('python/test_support/sql/streaming')
+        >>> df = spark.read.format('text').stream('python/test_support/sql/streaming')
         >>> df.isStreaming
         True
         """
@@ -211,11 +211,11 @@ class DataFrameReader(object):
                                           ``spark.sql.columnNameOfCorruptRecord``. If None is set,
                                           it uses the default value ``_corrupt_record``.
 
-        >>> df1 = sqlContext.read.json('python/test_support/sql/people.json')
+        >>> df1 = spark.read.json('python/test_support/sql/people.json')
         >>> df1.dtypes
         [('age', 'bigint'), ('name', 'string')]
         >>> rdd = sc.textFile('python/test_support/sql/people.json')
-        >>> df2 = sqlContext.read.json(rdd)
+        >>> df2 = spark.read.json(rdd)
         >>> df2.dtypes
         [('age', 'bigint'), ('name', 'string')]
 
@@ -243,7 +243,7 @@ class DataFrameReader(object):
         if isinstance(path, basestring):
             path = [path]
         if type(path) == list:
-            return self._df(self._jreader.json(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
+            return self._df(self._jreader.json(self._spark._sc._jvm.PythonUtils.toSeq(path)))
         elif isinstance(path, RDD):
             def func(iterator):
                 for x in iterator:
@@ -254,7 +254,7 @@ class DataFrameReader(object):
                     yield x
             keyed = path.mapPartitions(func)
             keyed._bypass_serializer = True
-            jrdd = keyed._jrdd.map(self._sqlContext._jvm.BytesToString())
+            jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString())
             return self._df(self._jreader.json(jrdd))
         else:
             raise TypeError("path can be only string or RDD")
@@ -265,9 +265,9 @@ class DataFrameReader(object):
 
         :param tableName: string, name of the table.
 
-        >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned')
+        >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned')
         >>> df.registerTempTable('tmpTable')
-        >>> sqlContext.read.table('tmpTable').dtypes
+        >>> spark.read.table('tmpTable').dtypes
         [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
         """
         return self._df(self._jreader.table(tableName))
@@ -276,11 +276,11 @@ class DataFrameReader(object):
     def parquet(self, *paths):
         """Loads a Parquet file, returning the result as a :class:`DataFrame`.
 
-        >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned')
+        >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned')
         >>> df.dtypes
         [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')]
         """
-        return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, paths)))
+        return self._df(self._jreader.parquet(_to_seq(self._spark._sc, paths)))
 
     @ignore_unicode_prefix
     @since(1.6)
@@ -291,13 +291,13 @@ class DataFrameReader(object):
 
         :param paths: string, or list of strings, for input path(s).
 
-        >>> df = sqlContext.read.text('python/test_support/sql/text-test.txt')
+        >>> df = spark.read.text('python/test_support/sql/text-test.txt')
         >>> df.collect()
         [Row(value=u'hello'), Row(value=u'this')]
         """
         if isinstance(paths, basestring):
             path = [paths]
-        return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
+        return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(path)))
 
     @since(2.0)
     def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None,
@@ -356,7 +356,7 @@ class DataFrameReader(object):
                 * ``DROPMALFORMED`` : ignores the whole corrupted records.
                 * ``FAILFAST`` : throws an exception when it meets corrupted records.
 
-        >>> df = sqlContext.read.csv('python/test_support/sql/ages.csv')
+        >>> df = spark.read.csv('python/test_support/sql/ages.csv')
         >>> df.dtypes
         [('C0', 'string'), ('C1', 'string')]
         """
@@ -396,7 +396,7 @@ class DataFrameReader(object):
             self.option("mode", mode)
         if isinstance(path, basestring):
             path = [path]
-        return self._df(self._jreader.csv(self._sqlContext._sc._jvm.PythonUtils.toSeq(path)))
+        return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path)))
 
     @since(1.5)
     def orc(self, path):
@@ -441,16 +441,16 @@ class DataFrameReader(object):
         """
         if properties is None:
             properties = dict()
-        jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
+        jprop = JavaClass("java.util.Properties", self._spark._sc._gateway._gateway_client)()
         for k in properties:
             jprop.setProperty(k, properties[k])
         if column is not None:
             if numPartitions is None:
-                numPartitions = self._sqlContext._sc.defaultParallelism
+                numPartitions = self._spark._sc.defaultParallelism
             return self._df(self._jreader.jdbc(url, table, column, int(lowerBound), int(upperBound),
                                                int(numPartitions), jprop))
         if predicates is not None:
-            gateway = self._sqlContext._sc._gateway
+            gateway = self._spark._sc._gateway
             jpredicates = utils.toJArray(gateway, gateway.jvm.java.lang.String, predicates)
             return self._df(self._jreader.jdbc(url, table, jpredicates, jprop))
         return self._df(self._jreader.jdbc(url, table, jprop))
@@ -466,7 +466,7 @@ class DataFrameWriter(object):
     """
     def __init__(self, df):
         self._df = df
-        self._sqlContext = df.sql_ctx
+        self._spark = df.sql_ctx
         self._jwrite = df._jdf.write()
 
     def _cq(self, jcq):
@@ -531,14 +531,14 @@ class DataFrameWriter(object):
         """
         if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
             cols = cols[0]
-        self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols))
+        self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols))
         return self
 
     @since(2.0)
     def queryName(self, queryName):
         """Specifies the name of the :class:`ContinuousQuery` that can be started with
         :func:`startStream`. This name must be unique among all the currently active queries
-        in the associated SQLContext
+        in the associated spark
 
         .. note:: Experimental.
 
@@ -573,7 +573,7 @@ class DataFrameWriter(object):
             trigger = ProcessingTime(processingTime)
         if trigger is None:
             raise ValueError('A trigger was not provided. Supported triggers: processingTime.')
-        self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._sqlContext))
+        self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._spark))
         return self
 
     @since(1.4)
@@ -854,7 +854,7 @@ class DataFrameWriter(object):
         """
         if properties is None:
             properties = dict()
-        jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)()
+        jprop = JavaClass("java.util.Properties", self._spark._sc._gateway._gateway_client)()
         for k in properties:
             jprop.setProperty(k, properties[k])
         self._jwrite.mode(mode).jdbc(url, table, jprop)
@@ -865,7 +865,7 @@ def _test():
     import os
     import tempfile
     from pyspark.context import SparkContext
-    from pyspark.sql import Row, SQLContext, HiveContext
+    from pyspark.sql import SparkSession, Row, HiveContext
     import pyspark.sql.readwriter
 
     os.chdir(os.environ["SPARK_HOME"])
@@ -876,11 +876,13 @@ def _test():
     globs['tempfile'] = tempfile
     globs['os'] = os
     globs['sc'] = sc
-    globs['sqlContext'] = SQLContext(sc)
+    globs['spark'] = SparkSession.builder\
+        .enableHiveSupport()\
+        .getOrCreate()
     globs['hiveContext'] = HiveContext._createForTesting(sc)
-    globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned')
+    globs['df'] = globs['spark'].read.parquet('python/test_support/sql/parquet_partitioned')
     globs['sdf'] = \
-        globs['sqlContext'].read.format('text').stream('python/test_support/sql/streaming')
+        globs['spark'].read.format('text').stream('python/test_support/sql/streaming')
 
     (failure_count, test_count) = doctest.testmod(
         pyspark.sql.readwriter, globs=globs,
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index cd5c4a7b3e9f84ce662a177cb816365de69eddf9..0c73f58c3b246b7ed1cfda7516aa5175991a7e85 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -45,7 +45,7 @@ if sys.version_info[:2] <= (2, 6):
 else:
     import unittest
 
-from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
+from pyspark.sql import SparkSession, HiveContext, Column, Row
 from pyspark.sql.types import *
 from pyspark.sql.types import UserDefinedType, _infer_type
 from pyspark.tests import ReusedPySparkTestCase
@@ -178,20 +178,6 @@ class DataTypeTests(unittest.TestCase):
         self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1))
 
 
-class SQLContextTests(ReusedPySparkTestCase):
-    def test_get_or_create(self):
-        sqlCtx = SQLContext.getOrCreate(self.sc)
-        self.assertTrue(SQLContext.getOrCreate(self.sc) is sqlCtx)
-
-    def test_new_session(self):
-        sqlCtx = SQLContext.getOrCreate(self.sc)
-        sqlCtx.setConf("test_key", "a")
-        sqlCtx2 = sqlCtx.newSession()
-        sqlCtx2.setConf("test_key", "b")
-        self.assertEqual(sqlCtx.getConf("test_key", ""), "a")
-        self.assertEqual(sqlCtx2.getConf("test_key", ""), "b")
-
-
 class SQLTests(ReusedPySparkTestCase):
 
     @classmethod
@@ -199,15 +185,14 @@ class SQLTests(ReusedPySparkTestCase):
         ReusedPySparkTestCase.setUpClass()
         cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
         os.unlink(cls.tempdir.name)
-        cls.sparkSession = SparkSession(cls.sc)
-        cls.sqlCtx = cls.sparkSession._wrapped
+        cls.spark = SparkSession(cls.sc)
         cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
-        rdd = cls.sc.parallelize(cls.testData, 2)
-        cls.df = rdd.toDF()
+        cls.df = cls.spark.createDataFrame(cls.testData)
 
     @classmethod
     def tearDownClass(cls):
         ReusedPySparkTestCase.tearDownClass()
+        cls.spark.stop()
         shutil.rmtree(cls.tempdir.name, ignore_errors=True)
 
     def test_row_should_be_read_only(self):
@@ -218,7 +203,7 @@ class SQLTests(ReusedPySparkTestCase):
             row.a = 3
         self.assertRaises(Exception, foo)
 
-        row2 = self.sqlCtx.range(10).first()
+        row2 = self.spark.range(10).first()
         self.assertEqual(0, row2.id)
 
         def foo2():
@@ -226,14 +211,14 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertRaises(Exception, foo2)
 
     def test_range(self):
-        self.assertEqual(self.sqlCtx.range(1, 1).count(), 0)
-        self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1)
-        self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2)
-        self.assertEqual(self.sqlCtx.range(-2).count(), 0)
-        self.assertEqual(self.sqlCtx.range(3).count(), 3)
+        self.assertEqual(self.spark.range(1, 1).count(), 0)
+        self.assertEqual(self.spark.range(1, 0, -1).count(), 1)
+        self.assertEqual(self.spark.range(0, 1 << 40, 1 << 39).count(), 2)
+        self.assertEqual(self.spark.range(-2).count(), 0)
+        self.assertEqual(self.spark.range(3).count(), 3)
 
     def test_duplicated_column_names(self):
-        df = self.sqlCtx.createDataFrame([(1, 2)], ["c", "c"])
+        df = self.spark.createDataFrame([(1, 2)], ["c", "c"])
         row = df.select('*').first()
         self.assertEqual(1, row[0])
         self.assertEqual(2, row[1])
@@ -247,7 +232,7 @@ class SQLTests(ReusedPySparkTestCase):
         from pyspark.sql.functions import explode
         d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
         rdd = self.sc.parallelize(d)
-        data = self.sqlCtx.createDataFrame(rdd)
+        data = self.spark.createDataFrame(rdd)
 
         result = data.select(explode(data.intlist).alias("a")).select("a").collect()
         self.assertEqual(result[0][0], 1)
@@ -269,7 +254,7 @@ class SQLTests(ReusedPySparkTestCase):
     def test_udf_with_callable(self):
         d = [Row(number=i, squared=i**2) for i in range(10)]
         rdd = self.sc.parallelize(d)
-        data = self.sqlCtx.createDataFrame(rdd)
+        data = self.spark.createDataFrame(rdd)
 
         class PlusFour:
             def __call__(self, col):
@@ -284,7 +269,7 @@ class SQLTests(ReusedPySparkTestCase):
     def test_udf_with_partial_function(self):
         d = [Row(number=i, squared=i**2) for i in range(10)]
         rdd = self.sc.parallelize(d)
-        data = self.sqlCtx.createDataFrame(rdd)
+        data = self.spark.createDataFrame(rdd)
 
         def some_func(col, param):
             if col is not None:
@@ -296,56 +281,56 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
 
     def test_udf(self):
-        self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
-        [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
+        self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
+        [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
         self.assertEqual(row[0], 5)
 
     def test_udf2(self):
-        self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType())
-        self.sqlCtx.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
-        [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
+        self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType())
+        self.spark.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
+        [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
         self.assertEqual(4, res[0])
 
     def test_chained_udf(self):
-        self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
-        [row] = self.sqlCtx.sql("SELECT double(1)").collect()
+        self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType())
+        [row] = self.spark.sql("SELECT double(1)").collect()
         self.assertEqual(row[0], 2)
-        [row] = self.sqlCtx.sql("SELECT double(double(1))").collect()
+        [row] = self.spark.sql("SELECT double(double(1))").collect()
         self.assertEqual(row[0], 4)
-        [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
+        [row] = self.spark.sql("SELECT double(double(1) + 1)").collect()
         self.assertEqual(row[0], 6)
 
     def test_multiple_udfs(self):
-        self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType())
-        [row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect()
+        self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType())
+        [row] = self.spark.sql("SELECT double(1), double(2)").collect()
         self.assertEqual(tuple(row), (2, 4))
-        [row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
+        [row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
         self.assertEqual(tuple(row), (4, 12))
-        self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType())
-        [row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
+        self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType())
+        [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
         self.assertEqual(tuple(row), (6, 5))
 
     def test_udf_with_array_type(self):
         d = [Row(l=list(range(3)), d={"key": list(range(5))})]
         rdd = self.sc.parallelize(d)
-        self.sqlCtx.createDataFrame(rdd).registerTempTable("test")
-        self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
-        self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
-        [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
+        self.spark.createDataFrame(rdd).registerTempTable("test")
+        self.spark.catalog.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
+        self.spark.catalog.registerFunction("maplen", lambda d: len(d), IntegerType())
+        [(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from test").collect()
         self.assertEqual(list(range(3)), l1)
         self.assertEqual(1, l2)
 
     def test_broadcast_in_udf(self):
         bar = {"a": "aa", "b": "bb", "c": "abc"}
         foo = self.sc.broadcast(bar)
-        self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
-        [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect()
+        self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
+        [res] = self.spark.sql("SELECT MYUDF('c')").collect()
         self.assertEqual("abc", res[0])
-        [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
+        [res] = self.spark.sql("SELECT MYUDF('')").collect()
         self.assertEqual("", res[0])
 
     def test_udf_with_aggregate_function(self):
-        df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
+        df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
         from pyspark.sql.functions import udf, col
         from pyspark.sql.types import BooleanType
 
@@ -355,7 +340,7 @@ class SQLTests(ReusedPySparkTestCase):
 
     def test_basic_functions(self):
         rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
-        df = self.sqlCtx.read.json(rdd)
+        df = self.spark.read.json(rdd)
         df.count()
         df.collect()
         df.schema
@@ -369,41 +354,41 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(2, df.count())
 
         df.registerTempTable("temp")
-        df = self.sqlCtx.sql("select foo from temp")
+        df = self.spark.sql("select foo from temp")
         df.count()
         df.collect()
 
     def test_apply_schema_to_row(self):
-        df = self.sqlCtx.read.json(self.sc.parallelize(["""{"a":2}"""]))
-        df2 = self.sqlCtx.createDataFrame(df.rdd.map(lambda x: x), df.schema)
+        df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""]))
+        df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema)
         self.assertEqual(df.collect(), df2.collect())
 
         rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
-        df3 = self.sqlCtx.createDataFrame(rdd, df.schema)
+        df3 = self.spark.createDataFrame(rdd, df.schema)
         self.assertEqual(10, df3.count())
 
     def test_infer_schema_to_local(self):
         input = [{"a": 1}, {"b": "coffee"}]
         rdd = self.sc.parallelize(input)
-        df = self.sqlCtx.createDataFrame(input)
-        df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0)
+        df = self.spark.createDataFrame(input)
+        df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
         self.assertEqual(df.schema, df2.schema)
 
         rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
-        df3 = self.sqlCtx.createDataFrame(rdd, df.schema)
+        df3 = self.spark.createDataFrame(rdd, df.schema)
         self.assertEqual(10, df3.count())
 
     def test_create_dataframe_schema_mismatch(self):
         input = [Row(a=1)]
         rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))
         schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())])
-        df = self.sqlCtx.createDataFrame(rdd, schema)
+        df = self.spark.createDataFrame(rdd, schema)
         self.assertRaises(Exception, lambda: df.show())
 
     def test_serialize_nested_array_and_map(self):
         d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
         rdd = self.sc.parallelize(d)
-        df = self.sqlCtx.createDataFrame(rdd)
+        df = self.spark.createDataFrame(rdd)
         row = df.head()
         self.assertEqual(1, len(row.l))
         self.assertEqual(1, row.l[0].a)
@@ -425,31 +410,31 @@ class SQLTests(ReusedPySparkTestCase):
         d = [Row(l=[], d={}, s=None),
              Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
         rdd = self.sc.parallelize(d)
-        df = self.sqlCtx.createDataFrame(rdd)
+        df = self.spark.createDataFrame(rdd)
         self.assertEqual([], df.rdd.map(lambda r: r.l).first())
         self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect())
         df.registerTempTable("test")
-        result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
+        result = self.spark.sql("SELECT l[0].a from test where d['key'].d = '2'")
         self.assertEqual(1, result.head()[0])
 
-        df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0)
+        df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0)
         self.assertEqual(df.schema, df2.schema)
         self.assertEqual({}, df2.rdd.map(lambda r: r.d).first())
         self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect())
         df2.registerTempTable("test2")
-        result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
+        result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
         self.assertEqual(1, result.head()[0])
 
     def test_infer_nested_schema(self):
         NestedRow = Row("f1", "f2")
         nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
                                           NestedRow([2, 3], {"row2": 2.0})])
-        df = self.sqlCtx.createDataFrame(nestedRdd1)
+        df = self.spark.createDataFrame(nestedRdd1)
         self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0])
 
         nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]),
                                           NestedRow([[2, 3], [3, 4]], [2, 3])])
-        df = self.sqlCtx.createDataFrame(nestedRdd2)
+        df = self.spark.createDataFrame(nestedRdd2)
         self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0])
 
         from collections import namedtuple
@@ -457,17 +442,17 @@ class SQLTests(ReusedPySparkTestCase):
         rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"),
                                    CustomRow(field1=2, field2="row2"),
                                    CustomRow(field1=3, field2="row3")])
-        df = self.sqlCtx.createDataFrame(rdd)
+        df = self.spark.createDataFrame(rdd)
         self.assertEqual(Row(field1=1, field2=u'row1'), df.first())
 
     def test_create_dataframe_from_objects(self):
         data = [MyObject(1, "1"), MyObject(2, "2")]
-        df = self.sqlCtx.createDataFrame(data)
+        df = self.spark.createDataFrame(data)
         self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")])
         self.assertEqual(df.first(), Row(key=1, value="1"))
 
     def test_select_null_literal(self):
-        df = self.sqlCtx.sql("select null as col")
+        df = self.spark.sql("select null as col")
         self.assertEqual(Row(col=None), df.first())
 
     def test_apply_schema(self):
@@ -488,7 +473,7 @@ class SQLTests(ReusedPySparkTestCase):
             StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
             StructField("list1", ArrayType(ByteType(), False), False),
             StructField("null1", DoubleType(), True)])
-        df = self.sqlCtx.createDataFrame(rdd, schema)
+        df = self.spark.createDataFrame(rdd, schema)
         results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1,
                              x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
         r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
@@ -496,9 +481,9 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(r, results.first())
 
         df.registerTempTable("table2")
-        r = self.sqlCtx.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
-                            "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
-                            "float1 + 1.5 as float1 FROM table2").first()
+        r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
+                           "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
+                           "float1 + 1.5 as float1 FROM table2").first()
 
         self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r))
 
@@ -508,7 +493,7 @@ class SQLTests(ReusedPySparkTestCase):
         abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]"
         schema = _parse_schema_abstract(abstract)
         typedSchema = _infer_schema_type(rdd.first(), schema)
-        df = self.sqlCtx.createDataFrame(rdd, typedSchema)
+        df = self.spark.createDataFrame(rdd, typedSchema)
         r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3])
         self.assertEqual(r, tuple(df.first()))
 
@@ -524,7 +509,7 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(1, row.asDict()['l'][0].a)
         df = self.sc.parallelize([row]).toDF()
         df.registerTempTable("test")
-        row = self.sqlCtx.sql("select l, d from test").head()
+        row = self.spark.sql("select l, d from test").head()
         self.assertEqual(1, row.asDict()["l"][0].a)
         self.assertEqual(1.0, row.asDict()['d']['key'].c)
 
@@ -535,7 +520,7 @@ class SQLTests(ReusedPySparkTestCase):
         def check_datatype(datatype):
             pickled = pickle.loads(pickle.dumps(datatype))
             assert datatype == pickled
-            scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json())
+            scala_datatype = self.spark._wrapped._ssql_ctx.parseDataType(datatype.json())
             python_datatype = _parse_datatype_json_string(scala_datatype.json())
             assert datatype == python_datatype
 
@@ -560,21 +545,21 @@ class SQLTests(ReusedPySparkTestCase):
     def test_infer_schema_with_udt(self):
         from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
-        df = self.sqlCtx.createDataFrame([row])
+        df = self.spark.createDataFrame([row])
         schema = df.schema
         field = [f for f in schema.fields if f.name == "point"][0]
         self.assertEqual(type(field.dataType), ExamplePointUDT)
         df.registerTempTable("labeled_point")
-        point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
+        point = self.spark.sql("SELECT point FROM labeled_point").head().point
         self.assertEqual(point, ExamplePoint(1.0, 2.0))
 
         row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
-        df = self.sqlCtx.createDataFrame([row])
+        df = self.spark.createDataFrame([row])
         schema = df.schema
         field = [f for f in schema.fields if f.name == "point"][0]
         self.assertEqual(type(field.dataType), PythonOnlyUDT)
         df.registerTempTable("labeled_point")
-        point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
+        point = self.spark.sql("SELECT point FROM labeled_point").head().point
         self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
 
     def test_apply_schema_with_udt(self):
@@ -582,21 +567,21 @@ class SQLTests(ReusedPySparkTestCase):
         row = (1.0, ExamplePoint(1.0, 2.0))
         schema = StructType([StructField("label", DoubleType(), False),
                              StructField("point", ExamplePointUDT(), False)])
-        df = self.sqlCtx.createDataFrame([row], schema)
+        df = self.spark.createDataFrame([row], schema)
         point = df.head().point
         self.assertEqual(point, ExamplePoint(1.0, 2.0))
 
         row = (1.0, PythonOnlyPoint(1.0, 2.0))
         schema = StructType([StructField("label", DoubleType(), False),
                              StructField("point", PythonOnlyUDT(), False)])
-        df = self.sqlCtx.createDataFrame([row], schema)
+        df = self.spark.createDataFrame([row], schema)
         point = df.head().point
         self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
 
     def test_udf_with_udt(self):
         from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
-        df = self.sqlCtx.createDataFrame([row])
+        df = self.spark.createDataFrame([row])
         self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
         udf = UserDefinedFunction(lambda p: p.y, DoubleType())
         self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
@@ -604,7 +589,7 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
 
         row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
-        df = self.sqlCtx.createDataFrame([row])
+        df = self.spark.createDataFrame([row])
         self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
         udf = UserDefinedFunction(lambda p: p.y, DoubleType())
         self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
@@ -614,17 +599,17 @@ class SQLTests(ReusedPySparkTestCase):
     def test_parquet_with_udt(self):
         from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
-        df0 = self.sqlCtx.createDataFrame([row])
+        df0 = self.spark.createDataFrame([row])
         output_dir = os.path.join(self.tempdir.name, "labeled_point")
         df0.write.parquet(output_dir)
-        df1 = self.sqlCtx.read.parquet(output_dir)
+        df1 = self.spark.read.parquet(output_dir)
         point = df1.head().point
         self.assertEqual(point, ExamplePoint(1.0, 2.0))
 
         row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
-        df0 = self.sqlCtx.createDataFrame([row])
+        df0 = self.spark.createDataFrame([row])
         df0.write.parquet(output_dir, mode='overwrite')
-        df1 = self.sqlCtx.read.parquet(output_dir)
+        df1 = self.spark.read.parquet(output_dir)
         point = df1.head().point
         self.assertEqual(point, PythonOnlyPoint(1.0, 2.0))
 
@@ -634,8 +619,8 @@ class SQLTests(ReusedPySparkTestCase):
         row2 = (2.0, ExamplePoint(3.0, 4.0))
         schema = StructType([StructField("label", DoubleType(), False),
                              StructField("point", ExamplePointUDT(), False)])
-        df1 = self.sqlCtx.createDataFrame([row1], schema)
-        df2 = self.sqlCtx.createDataFrame([row2], schema)
+        df1 = self.spark.createDataFrame([row1], schema)
+        df2 = self.spark.createDataFrame([row2], schema)
 
         result = df1.union(df2).orderBy("label").collect()
         self.assertEqual(
@@ -688,7 +673,7 @@ class SQLTests(ReusedPySparkTestCase):
 
     def test_first_last_ignorenulls(self):
         from pyspark.sql import functions
-        df = self.sqlCtx.range(0, 100)
+        df = self.spark.range(0, 100)
         df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id"))
         df3 = df2.select(functions.first(df2.id, False).alias('a'),
                          functions.first(df2.id, True).alias('b'),
@@ -829,36 +814,36 @@ class SQLTests(ReusedPySparkTestCase):
         schema = StructType([StructField("f1", StringType(), True, None),
                              StructField("f2", StringType(), True, {'a': None})])
         rdd = self.sc.parallelize([["a", "b"], ["c", "d"]])
-        self.sqlCtx.createDataFrame(rdd, schema)
+        self.spark.createDataFrame(rdd, schema)
 
     def test_save_and_load(self):
         df = self.df
         tmpPath = tempfile.mkdtemp()
         shutil.rmtree(tmpPath)
         df.write.json(tmpPath)
-        actual = self.sqlCtx.read.json(tmpPath)
+        actual = self.spark.read.json(tmpPath)
         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
 
         schema = StructType([StructField("value", StringType(), True)])
-        actual = self.sqlCtx.read.json(tmpPath, schema)
+        actual = self.spark.read.json(tmpPath, schema)
         self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
 
         df.write.json(tmpPath, "overwrite")
-        actual = self.sqlCtx.read.json(tmpPath)
+        actual = self.spark.read.json(tmpPath)
         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
 
         df.write.save(format="json", mode="overwrite", path=tmpPath,
                       noUse="this options will not be used in save.")
-        actual = self.sqlCtx.read.load(format="json", path=tmpPath,
-                                       noUse="this options will not be used in load.")
+        actual = self.spark.read.load(format="json", path=tmpPath,
+                                      noUse="this options will not be used in load.")
         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
 
-        defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
+        defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
                                                     "org.apache.spark.sql.parquet")
-        self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
-        actual = self.sqlCtx.read.load(path=tmpPath)
+        self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+        actual = self.spark.read.load(path=tmpPath)
         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-        self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+        self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
 
         csvpath = os.path.join(tempfile.mkdtemp(), 'data')
         df.write.option('quote', None).format('csv').save(csvpath)
@@ -870,36 +855,36 @@ class SQLTests(ReusedPySparkTestCase):
         tmpPath = tempfile.mkdtemp()
         shutil.rmtree(tmpPath)
         df.write.json(tmpPath)
-        actual = self.sqlCtx.read.json(tmpPath)
+        actual = self.spark.read.json(tmpPath)
         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
 
         schema = StructType([StructField("value", StringType(), True)])
-        actual = self.sqlCtx.read.json(tmpPath, schema)
+        actual = self.spark.read.json(tmpPath, schema)
         self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
 
         df.write.mode("overwrite").json(tmpPath)
-        actual = self.sqlCtx.read.json(tmpPath)
+        actual = self.spark.read.json(tmpPath)
         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
 
         df.write.mode("overwrite").options(noUse="this options will not be used in save.")\
                 .option("noUse", "this option will not be used in save.")\
                 .format("json").save(path=tmpPath)
         actual =\
-            self.sqlCtx.read.format("json")\
-                            .load(path=tmpPath, noUse="this options will not be used in load.")
+            self.spark.read.format("json")\
+                           .load(path=tmpPath, noUse="this options will not be used in load.")
         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
 
-        defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
+        defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
                                                     "org.apache.spark.sql.parquet")
-        self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
-        actual = self.sqlCtx.read.load(path=tmpPath)
+        self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+        actual = self.spark.read.load(path=tmpPath)
         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-        self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+        self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
 
         shutil.rmtree(tmpPath)
 
     def test_stream_trigger_takes_keyword_args(self):
-        df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
+        df = self.spark.read.format('text').stream('python/test_support/sql/streaming')
         try:
             df.write.trigger('5 seconds')
             self.fail("Should have thrown an exception")
@@ -909,7 +894,7 @@ class SQLTests(ReusedPySparkTestCase):
 
     def test_stream_read_options(self):
         schema = StructType([StructField("data", StringType(), False)])
-        df = self.sqlCtx.read.format('text').option('path', 'python/test_support/sql/streaming')\
+        df = self.spark.read.format('text').option('path', 'python/test_support/sql/streaming')\
             .schema(schema).stream()
         self.assertTrue(df.isStreaming)
         self.assertEqual(df.schema.simpleString(), "struct<data:string>")
@@ -917,15 +902,15 @@ class SQLTests(ReusedPySparkTestCase):
     def test_stream_read_options_overwrite(self):
         bad_schema = StructType([StructField("test", IntegerType(), False)])
         schema = StructType([StructField("data", StringType(), False)])
-        df = self.sqlCtx.read.format('csv').option('path', 'python/test_support/sql/fake') \
+        df = self.spark.read.format('csv').option('path', 'python/test_support/sql/fake') \
             .schema(bad_schema).stream(path='python/test_support/sql/streaming',
                                        schema=schema, format='text')
         self.assertTrue(df.isStreaming)
         self.assertEqual(df.schema.simpleString(), "struct<data:string>")
 
     def test_stream_save_options(self):
-        df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
-        for cq in self.sqlCtx.streams.active:
+        df = self.spark.read.format('text').stream('python/test_support/sql/streaming')
+        for cq in self.spark._wrapped.streams.active:
             cq.stop()
         tmpPath = tempfile.mkdtemp()
         shutil.rmtree(tmpPath)
@@ -948,8 +933,8 @@ class SQLTests(ReusedPySparkTestCase):
             shutil.rmtree(tmpPath)
 
     def test_stream_save_options_overwrite(self):
-        df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
-        for cq in self.sqlCtx.streams.active:
+        df = self.spark.read.format('text').stream('python/test_support/sql/streaming')
+        for cq in self.spark._wrapped.streams.active:
             cq.stop()
         tmpPath = tempfile.mkdtemp()
         shutil.rmtree(tmpPath)
@@ -977,8 +962,8 @@ class SQLTests(ReusedPySparkTestCase):
             shutil.rmtree(tmpPath)
 
     def test_stream_await_termination(self):
-        df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
-        for cq in self.sqlCtx.streams.active:
+        df = self.spark.read.format('text').stream('python/test_support/sql/streaming')
+        for cq in self.spark._wrapped.streams.active:
             cq.stop()
         tmpPath = tempfile.mkdtemp()
         shutil.rmtree(tmpPath)
@@ -1005,8 +990,8 @@ class SQLTests(ReusedPySparkTestCase):
             shutil.rmtree(tmpPath)
 
     def test_query_manager_await_termination(self):
-        df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming')
-        for cq in self.sqlCtx.streams.active:
+        df = self.spark.read.format('text').stream('python/test_support/sql/streaming')
+        for cq in self.spark._wrapped.streams.active:
             cq.stop()
         tmpPath = tempfile.mkdtemp()
         shutil.rmtree(tmpPath)
@@ -1018,13 +1003,13 @@ class SQLTests(ReusedPySparkTestCase):
         try:
             self.assertTrue(cq.isActive)
             try:
-                self.sqlCtx.streams.awaitAnyTermination("hello")
+                self.spark._wrapped.streams.awaitAnyTermination("hello")
                 self.fail("Expected a value exception")
             except ValueError:
                 pass
             now = time.time()
             # test should take at least 2 seconds
-            res = self.sqlCtx.streams.awaitAnyTermination(2.6)
+            res = self.spark._wrapped.streams.awaitAnyTermination(2.6)
             duration = time.time() - now
             self.assertTrue(duration >= 2)
             self.assertFalse(res)
@@ -1035,7 +1020,7 @@ class SQLTests(ReusedPySparkTestCase):
     def test_help_command(self):
         # Regression test for SPARK-5464
         rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
-        df = self.sqlCtx.read.json(rdd)
+        df = self.spark.read.json(rdd)
         # render_doc() reproduces the help() exception without printing output
         pydoc.render_doc(df)
         pydoc.render_doc(df.foo)
@@ -1051,7 +1036,7 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertRaises(TypeError, lambda: df[{}])
 
     def test_column_name_with_non_ascii(self):
-        df = self.sqlCtx.createDataFrame([(1,)], ["数量"])
+        df = self.spark.createDataFrame([(1,)], ["数量"])
         self.assertEqual(StructType([StructField("数量", LongType(), True)]), df.schema)
         self.assertEqual("DataFrame[数量: bigint]", str(df))
         self.assertEqual([("数量", 'bigint')], df.dtypes)
@@ -1084,7 +1069,7 @@ class SQLTests(ReusedPySparkTestCase):
         # this saving as Parquet caused issues as well.
         output_dir = os.path.join(self.tempdir.name, "infer_long_type")
         df.write.parquet(output_dir)
-        df1 = self.sqlCtx.read.parquet(output_dir)
+        df1 = self.spark.read.parquet(output_dir)
         self.assertEqual('a', df1.first().f1)
         self.assertEqual(100000000000000, df1.first().f2)
 
@@ -1100,7 +1085,7 @@ class SQLTests(ReusedPySparkTestCase):
         time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000)
         date = time.date()
         row = Row(date=date, time=time)
-        df = self.sqlCtx.createDataFrame([row])
+        df = self.spark.createDataFrame([row])
         self.assertEqual(1, df.filter(df.date == date).count())
         self.assertEqual(1, df.filter(df.time == time).count())
         self.assertEqual(0, df.filter(df.date > date).count())
@@ -1110,7 +1095,7 @@ class SQLTests(ReusedPySparkTestCase):
         dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0))
         dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1))
         row = Row(date=dt1)
-        df = self.sqlCtx.createDataFrame([row])
+        df = self.spark.createDataFrame([row])
         self.assertEqual(0, df.filter(df.date == dt2).count())
         self.assertEqual(1, df.filter(df.date > dt2).count())
         self.assertEqual(0, df.filter(df.date < dt2).count())
@@ -1125,7 +1110,7 @@ class SQLTests(ReusedPySparkTestCase):
         utcnow = datetime.datetime.utcfromtimestamp(ts)  # without microseconds
         # add microseconds to utcnow (keeping year,month,day,hour,minute,second)
         utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc)))
-        df = self.sqlCtx.createDataFrame([(day, now, utcnow)])
+        df = self.spark.createDataFrame([(day, now, utcnow)])
         day1, now1, utcnow1 = df.first()
         self.assertEqual(day1, day)
         self.assertEqual(now, now1)
@@ -1134,13 +1119,13 @@ class SQLTests(ReusedPySparkTestCase):
     def test_decimal(self):
         from decimal import Decimal
         schema = StructType([StructField("decimal", DecimalType(10, 5))])
-        df = self.sqlCtx.createDataFrame([(Decimal("3.14159"),)], schema)
+        df = self.spark.createDataFrame([(Decimal("3.14159"),)], schema)
         row = df.select(df.decimal + 1).first()
         self.assertEqual(row[0], Decimal("4.14159"))
         tmpPath = tempfile.mkdtemp()
         shutil.rmtree(tmpPath)
         df.write.parquet(tmpPath)
-        df2 = self.sqlCtx.read.parquet(tmpPath)
+        df2 = self.spark.read.parquet(tmpPath)
         row = df2.first()
         self.assertEqual(row[0], Decimal("3.14159"))
 
@@ -1151,52 +1136,52 @@ class SQLTests(ReusedPySparkTestCase):
             StructField("height", DoubleType(), True)])
 
         # shouldn't drop a non-null row
-        self.assertEqual(self.sqlCtx.createDataFrame(
+        self.assertEqual(self.spark.createDataFrame(
             [(u'Alice', 50, 80.1)], schema).dropna().count(),
             1)
 
         # dropping rows with a single null value
-        self.assertEqual(self.sqlCtx.createDataFrame(
+        self.assertEqual(self.spark.createDataFrame(
             [(u'Alice', None, 80.1)], schema).dropna().count(),
             0)
-        self.assertEqual(self.sqlCtx.createDataFrame(
+        self.assertEqual(self.spark.createDataFrame(
             [(u'Alice', None, 80.1)], schema).dropna(how='any').count(),
             0)
 
         # if how = 'all', only drop rows if all values are null
-        self.assertEqual(self.sqlCtx.createDataFrame(
+        self.assertEqual(self.spark.createDataFrame(
             [(u'Alice', None, 80.1)], schema).dropna(how='all').count(),
             1)
-        self.assertEqual(self.sqlCtx.createDataFrame(
+        self.assertEqual(self.spark.createDataFrame(
             [(None, None, None)], schema).dropna(how='all').count(),
             0)
 
         # how and subset
-        self.assertEqual(self.sqlCtx.createDataFrame(
+        self.assertEqual(self.spark.createDataFrame(
             [(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
             1)
-        self.assertEqual(self.sqlCtx.createDataFrame(
+        self.assertEqual(self.spark.createDataFrame(
             [(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(),
             0)
 
         # threshold
-        self.assertEqual(self.sqlCtx.createDataFrame(
+        self.assertEqual(self.spark.createDataFrame(
             [(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(),
             1)
-        self.assertEqual(self.sqlCtx.createDataFrame(
+        self.assertEqual(self.spark.createDataFrame(
             [(u'Alice', None, None)], schema).dropna(thresh=2).count(),
             0)
 
         # threshold and subset
-        self.assertEqual(self.sqlCtx.createDataFrame(
+        self.assertEqual(self.spark.createDataFrame(
             [(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
             1)
-        self.assertEqual(self.sqlCtx.createDataFrame(
+        self.assertEqual(self.spark.createDataFrame(
             [(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(),
             0)
 
         # thresh should take precedence over how
-        self.assertEqual(self.sqlCtx.createDataFrame(
+        self.assertEqual(self.spark.createDataFrame(
             [(u'Alice', 50, None)], schema).dropna(
                 how='any', thresh=2, subset=['name', 'age']).count(),
             1)
@@ -1208,33 +1193,33 @@ class SQLTests(ReusedPySparkTestCase):
             StructField("height", DoubleType(), True)])
 
         # fillna shouldn't change non-null values
-        row = self.sqlCtx.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first()
+        row = self.spark.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first()
         self.assertEqual(row.age, 10)
 
         # fillna with int
-        row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first()
+        row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first()
         self.assertEqual(row.age, 50)
         self.assertEqual(row.height, 50.0)
 
         # fillna with double
-        row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first()
+        row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first()
         self.assertEqual(row.age, 50)
         self.assertEqual(row.height, 50.1)
 
         # fillna with string
-        row = self.sqlCtx.createDataFrame([(None, None, None)], schema).fillna("hello").first()
+        row = self.spark.createDataFrame([(None, None, None)], schema).fillna("hello").first()
         self.assertEqual(row.name, u"hello")
         self.assertEqual(row.age, None)
 
         # fillna with subset specified for numeric cols
-        row = self.sqlCtx.createDataFrame(
+        row = self.spark.createDataFrame(
             [(None, None, None)], schema).fillna(50, subset=['name', 'age']).first()
         self.assertEqual(row.name, None)
         self.assertEqual(row.age, 50)
         self.assertEqual(row.height, None)
 
         # fillna with subset specified for numeric cols
-        row = self.sqlCtx.createDataFrame(
+        row = self.spark.createDataFrame(
             [(None, None, None)], schema).fillna("haha", subset=['name', 'age']).first()
         self.assertEqual(row.name, "haha")
         self.assertEqual(row.age, None)
@@ -1243,7 +1228,7 @@ class SQLTests(ReusedPySparkTestCase):
     def test_bitwise_operations(self):
         from pyspark.sql import functions
         row = Row(a=170, b=75)
-        df = self.sqlCtx.createDataFrame([row])
+        df = self.spark.createDataFrame([row])
         result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict()
         self.assertEqual(170 & 75, result['(a & b)'])
         result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict()
@@ -1256,7 +1241,7 @@ class SQLTests(ReusedPySparkTestCase):
     def test_expr(self):
         from pyspark.sql import functions
         row = Row(a="length string", b=75)
-        df = self.sqlCtx.createDataFrame([row])
+        df = self.spark.createDataFrame([row])
         result = df.select(functions.expr("length(a)")).collect()[0].asDict()
         self.assertEqual(13, result["length(a)"])
 
@@ -1267,58 +1252,58 @@ class SQLTests(ReusedPySparkTestCase):
             StructField("height", DoubleType(), True)])
 
         # replace with int
-        row = self.sqlCtx.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first()
+        row = self.spark.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first()
         self.assertEqual(row.age, 20)
         self.assertEqual(row.height, 20.0)
 
         # replace with double
-        row = self.sqlCtx.createDataFrame(
+        row = self.spark.createDataFrame(
             [(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first()
         self.assertEqual(row.age, 82)
         self.assertEqual(row.height, 82.1)
 
         # replace with string
-        row = self.sqlCtx.createDataFrame(
+        row = self.spark.createDataFrame(
             [(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first()
         self.assertEqual(row.name, u"Ann")
         self.assertEqual(row.age, 10)
 
         # replace with subset specified by a string of a column name w/ actual change
-        row = self.sqlCtx.createDataFrame(
+        row = self.spark.createDataFrame(
             [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first()
         self.assertEqual(row.age, 20)
 
         # replace with subset specified by a string of a column name w/o actual change
-        row = self.sqlCtx.createDataFrame(
+        row = self.spark.createDataFrame(
             [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first()
         self.assertEqual(row.age, 10)
 
         # replace with subset specified with one column replaced, another column not in subset
         # stays unchanged.
-        row = self.sqlCtx.createDataFrame(
+        row = self.spark.createDataFrame(
             [(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first()
         self.assertEqual(row.name, u'Alice')
         self.assertEqual(row.age, 20)
         self.assertEqual(row.height, 10.0)
 
         # replace with subset specified but no column will be replaced
-        row = self.sqlCtx.createDataFrame(
+        row = self.spark.createDataFrame(
             [(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first()
         self.assertEqual(row.name, u'Alice')
         self.assertEqual(row.age, 10)
         self.assertEqual(row.height, None)
 
     def test_capture_analysis_exception(self):
-        self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc"))
+        self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc"))
         self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
 
     def test_capture_parse_exception(self):
-        self.assertRaises(ParseException, lambda: self.sqlCtx.sql("abc"))
+        self.assertRaises(ParseException, lambda: self.spark.sql("abc"))
 
     def test_capture_illegalargument_exception(self):
         self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks",
-                                lambda: self.sqlCtx.sql("SET mapred.reduce.tasks=-1"))
-        df = self.sqlCtx.createDataFrame([(1, 2)], ["a", "b"])
+                                lambda: self.spark.sql("SET mapred.reduce.tasks=-1"))
+        df = self.spark.createDataFrame([(1, 2)], ["a", "b"])
         self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values",
                                 lambda: df.select(sha2(df.a, 1024)).collect())
         try:
@@ -1345,8 +1330,8 @@ class SQLTests(ReusedPySparkTestCase):
     def test_functions_broadcast(self):
         from pyspark.sql.functions import broadcast
 
-        df1 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
-        df2 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
+        df1 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
+        df2 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value"))
 
         # equijoin - should be converted into broadcast join
         plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan()
@@ -1396,9 +1381,9 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(df.collect(), [Row(key=i) for i in range(100)])
 
     def test_conf(self):
-        spark = self.sparkSession
+        spark = self.spark
         spark.conf.set("bogo", "sipeo")
-        self.assertEqual(self.sparkSession.conf.get("bogo"), "sipeo")
+        self.assertEqual(spark.conf.get("bogo"), "sipeo")
         spark.conf.set("bogo", "ta")
         self.assertEqual(spark.conf.get("bogo"), "ta")
         self.assertEqual(spark.conf.get("bogo", "not.read"), "ta")
@@ -1408,7 +1393,7 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia")
 
     def test_current_database(self):
-        spark = self.sparkSession
+        spark = self.spark
         spark.catalog._reset()
         self.assertEquals(spark.catalog.currentDatabase(), "default")
         spark.sql("CREATE DATABASE some_db")
@@ -1420,7 +1405,7 @@ class SQLTests(ReusedPySparkTestCase):
             lambda: spark.catalog.setCurrentDatabase("does_not_exist"))
 
     def test_list_databases(self):
-        spark = self.sparkSession
+        spark = self.spark
         spark.catalog._reset()
         databases = [db.name for db in spark.catalog.listDatabases()]
         self.assertEquals(databases, ["default"])
@@ -1430,7 +1415,7 @@ class SQLTests(ReusedPySparkTestCase):
 
     def test_list_tables(self):
         from pyspark.sql.catalog import Table
-        spark = self.sparkSession
+        spark = self.spark
         spark.catalog._reset()
         spark.sql("CREATE DATABASE some_db")
         self.assertEquals(spark.catalog.listTables(), [])
@@ -1475,7 +1460,7 @@ class SQLTests(ReusedPySparkTestCase):
 
     def test_list_functions(self):
         from pyspark.sql.catalog import Function
-        spark = self.sparkSession
+        spark = self.spark
         spark.catalog._reset()
         spark.sql("CREATE DATABASE some_db")
         functions = dict((f.name, f) for f in spark.catalog.listFunctions())
@@ -1512,7 +1497,7 @@ class SQLTests(ReusedPySparkTestCase):
 
     def test_list_columns(self):
         from pyspark.sql.catalog import Column
-        spark = self.sparkSession
+        spark = self.spark
         spark.catalog._reset()
         spark.sql("CREATE DATABASE some_db")
         spark.sql("CREATE TABLE tab1 (name STRING, age INT)")
@@ -1561,7 +1546,7 @@ class SQLTests(ReusedPySparkTestCase):
             lambda: spark.catalog.listColumns("does_not_exist"))
 
     def test_cache(self):
-        spark = self.sparkSession
+        spark = self.spark
         spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab1")
         spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab2")
         self.assertFalse(spark.catalog.isCached("tab1"))
@@ -1605,7 +1590,7 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
             cls.tearDownClass()
             raise unittest.SkipTest("Hive is not available")
         os.unlink(cls.tempdir.name)
-        cls.sqlCtx = HiveContext._createForTesting(cls.sc)
+        cls.spark = HiveContext._createForTesting(cls.sc)
         cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
         cls.df = cls.sc.parallelize(cls.testData).toDF()
 
@@ -1619,45 +1604,45 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
         tmpPath = tempfile.mkdtemp()
         shutil.rmtree(tmpPath)
         df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath)
-        actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, "json")
+        actual = self.spark.createExternalTable("externalJsonTable", tmpPath, "json")
         self.assertEqual(sorted(df.collect()),
-                         sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+                         sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()))
         self.assertEqual(sorted(df.collect()),
-                         sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+                         sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()))
         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-        self.sqlCtx.sql("DROP TABLE externalJsonTable")
+        self.spark.sql("DROP TABLE externalJsonTable")
 
         df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath)
         schema = StructType([StructField("value", StringType(), True)])
-        actual = self.sqlCtx.createExternalTable("externalJsonTable", source="json",
-                                                 schema=schema, path=tmpPath,
-                                                 noUse="this options will not be used")
+        actual = self.spark.createExternalTable("externalJsonTable", source="json",
+                                                schema=schema, path=tmpPath,
+                                                noUse="this options will not be used")
         self.assertEqual(sorted(df.collect()),
-                         sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+                         sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()))
         self.assertEqual(sorted(df.select("value").collect()),
-                         sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+                         sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()))
         self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
-        self.sqlCtx.sql("DROP TABLE savedJsonTable")
-        self.sqlCtx.sql("DROP TABLE externalJsonTable")
+        self.spark.sql("DROP TABLE savedJsonTable")
+        self.spark.sql("DROP TABLE externalJsonTable")
 
-        defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
-                                                    "org.apache.spark.sql.parquet")
-        self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+        defaultDataSourceName = self.spark.getConf("spark.sql.sources.default",
+                                                   "org.apache.spark.sql.parquet")
+        self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
         df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
-        actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath)
+        actual = self.spark.createExternalTable("externalJsonTable", path=tmpPath)
         self.assertEqual(sorted(df.collect()),
-                         sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+                         sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()))
         self.assertEqual(sorted(df.collect()),
-                         sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+                         sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()))
         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-        self.sqlCtx.sql("DROP TABLE savedJsonTable")
-        self.sqlCtx.sql("DROP TABLE externalJsonTable")
-        self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+        self.spark.sql("DROP TABLE savedJsonTable")
+        self.spark.sql("DROP TABLE externalJsonTable")
+        self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
 
         shutil.rmtree(tmpPath)
 
     def test_window_functions(self):
-        df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
+        df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
         w = Window.partitionBy("value").orderBy("key")
         from pyspark.sql import functions as F
         sel = df.select(df.value, df.key,
@@ -1679,7 +1664,7 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
             self.assertEqual(tuple(r), ex[:len(r)])
 
     def test_window_functions_without_partitionBy(self):
-        df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
+        df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
         w = Window.orderBy("key", df.value)
         from pyspark.sql import functions as F
         sel = df.select(df.value, df.key,
@@ -1701,7 +1686,7 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
             self.assertEqual(tuple(r), ex[:len(r)])
 
     def test_collect_functions(self):
-        df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
+        df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
         from pyspark.sql import functions
 
         self.assertEqual(