diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index ad1631fb5baa50f36d69e53983a5f87c513ab7e1..49d3a4a332fd19390dce1d67eed2b076b7ea8ef4 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -57,13 +57,25 @@ from pyspark.ml.tuning import * from pyspark.ml.wrapper import JavaParams from pyspark.mllib.common import _java2py from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector -from pyspark.sql import DataFrame, SQLContext, Row +from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql.functions import rand from pyspark.sql.utils import IllegalArgumentException from pyspark.storagelevel import * from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase +class SparkSessionTestCase(PySparkTestCase): + @classmethod + def setUpClass(cls): + PySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + PySparkTestCase.tearDownClass() + cls.spark.stop() + + class MockDataset(DataFrame): def __init__(self): @@ -350,7 +362,7 @@ class ParamTests(PySparkTestCase): self.assertEqual(model.getWindowSize(), 6) -class FeatureTests(PySparkTestCase): +class FeatureTests(SparkSessionTestCase): def test_binarizer(self): b0 = Binarizer() @@ -376,8 +388,7 @@ class FeatureTests(PySparkTestCase): self.assertEqual(b1.getOutputCol(), "output") def test_idf(self): - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame([ + dataset = self.spark.createDataFrame([ (DenseVector([1.0, 2.0]),), (DenseVector([0.0, 1.0]),), (DenseVector([3.0, 0.2]),)], ["tf"]) @@ -390,8 +401,7 @@ class FeatureTests(PySparkTestCase): self.assertIsNotNone(output.head().idf) def test_ngram(self): - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame([ + dataset = self.spark.createDataFrame([ Row(input=["a", "b", "c", "d", "e"])]) ngram0 = NGram(n=4, inputCol="input", outputCol="output") self.assertEqual(ngram0.getN(), 4) @@ -401,8 +411,7 @@ class FeatureTests(PySparkTestCase): self.assertEqual(transformedDF.head().output, ["a b c d", "b c d e"]) def test_stopwordsremover(self): - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame([Row(input=["a", "panda"])]) + dataset = self.spark.createDataFrame([Row(input=["a", "panda"])]) stopWordRemover = StopWordsRemover(inputCol="input", outputCol="output") # Default self.assertEqual(stopWordRemover.getInputCol(), "input") @@ -419,15 +428,14 @@ class FeatureTests(PySparkTestCase): self.assertEqual(transformedDF.head().output, ["a"]) # with language selection stopwords = StopWordsRemover.loadDefaultStopWords("turkish") - dataset = sqlContext.createDataFrame([Row(input=["acaba", "ama", "biri"])]) + dataset = self.spark.createDataFrame([Row(input=["acaba", "ama", "biri"])]) stopWordRemover.setStopWords(stopwords) self.assertEqual(stopWordRemover.getStopWords(), stopwords) transformedDF = stopWordRemover.transform(dataset) self.assertEqual(transformedDF.head().output, []) def test_count_vectorizer_with_binary(self): - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame([ + dataset = self.spark.createDataFrame([ (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), (1, "a a".split(' '), SparseVector(3, {0: 1.0}),), (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), @@ -475,11 +483,10 @@ class InducedErrorEstimator(Estimator, HasInducedError): return model -class CrossValidatorTests(PySparkTestCase): +class CrossValidatorTests(SparkSessionTestCase): def test_copy(self): - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame([ + dataset = self.spark.createDataFrame([ (10, 10.0), (50, 50.0), (100, 100.0), @@ -503,8 +510,7 @@ class CrossValidatorTests(PySparkTestCase): < 0.0001) def test_fit_minimize_metric(self): - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame([ + dataset = self.spark.createDataFrame([ (10, 10.0), (50, 50.0), (100, 100.0), @@ -527,8 +533,7 @@ class CrossValidatorTests(PySparkTestCase): self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") def test_fit_maximize_metric(self): - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame([ + dataset = self.spark.createDataFrame([ (10, 10.0), (50, 50.0), (100, 100.0), @@ -554,8 +559,7 @@ class CrossValidatorTests(PySparkTestCase): # This tests saving and loading the trained model only. # Save/load for CrossValidator will be added later: SPARK-13786 temp_path = tempfile.mkdtemp() - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame( + dataset = self.spark.createDataFrame( [(Vectors.dense([0.0]), 0.0), (Vectors.dense([0.4]), 1.0), (Vectors.dense([0.5]), 0.0), @@ -576,11 +580,10 @@ class CrossValidatorTests(PySparkTestCase): self.assertEqual(loadedLrModel.intercept, lrModel.intercept) -class TrainValidationSplitTests(PySparkTestCase): +class TrainValidationSplitTests(SparkSessionTestCase): def test_fit_minimize_metric(self): - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame([ + dataset = self.spark.createDataFrame([ (10, 10.0), (50, 50.0), (100, 100.0), @@ -603,8 +606,7 @@ class TrainValidationSplitTests(PySparkTestCase): self.assertEqual(0.0, bestModelMetric, "Best model has RMSE of 0") def test_fit_maximize_metric(self): - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame([ + dataset = self.spark.createDataFrame([ (10, 10.0), (50, 50.0), (100, 100.0), @@ -630,8 +632,7 @@ class TrainValidationSplitTests(PySparkTestCase): # This tests saving and loading the trained model only. # Save/load for TrainValidationSplit will be added later: SPARK-13786 temp_path = tempfile.mkdtemp() - sqlContext = SQLContext(self.sc) - dataset = sqlContext.createDataFrame( + dataset = self.spark.createDataFrame( [(Vectors.dense([0.0]), 0.0), (Vectors.dense([0.4]), 1.0), (Vectors.dense([0.5]), 0.0), @@ -652,7 +653,7 @@ class TrainValidationSplitTests(PySparkTestCase): self.assertEqual(loadedLrModel.intercept, lrModel.intercept) -class PersistenceTest(PySparkTestCase): +class PersistenceTest(SparkSessionTestCase): def test_linear_regression(self): lr = LinearRegression(maxIter=1) @@ -724,11 +725,10 @@ class PersistenceTest(PySparkTestCase): """ Pipeline[HashingTF, PCA] """ - sqlContext = SQLContext(self.sc) temp_path = tempfile.mkdtemp() try: - df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) + df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") pca = PCA(k=2, inputCol="features", outputCol="pca_features") pl = Pipeline(stages=[tf, pca]) @@ -753,11 +753,10 @@ class PersistenceTest(PySparkTestCase): """ Pipeline[HashingTF, Pipeline[PCA]] """ - sqlContext = SQLContext(self.sc) temp_path = tempfile.mkdtemp() try: - df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) + df = self.spark.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") pca = PCA(k=2, inputCol="features", outputCol="pca_features") p0 = Pipeline(stages=[pca]) @@ -816,7 +815,7 @@ class PersistenceTest(PySparkTestCase): pass -class LDATest(PySparkTestCase): +class LDATest(SparkSessionTestCase): def _compare(self, m1, m2): """ @@ -836,8 +835,7 @@ class LDATest(PySparkTestCase): def test_persistence(self): # Test save/load for LDA, LocalLDAModel, DistributedLDAModel. - sqlContext = SQLContext(self.sc) - df = sqlContext.createDataFrame([ + df = self.spark.createDataFrame([ [1, Vectors.dense([0.0, 1.0])], [2, Vectors.sparse(2, {0: 1.0})], ], ["id", "features"]) @@ -871,12 +869,11 @@ class LDATest(PySparkTestCase): pass -class TrainingSummaryTest(PySparkTestCase): +class TrainingSummaryTest(SparkSessionTestCase): def test_linear_regression_summary(self): from pyspark.mllib.linalg import Vectors - sqlContext = SQLContext(self.sc) - df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"]) lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight", @@ -914,8 +911,7 @@ class TrainingSummaryTest(PySparkTestCase): def test_logistic_regression_summary(self): from pyspark.mllib.linalg import Vectors - sqlContext = SQLContext(self.sc) - df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + df = self.spark.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"]) lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) @@ -942,11 +938,10 @@ class TrainingSummaryTest(PySparkTestCase): self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) -class OneVsRestTests(PySparkTestCase): +class OneVsRestTests(SparkSessionTestCase): def test_copy(self): - sqlContext = SQLContext(self.sc) - df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), (1.0, Vectors.sparse(2, [], [])), (2.0, Vectors.dense(0.5, 0.5))], ["label", "features"]) @@ -960,8 +955,7 @@ class OneVsRestTests(PySparkTestCase): self.assertEqual(model1.getPredictionCol(), "indexed") def test_output_columns(self): - sqlContext = SQLContext(self.sc) - df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), (1.0, Vectors.sparse(2, [], [])), (2.0, Vectors.dense(0.5, 0.5))], ["label", "features"]) @@ -973,8 +967,7 @@ class OneVsRestTests(PySparkTestCase): def test_save_load(self): temp_path = tempfile.mkdtemp() - sqlContext = SQLContext(self.sc) - df = sqlContext.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), + df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)), (1.0, Vectors.sparse(2, [], [])), (2.0, Vectors.dense(0.5, 0.5))], ["label", "features"]) @@ -994,12 +987,11 @@ class OneVsRestTests(PySparkTestCase): self.assertEqual(m.uid, n.uid) -class HashingTFTest(PySparkTestCase): +class HashingTFTest(SparkSessionTestCase): def test_apply_binary_term_freqs(self): - sqlContext = SQLContext(self.sc) - df = sqlContext.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"]) + df = self.spark.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"]) n = 10 hashingTF = HashingTF() hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True) @@ -1011,11 +1003,10 @@ class HashingTFTest(PySparkTestCase): ": expected " + str(expected[i]) + ", got " + str(features[i])) -class ALSTest(PySparkTestCase): +class ALSTest(SparkSessionTestCase): def test_storage_levels(self): - sqlContext = SQLContext(self.sc) - df = sqlContext.createDataFrame( + df = self.spark.createDataFrame( [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)], ["user", "item", "rating"]) als = ALS().setMaxIter(1).setRank(1) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 53a1d2c59cb2f5a346a43746e43174a407f67428..74cf7bb8eaf9d423754fcab1a5b4961cac6282a5 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -66,7 +66,8 @@ from pyspark.mllib.util import LinearDataGenerator from pyspark.mllib.util import MLUtils from pyspark.serializers import PickleSerializer from pyspark.streaming import StreamingContext -from pyspark.sql import SQLContext +from pyspark.sql import SparkSession +from pyspark.sql.utils import IllegalArgumentException from pyspark.streaming import StreamingContext _have_scipy = False @@ -83,9 +84,10 @@ ser = PickleSerializer() class MLlibTestCase(unittest.TestCase): def setUp(self): self.sc = SparkContext('local[4]', "MLlib tests") + self.spark = SparkSession(self.sc) def tearDown(self): - self.sc.stop() + self.spark.stop() class MLLibStreamingTestCase(unittest.TestCase): @@ -698,7 +700,6 @@ class VectorUDTTests(MLlibTestCase): self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) def test_infer_schema(self): - sqlCtx = SQLContext(self.sc) rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) df = rdd.toDF() schema = df.schema @@ -731,7 +732,6 @@ class MatrixUDTTests(MLlibTestCase): self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m))) def test_infer_schema(self): - sqlCtx = SQLContext(self.sc) rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)]) df = rdd.toDF() schema = df.schema @@ -919,7 +919,7 @@ class ChiSqTestTests(MLlibTestCase): # Negative counts in observed neg_obs = Vectors.dense([1.0, 2.0, 3.0, -4.0]) - self.assertRaises(Py4JJavaError, Statistics.chiSqTest, neg_obs, expected1) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_obs, expected1) # Count = 0.0 in expected but not observed zero_expected = Vectors.dense([1.0, 0.0, 3.0]) @@ -930,7 +930,8 @@ class ChiSqTestTests(MLlibTestCase): # 0.0 in expected and observed simultaneously zero_observed = Vectors.dense([2.0, 0.0, 1.0]) - self.assertRaises(Py4JJavaError, Statistics.chiSqTest, zero_observed, zero_expected) + self.assertRaises( + IllegalArgumentException, Statistics.chiSqTest, zero_observed, zero_expected) def test_matrix_independence(self): data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0] @@ -944,15 +945,15 @@ class ChiSqTestTests(MLlibTestCase): # Negative counts neg_counts = Matrices.dense(2, 2, [4.0, 5.0, 3.0, -3.0]) - self.assertRaises(Py4JJavaError, Statistics.chiSqTest, neg_counts) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, neg_counts) # Row sum = 0.0 row_zero = Matrices.dense(2, 2, [0.0, 1.0, 0.0, 2.0]) - self.assertRaises(Py4JJavaError, Statistics.chiSqTest, row_zero) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, row_zero) # Column sum = 0.0 col_zero = Matrices.dense(2, 2, [0.0, 0.0, 2.0, 2.0]) - self.assertRaises(Py4JJavaError, Statistics.chiSqTest, col_zero) + self.assertRaises(IllegalArgumentException, Statistics.chiSqTest, col_zero) def test_chi_sq_pearson(self): data = [ diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 7e79df33e85f5dab5b03cd48002a7f68943746a2..bd728c97c82a8b4ec9b33c351bba14955ca37390 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -47,19 +47,19 @@ def to_str(value): class DataFrameReader(object): """ Interface used to load a :class:`DataFrame` from external storage systems - (e.g. file systems, key-value stores, etc). Use :func:`SQLContext.read` + (e.g. file systems, key-value stores, etc). Use :func:`spark.read` to access this. .. versionadded:: 1.4 """ - def __init__(self, sqlContext): - self._jreader = sqlContext._ssql_ctx.read() - self._sqlContext = sqlContext + def __init__(self, spark): + self._jreader = spark._ssql_ctx.read() + self._spark = spark def _df(self, jdf): from pyspark.sql.dataframe import DataFrame - return DataFrame(jdf, self._sqlContext) + return DataFrame(jdf, self._spark) @since(1.4) def format(self, source): @@ -67,7 +67,7 @@ class DataFrameReader(object): :param source: string, name of the data source, e.g. 'json', 'parquet'. - >>> df = sqlContext.read.format('json').load('python/test_support/sql/people.json') + >>> df = spark.read.format('json').load('python/test_support/sql/people.json') >>> df.dtypes [('age', 'bigint'), ('name', 'string')] @@ -87,7 +87,7 @@ class DataFrameReader(object): """ if not isinstance(schema, StructType): raise TypeError("schema should be StructType") - jschema = self._sqlContext._ssql_ctx.parseDataType(schema.json()) + jschema = self._spark._ssql_ctx.parseDataType(schema.json()) self._jreader = self._jreader.schema(jschema) return self @@ -115,12 +115,12 @@ class DataFrameReader(object): :param schema: optional :class:`StructType` for the input schema. :param options: all other string options - >>> df = sqlContext.read.load('python/test_support/sql/parquet_partitioned', opt1=True, + >>> df = spark.read.load('python/test_support/sql/parquet_partitioned', opt1=True, ... opt2=1, opt3='str') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] - >>> df = sqlContext.read.format('json').load(['python/test_support/sql/people.json', + >>> df = spark.read.format('json').load(['python/test_support/sql/people.json', ... 'python/test_support/sql/people1.json']) >>> df.dtypes [('age', 'bigint'), ('aka', 'string'), ('name', 'string')] @@ -133,7 +133,7 @@ class DataFrameReader(object): if path is not None: if type(path) != list: path = [path] - return self._df(self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) + return self._df(self._jreader.load(self._spark._sc._jvm.PythonUtils.toSeq(path))) else: return self._df(self._jreader.load()) @@ -148,7 +148,7 @@ class DataFrameReader(object): :param schema: optional :class:`StructType` for the input schema. :param options: all other string options - >>> df = sqlContext.read.format('text').stream('python/test_support/sql/streaming') + >>> df = spark.read.format('text').stream('python/test_support/sql/streaming') >>> df.isStreaming True """ @@ -211,11 +211,11 @@ class DataFrameReader(object): ``spark.sql.columnNameOfCorruptRecord``. If None is set, it uses the default value ``_corrupt_record``. - >>> df1 = sqlContext.read.json('python/test_support/sql/people.json') + >>> df1 = spark.read.json('python/test_support/sql/people.json') >>> df1.dtypes [('age', 'bigint'), ('name', 'string')] >>> rdd = sc.textFile('python/test_support/sql/people.json') - >>> df2 = sqlContext.read.json(rdd) + >>> df2 = spark.read.json(rdd) >>> df2.dtypes [('age', 'bigint'), ('name', 'string')] @@ -243,7 +243,7 @@ class DataFrameReader(object): if isinstance(path, basestring): path = [path] if type(path) == list: - return self._df(self._jreader.json(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) + return self._df(self._jreader.json(self._spark._sc._jvm.PythonUtils.toSeq(path))) elif isinstance(path, RDD): def func(iterator): for x in iterator: @@ -254,7 +254,7 @@ class DataFrameReader(object): yield x keyed = path.mapPartitions(func) keyed._bypass_serializer = True - jrdd = keyed._jrdd.map(self._sqlContext._jvm.BytesToString()) + jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString()) return self._df(self._jreader.json(jrdd)) else: raise TypeError("path can be only string or RDD") @@ -265,9 +265,9 @@ class DataFrameReader(object): :param tableName: string, name of the table. - >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned') + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.registerTempTable('tmpTable') - >>> sqlContext.read.table('tmpTable').dtypes + >>> spark.read.table('tmpTable').dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ return self._df(self._jreader.table(tableName)) @@ -276,11 +276,11 @@ class DataFrameReader(object): def parquet(self, *paths): """Loads a Parquet file, returning the result as a :class:`DataFrame`. - >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned') + >>> df = spark.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ - return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, paths))) + return self._df(self._jreader.parquet(_to_seq(self._spark._sc, paths))) @ignore_unicode_prefix @since(1.6) @@ -291,13 +291,13 @@ class DataFrameReader(object): :param paths: string, or list of strings, for input path(s). - >>> df = sqlContext.read.text('python/test_support/sql/text-test.txt') + >>> df = spark.read.text('python/test_support/sql/text-test.txt') >>> df.collect() [Row(value=u'hello'), Row(value=u'this')] """ if isinstance(paths, basestring): path = [paths] - return self._df(self._jreader.text(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) + return self._df(self._jreader.text(self._spark._sc._jvm.PythonUtils.toSeq(path))) @since(2.0) def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None, @@ -356,7 +356,7 @@ class DataFrameReader(object): * ``DROPMALFORMED`` : ignores the whole corrupted records. * ``FAILFAST`` : throws an exception when it meets corrupted records. - >>> df = sqlContext.read.csv('python/test_support/sql/ages.csv') + >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes [('C0', 'string'), ('C1', 'string')] """ @@ -396,7 +396,7 @@ class DataFrameReader(object): self.option("mode", mode) if isinstance(path, basestring): path = [path] - return self._df(self._jreader.csv(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) + return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) @since(1.5) def orc(self, path): @@ -441,16 +441,16 @@ class DataFrameReader(object): """ if properties is None: properties = dict() - jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() + jprop = JavaClass("java.util.Properties", self._spark._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) if column is not None: if numPartitions is None: - numPartitions = self._sqlContext._sc.defaultParallelism + numPartitions = self._spark._sc.defaultParallelism return self._df(self._jreader.jdbc(url, table, column, int(lowerBound), int(upperBound), int(numPartitions), jprop)) if predicates is not None: - gateway = self._sqlContext._sc._gateway + gateway = self._spark._sc._gateway jpredicates = utils.toJArray(gateway, gateway.jvm.java.lang.String, predicates) return self._df(self._jreader.jdbc(url, table, jpredicates, jprop)) return self._df(self._jreader.jdbc(url, table, jprop)) @@ -466,7 +466,7 @@ class DataFrameWriter(object): """ def __init__(self, df): self._df = df - self._sqlContext = df.sql_ctx + self._spark = df.sql_ctx self._jwrite = df._jdf.write() def _cq(self, jcq): @@ -531,14 +531,14 @@ class DataFrameWriter(object): """ if len(cols) == 1 and isinstance(cols[0], (list, tuple)): cols = cols[0] - self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols)) + self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols)) return self @since(2.0) def queryName(self, queryName): """Specifies the name of the :class:`ContinuousQuery` that can be started with :func:`startStream`. This name must be unique among all the currently active queries - in the associated SQLContext + in the associated spark .. note:: Experimental. @@ -573,7 +573,7 @@ class DataFrameWriter(object): trigger = ProcessingTime(processingTime) if trigger is None: raise ValueError('A trigger was not provided. Supported triggers: processingTime.') - self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._sqlContext)) + self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._spark)) return self @since(1.4) @@ -854,7 +854,7 @@ class DataFrameWriter(object): """ if properties is None: properties = dict() - jprop = JavaClass("java.util.Properties", self._sqlContext._sc._gateway._gateway_client)() + jprop = JavaClass("java.util.Properties", self._spark._sc._gateway._gateway_client)() for k in properties: jprop.setProperty(k, properties[k]) self._jwrite.mode(mode).jdbc(url, table, jprop) @@ -865,7 +865,7 @@ def _test(): import os import tempfile from pyspark.context import SparkContext - from pyspark.sql import Row, SQLContext, HiveContext + from pyspark.sql import SparkSession, Row, HiveContext import pyspark.sql.readwriter os.chdir(os.environ["SPARK_HOME"]) @@ -876,11 +876,13 @@ def _test(): globs['tempfile'] = tempfile globs['os'] = os globs['sc'] = sc - globs['sqlContext'] = SQLContext(sc) + globs['spark'] = SparkSession.builder\ + .enableHiveSupport()\ + .getOrCreate() globs['hiveContext'] = HiveContext._createForTesting(sc) - globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned') + globs['df'] = globs['spark'].read.parquet('python/test_support/sql/parquet_partitioned') globs['sdf'] = \ - globs['sqlContext'].read.format('text').stream('python/test_support/sql/streaming') + globs['spark'].read.format('text').stream('python/test_support/sql/streaming') (failure_count, test_count) = doctest.testmod( pyspark.sql.readwriter, globs=globs, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index cd5c4a7b3e9f84ce662a177cb816365de69eddf9..0c73f58c3b246b7ed1cfda7516aa5175991a7e85 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -45,7 +45,7 @@ if sys.version_info[:2] <= (2, 6): else: import unittest -from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row +from pyspark.sql import SparkSession, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase @@ -178,20 +178,6 @@ class DataTypeTests(unittest.TestCase): self.assertEqual(dt.fromInternal(0), datetime.date(1970, 1, 1)) -class SQLContextTests(ReusedPySparkTestCase): - def test_get_or_create(self): - sqlCtx = SQLContext.getOrCreate(self.sc) - self.assertTrue(SQLContext.getOrCreate(self.sc) is sqlCtx) - - def test_new_session(self): - sqlCtx = SQLContext.getOrCreate(self.sc) - sqlCtx.setConf("test_key", "a") - sqlCtx2 = sqlCtx.newSession() - sqlCtx2.setConf("test_key", "b") - self.assertEqual(sqlCtx.getConf("test_key", ""), "a") - self.assertEqual(sqlCtx2.getConf("test_key", ""), "b") - - class SQLTests(ReusedPySparkTestCase): @classmethod @@ -199,15 +185,14 @@ class SQLTests(ReusedPySparkTestCase): ReusedPySparkTestCase.setUpClass() cls.tempdir = tempfile.NamedTemporaryFile(delete=False) os.unlink(cls.tempdir.name) - cls.sparkSession = SparkSession(cls.sc) - cls.sqlCtx = cls.sparkSession._wrapped + cls.spark = SparkSession(cls.sc) cls.testData = [Row(key=i, value=str(i)) for i in range(100)] - rdd = cls.sc.parallelize(cls.testData, 2) - cls.df = rdd.toDF() + cls.df = cls.spark.createDataFrame(cls.testData) @classmethod def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() shutil.rmtree(cls.tempdir.name, ignore_errors=True) def test_row_should_be_read_only(self): @@ -218,7 +203,7 @@ class SQLTests(ReusedPySparkTestCase): row.a = 3 self.assertRaises(Exception, foo) - row2 = self.sqlCtx.range(10).first() + row2 = self.spark.range(10).first() self.assertEqual(0, row2.id) def foo2(): @@ -226,14 +211,14 @@ class SQLTests(ReusedPySparkTestCase): self.assertRaises(Exception, foo2) def test_range(self): - self.assertEqual(self.sqlCtx.range(1, 1).count(), 0) - self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1) - self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2) - self.assertEqual(self.sqlCtx.range(-2).count(), 0) - self.assertEqual(self.sqlCtx.range(3).count(), 3) + self.assertEqual(self.spark.range(1, 1).count(), 0) + self.assertEqual(self.spark.range(1, 0, -1).count(), 1) + self.assertEqual(self.spark.range(0, 1 << 40, 1 << 39).count(), 2) + self.assertEqual(self.spark.range(-2).count(), 0) + self.assertEqual(self.spark.range(3).count(), 3) def test_duplicated_column_names(self): - df = self.sqlCtx.createDataFrame([(1, 2)], ["c", "c"]) + df = self.spark.createDataFrame([(1, 2)], ["c", "c"]) row = df.select('*').first() self.assertEqual(1, row[0]) self.assertEqual(2, row[1]) @@ -247,7 +232,7 @@ class SQLTests(ReusedPySparkTestCase): from pyspark.sql.functions import explode d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] rdd = self.sc.parallelize(d) - data = self.sqlCtx.createDataFrame(rdd) + data = self.spark.createDataFrame(rdd) result = data.select(explode(data.intlist).alias("a")).select("a").collect() self.assertEqual(result[0][0], 1) @@ -269,7 +254,7 @@ class SQLTests(ReusedPySparkTestCase): def test_udf_with_callable(self): d = [Row(number=i, squared=i**2) for i in range(10)] rdd = self.sc.parallelize(d) - data = self.sqlCtx.createDataFrame(rdd) + data = self.spark.createDataFrame(rdd) class PlusFour: def __call__(self, col): @@ -284,7 +269,7 @@ class SQLTests(ReusedPySparkTestCase): def test_udf_with_partial_function(self): d = [Row(number=i, squared=i**2) for i in range(10)] rdd = self.sc.parallelize(d) - data = self.sqlCtx.createDataFrame(rdd) + data = self.spark.createDataFrame(rdd) def some_func(col, param): if col is not None: @@ -296,56 +281,56 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) def test_udf(self): - self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) - [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() self.assertEqual(row[0], 5) def test_udf2(self): - self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) - self.sqlCtx.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test") - [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() + self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType()) + self.spark.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test") + [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) def test_chained_udf(self): - self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType()) - [row] = self.sqlCtx.sql("SELECT double(1)").collect() + self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) + [row] = self.spark.sql("SELECT double(1)").collect() self.assertEqual(row[0], 2) - [row] = self.sqlCtx.sql("SELECT double(double(1))").collect() + [row] = self.spark.sql("SELECT double(double(1))").collect() self.assertEqual(row[0], 4) - [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect() + [row] = self.spark.sql("SELECT double(double(1) + 1)").collect() self.assertEqual(row[0], 6) def test_multiple_udfs(self): - self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType()) - [row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect() + self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType()) + [row] = self.spark.sql("SELECT double(1), double(2)").collect() self.assertEqual(tuple(row), (2, 4)) - [row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect() + [row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 2)").collect() self.assertEqual(tuple(row), (4, 12)) - self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType()) - [row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect() + self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType()) + [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect() self.assertEqual(tuple(row), (6, 5)) def test_udf_with_array_type(self): d = [Row(l=list(range(3)), d={"key": list(range(5))})] rdd = self.sc.parallelize(d) - self.sqlCtx.createDataFrame(rdd).registerTempTable("test") - self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) - self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) - [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() + self.spark.createDataFrame(rdd).registerTempTable("test") + self.spark.catalog.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) + self.spark.catalog.registerFunction("maplen", lambda d: len(d), IntegerType()) + [(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from test").collect() self.assertEqual(list(range(3)), l1) self.assertEqual(1, l2) def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} foo = self.sc.broadcast(bar) - self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') - [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() + self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') + [res] = self.spark.sql("SELECT MYUDF('c')").collect() self.assertEqual("abc", res[0]) - [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() + [res] = self.spark.sql("SELECT MYUDF('')").collect() self.assertEqual("", res[0]) def test_udf_with_aggregate_function(self): - df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) from pyspark.sql.functions import udf, col from pyspark.sql.types import BooleanType @@ -355,7 +340,7 @@ class SQLTests(ReusedPySparkTestCase): def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.read.json(rdd) + df = self.spark.read.json(rdd) df.count() df.collect() df.schema @@ -369,41 +354,41 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(2, df.count()) df.registerTempTable("temp") - df = self.sqlCtx.sql("select foo from temp") + df = self.spark.sql("select foo from temp") df.count() df.collect() def test_apply_schema_to_row(self): - df = self.sqlCtx.read.json(self.sc.parallelize(["""{"a":2}"""])) - df2 = self.sqlCtx.createDataFrame(df.rdd.map(lambda x: x), df.schema) + df = self.spark.read.json(self.sc.parallelize(["""{"a":2}"""])) + df2 = self.spark.createDataFrame(df.rdd.map(lambda x: x), df.schema) self.assertEqual(df.collect(), df2.collect()) rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - df3 = self.sqlCtx.createDataFrame(rdd, df.schema) + df3 = self.spark.createDataFrame(rdd, df.schema) self.assertEqual(10, df3.count()) def test_infer_schema_to_local(self): input = [{"a": 1}, {"b": "coffee"}] rdd = self.sc.parallelize(input) - df = self.sqlCtx.createDataFrame(input) - df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) + df = self.spark.createDataFrame(input) + df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0) self.assertEqual(df.schema, df2.schema) rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) - df3 = self.sqlCtx.createDataFrame(rdd, df.schema) + df3 = self.spark.createDataFrame(rdd, df.schema) self.assertEqual(10, df3.count()) def test_create_dataframe_schema_mismatch(self): input = [Row(a=1)] rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i)) schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())]) - df = self.sqlCtx.createDataFrame(rdd, schema) + df = self.spark.createDataFrame(rdd, schema) self.assertRaises(Exception, lambda: df.show()) def test_serialize_nested_array_and_map(self): d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] rdd = self.sc.parallelize(d) - df = self.sqlCtx.createDataFrame(rdd) + df = self.spark.createDataFrame(rdd) row = df.head() self.assertEqual(1, len(row.l)) self.assertEqual(1, row.l[0].a) @@ -425,31 +410,31 @@ class SQLTests(ReusedPySparkTestCase): d = [Row(l=[], d={}, s=None), Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] rdd = self.sc.parallelize(d) - df = self.sqlCtx.createDataFrame(rdd) + df = self.spark.createDataFrame(rdd) self.assertEqual([], df.rdd.map(lambda r: r.l).first()) self.assertEqual([None, ""], df.rdd.map(lambda r: r.s).collect()) df.registerTempTable("test") - result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") + result = self.spark.sql("SELECT l[0].a from test where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) - df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) + df2 = self.spark.createDataFrame(rdd, samplingRatio=1.0) self.assertEqual(df.schema, df2.schema) self.assertEqual({}, df2.rdd.map(lambda r: r.d).first()) self.assertEqual([None, ""], df2.rdd.map(lambda r: r.s).collect()) df2.registerTempTable("test2") - result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) def test_infer_nested_schema(self): NestedRow = Row("f1", "f2") nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), NestedRow([2, 3], {"row2": 2.0})]) - df = self.sqlCtx.createDataFrame(nestedRdd1) + df = self.spark.createDataFrame(nestedRdd1) self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0]) nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]), NestedRow([[2, 3], [3, 4]], [2, 3])]) - df = self.sqlCtx.createDataFrame(nestedRdd2) + df = self.spark.createDataFrame(nestedRdd2) self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0]) from collections import namedtuple @@ -457,17 +442,17 @@ class SQLTests(ReusedPySparkTestCase): rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"), CustomRow(field1=2, field2="row2"), CustomRow(field1=3, field2="row3")]) - df = self.sqlCtx.createDataFrame(rdd) + df = self.spark.createDataFrame(rdd) self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) def test_create_dataframe_from_objects(self): data = [MyObject(1, "1"), MyObject(2, "2")] - df = self.sqlCtx.createDataFrame(data) + df = self.spark.createDataFrame(data) self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")]) self.assertEqual(df.first(), Row(key=1, value="1")) def test_select_null_literal(self): - df = self.sqlCtx.sql("select null as col") + df = self.spark.sql("select null as col") self.assertEqual(Row(col=None), df.first()) def test_apply_schema(self): @@ -488,7 +473,7 @@ class SQLTests(ReusedPySparkTestCase): StructField("struct1", StructType([StructField("b", ShortType(), False)]), False), StructField("list1", ArrayType(ByteType(), False), False), StructField("null1", DoubleType(), True)]) - df = self.sqlCtx.createDataFrame(rdd, schema) + df = self.spark.createDataFrame(rdd, schema) results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), @@ -496,9 +481,9 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(r, results.first()) df.registerTempTable("table2") - r = self.sqlCtx.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + - "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + - "float1 + 1.5 as float1 FROM table2").first() + r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + + "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + + "float1 + 1.5 as float1 FROM table2").first() self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r)) @@ -508,7 +493,7 @@ class SQLTests(ReusedPySparkTestCase): abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]" schema = _parse_schema_abstract(abstract) typedSchema = _infer_schema_type(rdd.first(), schema) - df = self.sqlCtx.createDataFrame(rdd, typedSchema) + df = self.spark.createDataFrame(rdd, typedSchema) r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3]) self.assertEqual(r, tuple(df.first())) @@ -524,7 +509,7 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(1, row.asDict()['l'][0].a) df = self.sc.parallelize([row]).toDF() df.registerTempTable("test") - row = self.sqlCtx.sql("select l, d from test").head() + row = self.spark.sql("select l, d from test").head() self.assertEqual(1, row.asDict()["l"][0].a) self.assertEqual(1.0, row.asDict()['d']['key'].c) @@ -535,7 +520,7 @@ class SQLTests(ReusedPySparkTestCase): def check_datatype(datatype): pickled = pickle.loads(pickle.dumps(datatype)) assert datatype == pickled - scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json()) + scala_datatype = self.spark._wrapped._ssql_ctx.parseDataType(datatype.json()) python_datatype = _parse_datatype_json_string(scala_datatype.json()) assert datatype == python_datatype @@ -560,21 +545,21 @@ class SQLTests(ReusedPySparkTestCase): def test_infer_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) schema = df.schema field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), ExamplePointUDT) df.registerTempTable("labeled_point") - point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + point = self.spark.sql("SELECT point FROM labeled_point").head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) schema = df.schema field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), PythonOnlyUDT) df.registerTempTable("labeled_point") - point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + point = self.spark.sql("SELECT point FROM labeled_point").head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) def test_apply_schema_with_udt(self): @@ -582,21 +567,21 @@ class SQLTests(ReusedPySparkTestCase): row = (1.0, ExamplePoint(1.0, 2.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) - df = self.sqlCtx.createDataFrame([row], schema) + df = self.spark.createDataFrame([row], schema) point = df.head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = (1.0, PythonOnlyPoint(1.0, 2.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", PythonOnlyUDT(), False)]) - df = self.sqlCtx.createDataFrame([row], schema) + df = self.spark.createDataFrame([row], schema) point = df.head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) def test_udf_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) @@ -604,7 +589,7 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) @@ -614,17 +599,17 @@ class SQLTests(ReusedPySparkTestCase): def test_parquet_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df0 = self.sqlCtx.createDataFrame([row]) + df0 = self.spark.createDataFrame([row]) output_dir = os.path.join(self.tempdir.name, "labeled_point") df0.write.parquet(output_dir) - df1 = self.sqlCtx.read.parquet(output_dir) + df1 = self.spark.read.parquet(output_dir) point = df1.head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) - df0 = self.sqlCtx.createDataFrame([row]) + df0 = self.spark.createDataFrame([row]) df0.write.parquet(output_dir, mode='overwrite') - df1 = self.sqlCtx.read.parquet(output_dir) + df1 = self.spark.read.parquet(output_dir) point = df1.head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) @@ -634,8 +619,8 @@ class SQLTests(ReusedPySparkTestCase): row2 = (2.0, ExamplePoint(3.0, 4.0)) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) - df1 = self.sqlCtx.createDataFrame([row1], schema) - df2 = self.sqlCtx.createDataFrame([row2], schema) + df1 = self.spark.createDataFrame([row1], schema) + df2 = self.spark.createDataFrame([row2], schema) result = df1.union(df2).orderBy("label").collect() self.assertEqual( @@ -688,7 +673,7 @@ class SQLTests(ReusedPySparkTestCase): def test_first_last_ignorenulls(self): from pyspark.sql import functions - df = self.sqlCtx.range(0, 100) + df = self.spark.range(0, 100) df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id")) df3 = df2.select(functions.first(df2.id, False).alias('a'), functions.first(df2.id, True).alias('b'), @@ -829,36 +814,36 @@ class SQLTests(ReusedPySparkTestCase): schema = StructType([StructField("f1", StringType(), True, None), StructField("f2", StringType(), True, {'a': None})]) rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) - self.sqlCtx.createDataFrame(rdd, schema) + self.spark.createDataFrame(rdd, schema) def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) df.write.json(tmpPath) - actual = self.sqlCtx.read.json(tmpPath) + actual = self.spark.read.json(tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) schema = StructType([StructField("value", StringType(), True)]) - actual = self.sqlCtx.read.json(tmpPath, schema) + actual = self.spark.read.json(tmpPath, schema) self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) df.write.json(tmpPath, "overwrite") - actual = self.sqlCtx.read.json(tmpPath) + actual = self.spark.read.json(tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) df.write.save(format="json", mode="overwrite", path=tmpPath, noUse="this options will not be used in save.") - actual = self.sqlCtx.read.load(format="json", path=tmpPath, - noUse="this options will not be used in load.") + actual = self.spark.read.load(format="json", path=tmpPath, + noUse="this options will not be used in load.") self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default", "org.apache.spark.sql.parquet") - self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.sqlCtx.read.load(path=tmpPath) + self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + actual = self.spark.read.load(path=tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) csvpath = os.path.join(tempfile.mkdtemp(), 'data') df.write.option('quote', None).format('csv').save(csvpath) @@ -870,36 +855,36 @@ class SQLTests(ReusedPySparkTestCase): tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) df.write.json(tmpPath) - actual = self.sqlCtx.read.json(tmpPath) + actual = self.spark.read.json(tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) schema = StructType([StructField("value", StringType(), True)]) - actual = self.sqlCtx.read.json(tmpPath, schema) + actual = self.spark.read.json(tmpPath, schema) self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) df.write.mode("overwrite").json(tmpPath) - actual = self.sqlCtx.read.json(tmpPath) + actual = self.spark.read.json(tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) df.write.mode("overwrite").options(noUse="this options will not be used in save.")\ .option("noUse", "this option will not be used in save.")\ .format("json").save(path=tmpPath) actual =\ - self.sqlCtx.read.format("json")\ - .load(path=tmpPath, noUse="this options will not be used in load.") + self.spark.read.format("json")\ + .load(path=tmpPath, noUse="this options will not be used in load.") self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", + defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default", "org.apache.spark.sql.parquet") - self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.sqlCtx.read.load(path=tmpPath) + self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + actual = self.spark.read.load(path=tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) shutil.rmtree(tmpPath) def test_stream_trigger_takes_keyword_args(self): - df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') + df = self.spark.read.format('text').stream('python/test_support/sql/streaming') try: df.write.trigger('5 seconds') self.fail("Should have thrown an exception") @@ -909,7 +894,7 @@ class SQLTests(ReusedPySparkTestCase): def test_stream_read_options(self): schema = StructType([StructField("data", StringType(), False)]) - df = self.sqlCtx.read.format('text').option('path', 'python/test_support/sql/streaming')\ + df = self.spark.read.format('text').option('path', 'python/test_support/sql/streaming')\ .schema(schema).stream() self.assertTrue(df.isStreaming) self.assertEqual(df.schema.simpleString(), "struct<data:string>") @@ -917,15 +902,15 @@ class SQLTests(ReusedPySparkTestCase): def test_stream_read_options_overwrite(self): bad_schema = StructType([StructField("test", IntegerType(), False)]) schema = StructType([StructField("data", StringType(), False)]) - df = self.sqlCtx.read.format('csv').option('path', 'python/test_support/sql/fake') \ + df = self.spark.read.format('csv').option('path', 'python/test_support/sql/fake') \ .schema(bad_schema).stream(path='python/test_support/sql/streaming', schema=schema, format='text') self.assertTrue(df.isStreaming) self.assertEqual(df.schema.simpleString(), "struct<data:string>") def test_stream_save_options(self): - df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') - for cq in self.sqlCtx.streams.active: + df = self.spark.read.format('text').stream('python/test_support/sql/streaming') + for cq in self.spark._wrapped.streams.active: cq.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) @@ -948,8 +933,8 @@ class SQLTests(ReusedPySparkTestCase): shutil.rmtree(tmpPath) def test_stream_save_options_overwrite(self): - df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') - for cq in self.sqlCtx.streams.active: + df = self.spark.read.format('text').stream('python/test_support/sql/streaming') + for cq in self.spark._wrapped.streams.active: cq.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) @@ -977,8 +962,8 @@ class SQLTests(ReusedPySparkTestCase): shutil.rmtree(tmpPath) def test_stream_await_termination(self): - df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') - for cq in self.sqlCtx.streams.active: + df = self.spark.read.format('text').stream('python/test_support/sql/streaming') + for cq in self.spark._wrapped.streams.active: cq.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) @@ -1005,8 +990,8 @@ class SQLTests(ReusedPySparkTestCase): shutil.rmtree(tmpPath) def test_query_manager_await_termination(self): - df = self.sqlCtx.read.format('text').stream('python/test_support/sql/streaming') - for cq in self.sqlCtx.streams.active: + df = self.spark.read.format('text').stream('python/test_support/sql/streaming') + for cq in self.spark._wrapped.streams.active: cq.stop() tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) @@ -1018,13 +1003,13 @@ class SQLTests(ReusedPySparkTestCase): try: self.assertTrue(cq.isActive) try: - self.sqlCtx.streams.awaitAnyTermination("hello") + self.spark._wrapped.streams.awaitAnyTermination("hello") self.fail("Expected a value exception") except ValueError: pass now = time.time() # test should take at least 2 seconds - res = self.sqlCtx.streams.awaitAnyTermination(2.6) + res = self.spark._wrapped.streams.awaitAnyTermination(2.6) duration = time.time() - now self.assertTrue(duration >= 2) self.assertFalse(res) @@ -1035,7 +1020,7 @@ class SQLTests(ReusedPySparkTestCase): def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.read.json(rdd) + df = self.spark.read.json(rdd) # render_doc() reproduces the help() exception without printing output pydoc.render_doc(df) pydoc.render_doc(df.foo) @@ -1051,7 +1036,7 @@ class SQLTests(ReusedPySparkTestCase): self.assertRaises(TypeError, lambda: df[{}]) def test_column_name_with_non_ascii(self): - df = self.sqlCtx.createDataFrame([(1,)], ["æ•°é‡"]) + df = self.spark.createDataFrame([(1,)], ["æ•°é‡"]) self.assertEqual(StructType([StructField("æ•°é‡", LongType(), True)]), df.schema) self.assertEqual("DataFrame[æ•°é‡: bigint]", str(df)) self.assertEqual([("æ•°é‡", 'bigint')], df.dtypes) @@ -1084,7 +1069,7 @@ class SQLTests(ReusedPySparkTestCase): # this saving as Parquet caused issues as well. output_dir = os.path.join(self.tempdir.name, "infer_long_type") df.write.parquet(output_dir) - df1 = self.sqlCtx.read.parquet(output_dir) + df1 = self.spark.read.parquet(output_dir) self.assertEqual('a', df1.first().f1) self.assertEqual(100000000000000, df1.first().f2) @@ -1100,7 +1085,7 @@ class SQLTests(ReusedPySparkTestCase): time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000) date = time.date() row = Row(date=date, time=time) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) self.assertEqual(1, df.filter(df.date == date).count()) self.assertEqual(1, df.filter(df.time == time).count()) self.assertEqual(0, df.filter(df.date > date).count()) @@ -1110,7 +1095,7 @@ class SQLTests(ReusedPySparkTestCase): dt1 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(0)) dt2 = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000, tzinfo=UTCOffsetTimezone(1)) row = Row(date=dt1) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) self.assertEqual(0, df.filter(df.date == dt2).count()) self.assertEqual(1, df.filter(df.date > dt2).count()) self.assertEqual(0, df.filter(df.date < dt2).count()) @@ -1125,7 +1110,7 @@ class SQLTests(ReusedPySparkTestCase): utcnow = datetime.datetime.utcfromtimestamp(ts) # without microseconds # add microseconds to utcnow (keeping year,month,day,hour,minute,second) utcnow = datetime.datetime(*(utcnow.timetuple()[:6] + (now.microsecond, utc))) - df = self.sqlCtx.createDataFrame([(day, now, utcnow)]) + df = self.spark.createDataFrame([(day, now, utcnow)]) day1, now1, utcnow1 = df.first() self.assertEqual(day1, day) self.assertEqual(now, now1) @@ -1134,13 +1119,13 @@ class SQLTests(ReusedPySparkTestCase): def test_decimal(self): from decimal import Decimal schema = StructType([StructField("decimal", DecimalType(10, 5))]) - df = self.sqlCtx.createDataFrame([(Decimal("3.14159"),)], schema) + df = self.spark.createDataFrame([(Decimal("3.14159"),)], schema) row = df.select(df.decimal + 1).first() self.assertEqual(row[0], Decimal("4.14159")) tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) df.write.parquet(tmpPath) - df2 = self.sqlCtx.read.parquet(tmpPath) + df2 = self.spark.read.parquet(tmpPath) row = df2.first() self.assertEqual(row[0], Decimal("3.14159")) @@ -1151,52 +1136,52 @@ class SQLTests(ReusedPySparkTestCase): StructField("height", DoubleType(), True)]) # shouldn't drop a non-null row - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', 50, 80.1)], schema).dropna().count(), 1) # dropping rows with a single null value - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 80.1)], schema).dropna().count(), 0) - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 80.1)], schema).dropna(how='any').count(), 0) # if how = 'all', only drop rows if all values are null - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 80.1)], schema).dropna(how='all').count(), 1) - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(None, None, None)], schema).dropna(how='all').count(), 0) # how and subset - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', 50, None)], schema).dropna(how='any', subset=['name', 'age']).count(), 1) - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, None)], schema).dropna(how='any', subset=['name', 'age']).count(), 0) # threshold - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 80.1)], schema).dropna(thresh=2).count(), 1) - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, None)], schema).dropna(thresh=2).count(), 0) # threshold and subset - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', 50, None)], schema).dropna(thresh=2, subset=['name', 'age']).count(), 1) - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', None, 180.9)], schema).dropna(thresh=2, subset=['name', 'age']).count(), 0) # thresh should take precedence over how - self.assertEqual(self.sqlCtx.createDataFrame( + self.assertEqual(self.spark.createDataFrame( [(u'Alice', 50, None)], schema).dropna( how='any', thresh=2, subset=['name', 'age']).count(), 1) @@ -1208,33 +1193,33 @@ class SQLTests(ReusedPySparkTestCase): StructField("height", DoubleType(), True)]) # fillna shouldn't change non-null values - row = self.sqlCtx.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first() + row = self.spark.createDataFrame([(u'Alice', 10, 80.1)], schema).fillna(50).first() self.assertEqual(row.age, 10) # fillna with int - row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first() + row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50).first() self.assertEqual(row.age, 50) self.assertEqual(row.height, 50.0) # fillna with double - row = self.sqlCtx.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first() + row = self.spark.createDataFrame([(u'Alice', None, None)], schema).fillna(50.1).first() self.assertEqual(row.age, 50) self.assertEqual(row.height, 50.1) # fillna with string - row = self.sqlCtx.createDataFrame([(None, None, None)], schema).fillna("hello").first() + row = self.spark.createDataFrame([(None, None, None)], schema).fillna("hello").first() self.assertEqual(row.name, u"hello") self.assertEqual(row.age, None) # fillna with subset specified for numeric cols - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(None, None, None)], schema).fillna(50, subset=['name', 'age']).first() self.assertEqual(row.name, None) self.assertEqual(row.age, 50) self.assertEqual(row.height, None) # fillna with subset specified for numeric cols - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(None, None, None)], schema).fillna("haha", subset=['name', 'age']).first() self.assertEqual(row.name, "haha") self.assertEqual(row.age, None) @@ -1243,7 +1228,7 @@ class SQLTests(ReusedPySparkTestCase): def test_bitwise_operations(self): from pyspark.sql import functions row = Row(a=170, b=75) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) result = df.select(df.a.bitwiseAND(df.b)).collect()[0].asDict() self.assertEqual(170 & 75, result['(a & b)']) result = df.select(df.a.bitwiseOR(df.b)).collect()[0].asDict() @@ -1256,7 +1241,7 @@ class SQLTests(ReusedPySparkTestCase): def test_expr(self): from pyspark.sql import functions row = Row(a="length string", b=75) - df = self.sqlCtx.createDataFrame([row]) + df = self.spark.createDataFrame([row]) result = df.select(functions.expr("length(a)")).collect()[0].asDict() self.assertEqual(13, result["length(a)"]) @@ -1267,58 +1252,58 @@ class SQLTests(ReusedPySparkTestCase): StructField("height", DoubleType(), True)]) # replace with int - row = self.sqlCtx.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first() + row = self.spark.createDataFrame([(u'Alice', 10, 10.0)], schema).replace(10, 20).first() self.assertEqual(row.age, 20) self.assertEqual(row.height, 20.0) # replace with double - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(u'Alice', 80, 80.0)], schema).replace(80.0, 82.1).first() self.assertEqual(row.age, 82) self.assertEqual(row.height, 82.1) # replace with string - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace(u'Alice', u'Ann').first() self.assertEqual(row.name, u"Ann") self.assertEqual(row.age, 10) # replace with subset specified by a string of a column name w/ actual change - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='age').first() self.assertEqual(row.age, 20) # replace with subset specified by a string of a column name w/o actual change - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(u'Alice', 10, 80.1)], schema).replace(10, 20, subset='height').first() self.assertEqual(row.age, 10) # replace with subset specified with one column replaced, another column not in subset # stays unchanged. - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(u'Alice', 10, 10.0)], schema).replace(10, 20, subset=['name', 'age']).first() self.assertEqual(row.name, u'Alice') self.assertEqual(row.age, 20) self.assertEqual(row.height, 10.0) # replace with subset specified but no column will be replaced - row = self.sqlCtx.createDataFrame( + row = self.spark.createDataFrame( [(u'Alice', 10, None)], schema).replace(10, 20, subset=['name', 'height']).first() self.assertEqual(row.name, u'Alice') self.assertEqual(row.age, 10) self.assertEqual(row.height, None) def test_capture_analysis_exception(self): - self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc")) + self.assertRaises(AnalysisException, lambda: self.spark.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) def test_capture_parse_exception(self): - self.assertRaises(ParseException, lambda: self.sqlCtx.sql("abc")) + self.assertRaises(ParseException, lambda: self.spark.sql("abc")) def test_capture_illegalargument_exception(self): self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks", - lambda: self.sqlCtx.sql("SET mapred.reduce.tasks=-1")) - df = self.sqlCtx.createDataFrame([(1, 2)], ["a", "b"]) + lambda: self.spark.sql("SET mapred.reduce.tasks=-1")) + df = self.spark.createDataFrame([(1, 2)], ["a", "b"]) self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values", lambda: df.select(sha2(df.a, 1024)).collect()) try: @@ -1345,8 +1330,8 @@ class SQLTests(ReusedPySparkTestCase): def test_functions_broadcast(self): from pyspark.sql.functions import broadcast - df1 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) - df2 = self.sqlCtx.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + df1 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) + df2 = self.spark.createDataFrame([(1, "1"), (2, "2")], ("key", "value")) # equijoin - should be converted into broadcast join plan1 = df1.join(broadcast(df2), "key")._jdf.queryExecution().executedPlan() @@ -1396,9 +1381,9 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(df.collect(), [Row(key=i) for i in range(100)]) def test_conf(self): - spark = self.sparkSession + spark = self.spark spark.conf.set("bogo", "sipeo") - self.assertEqual(self.sparkSession.conf.get("bogo"), "sipeo") + self.assertEqual(spark.conf.get("bogo"), "sipeo") spark.conf.set("bogo", "ta") self.assertEqual(spark.conf.get("bogo"), "ta") self.assertEqual(spark.conf.get("bogo", "not.read"), "ta") @@ -1408,7 +1393,7 @@ class SQLTests(ReusedPySparkTestCase): self.assertEqual(spark.conf.get("bogo", "colombia"), "colombia") def test_current_database(self): - spark = self.sparkSession + spark = self.spark spark.catalog._reset() self.assertEquals(spark.catalog.currentDatabase(), "default") spark.sql("CREATE DATABASE some_db") @@ -1420,7 +1405,7 @@ class SQLTests(ReusedPySparkTestCase): lambda: spark.catalog.setCurrentDatabase("does_not_exist")) def test_list_databases(self): - spark = self.sparkSession + spark = self.spark spark.catalog._reset() databases = [db.name for db in spark.catalog.listDatabases()] self.assertEquals(databases, ["default"]) @@ -1430,7 +1415,7 @@ class SQLTests(ReusedPySparkTestCase): def test_list_tables(self): from pyspark.sql.catalog import Table - spark = self.sparkSession + spark = self.spark spark.catalog._reset() spark.sql("CREATE DATABASE some_db") self.assertEquals(spark.catalog.listTables(), []) @@ -1475,7 +1460,7 @@ class SQLTests(ReusedPySparkTestCase): def test_list_functions(self): from pyspark.sql.catalog import Function - spark = self.sparkSession + spark = self.spark spark.catalog._reset() spark.sql("CREATE DATABASE some_db") functions = dict((f.name, f) for f in spark.catalog.listFunctions()) @@ -1512,7 +1497,7 @@ class SQLTests(ReusedPySparkTestCase): def test_list_columns(self): from pyspark.sql.catalog import Column - spark = self.sparkSession + spark = self.spark spark.catalog._reset() spark.sql("CREATE DATABASE some_db") spark.sql("CREATE TABLE tab1 (name STRING, age INT)") @@ -1561,7 +1546,7 @@ class SQLTests(ReusedPySparkTestCase): lambda: spark.catalog.listColumns("does_not_exist")) def test_cache(self): - spark = self.sparkSession + spark = self.spark spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab1") spark.createDataFrame([(2, 2), (3, 3)]).registerTempTable("tab2") self.assertFalse(spark.catalog.isCached("tab1")) @@ -1605,7 +1590,7 @@ class HiveContextSQLTests(ReusedPySparkTestCase): cls.tearDownClass() raise unittest.SkipTest("Hive is not available") os.unlink(cls.tempdir.name) - cls.sqlCtx = HiveContext._createForTesting(cls.sc) + cls.spark = HiveContext._createForTesting(cls.sc) cls.testData = [Row(key=i, value=str(i)) for i in range(100)] cls.df = cls.sc.parallelize(cls.testData).toDF() @@ -1619,45 +1604,45 @@ class HiveContextSQLTests(ReusedPySparkTestCase): tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath) - actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath, "json") + actual = self.spark.createExternalTable("externalJsonTable", tmpPath, "json") self.assertEqual(sorted(df.collect()), - sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) self.assertEqual(sorted(df.collect()), - sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.sqlCtx.sql("DROP TABLE externalJsonTable") + self.spark.sql("DROP TABLE externalJsonTable") df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath) schema = StructType([StructField("value", StringType(), True)]) - actual = self.sqlCtx.createExternalTable("externalJsonTable", source="json", - schema=schema, path=tmpPath, - noUse="this options will not be used") + actual = self.spark.createExternalTable("externalJsonTable", source="json", + schema=schema, path=tmpPath, + noUse="this options will not be used") self.assertEqual(sorted(df.collect()), - sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) self.assertEqual(sorted(df.select("value").collect()), - sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) - self.sqlCtx.sql("DROP TABLE savedJsonTable") - self.sqlCtx.sql("DROP TABLE externalJsonTable") + self.spark.sql("DROP TABLE savedJsonTable") + self.spark.sql("DROP TABLE externalJsonTable") - defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", - "org.apache.spark.sql.parquet") - self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") + defaultDataSourceName = self.spark.getConf("spark.sql.sources.default", + "org.apache.spark.sql.parquet") + self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite") - actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath) + actual = self.spark.createExternalTable("externalJsonTable", path=tmpPath) self.assertEqual(sorted(df.collect()), - sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect())) + sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect())) self.assertEqual(sorted(df.collect()), - sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect())) + sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect())) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) - self.sqlCtx.sql("DROP TABLE savedJsonTable") - self.sqlCtx.sql("DROP TABLE externalJsonTable") - self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) + self.spark.sql("DROP TABLE savedJsonTable") + self.spark.sql("DROP TABLE externalJsonTable") + self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName) shutil.rmtree(tmpPath) def test_window_functions(self): - df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) w = Window.partitionBy("value").orderBy("key") from pyspark.sql import functions as F sel = df.select(df.value, df.key, @@ -1679,7 +1664,7 @@ class HiveContextSQLTests(ReusedPySparkTestCase): self.assertEqual(tuple(r), ex[:len(r)]) def test_window_functions_without_partitionBy(self): - df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) w = Window.orderBy("key", df.value) from pyspark.sql import functions as F sel = df.select(df.value, df.key, @@ -1701,7 +1686,7 @@ class HiveContextSQLTests(ReusedPySparkTestCase): self.assertEqual(tuple(r), ex[:len(r)]) def test_collect_functions(self): - df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) from pyspark.sql import functions self.assertEqual(