From 9c06c723018d4ef96ff31eb947226a6273ed8080 Mon Sep 17 00:00:00 2001 From: Davies Liu <davies.liu@gmail.com> Date: Fri, 12 Sep 2014 19:05:39 -0700 Subject: [PATCH] [SPARK-3500] [SQL] use JavaSchemaRDD as SchemaRDD._jschema_rdd Currently, SchemaRDD._jschema_rdd is SchemaRDD, the Scala API (coalesce(), repartition()) can not been called in Python easily, there is no way to specify the implicit parameter `ord`. The _jrdd is an JavaRDD, so _jschema_rdd should also be JavaSchemaRDD. In this patch, change _schema_rdd to JavaSchemaRDD, also added an assert for it. If some methods are missing from JavaSchemaRDD, then it's called by _schema_rdd.baseSchemaRDD().xxx(). BTW, Do we need JavaSQLContext? Author: Davies Liu <davies.liu@gmail.com> Closes #2369 from davies/fix_schemardd and squashes the following commits: abee159 [Davies Liu] use JavaSchemaRDD as SchemaRDD._jschema_rdd (cherry picked from commit 885d1621bc06bc1f009c9707c3452eac26baf828) Signed-off-by: Josh Rosen <joshrosen@apache.org> Conflicts: python/pyspark/tests.py --- python/pyspark/sql.py | 38 ++++++++++++++++++-------------------- python/pyspark/tests.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 0ff6a548a8..07b39c92b8 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1121,7 +1121,7 @@ class SQLContext: batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) jrdd = self._pythonToJava(rdd._jrdd, batched) srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema)) - return SchemaRDD(srdd, self) + return SchemaRDD(srdd.toJavaSchemaRDD(), self) def registerRDDAsTable(self, rdd, tableName): """Registers the given RDD as a temporary table in the catalog. @@ -1133,8 +1133,8 @@ class SQLContext: >>> sqlCtx.registerRDDAsTable(srdd, "table1") """ if (rdd.__class__ is SchemaRDD): - jschema_rdd = rdd._jschema_rdd - self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName) + srdd = rdd._jschema_rdd.baseSchemaRDD() + self._ssql_ctx.registerRDDAsTable(srdd, tableName) else: raise ValueError("Can only register SchemaRDD as table") @@ -1150,7 +1150,7 @@ class SQLContext: >>> sorted(srdd.collect()) == sorted(srdd2.collect()) True """ - jschema_rdd = self._ssql_ctx.parquetFile(path) + jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD() return SchemaRDD(jschema_rdd, self) def jsonFile(self, path, schema=None): @@ -1206,11 +1206,11 @@ class SQLContext: [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] """ if schema is None: - jschema_rdd = self._ssql_ctx.jsonFile(path) + srdd = self._ssql_ctx.jsonFile(path) else: scala_datatype = self._ssql_ctx.parseDataType(str(schema)) - jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype) - return SchemaRDD(jschema_rdd, self) + srdd = self._ssql_ctx.jsonFile(path, scala_datatype) + return SchemaRDD(srdd.toJavaSchemaRDD(), self) def jsonRDD(self, rdd, schema=None): """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. @@ -1274,11 +1274,11 @@ class SQLContext: keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) if schema is None: - jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + srdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) else: scala_datatype = self._ssql_ctx.parseDataType(str(schema)) - jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) - return SchemaRDD(jschema_rdd, self) + srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) + return SchemaRDD(srdd.toJavaSchemaRDD(), self) def sql(self, sqlQuery): """Return a L{SchemaRDD} representing the result of the given query. @@ -1289,7 +1289,7 @@ class SQLContext: >>> srdd2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] """ - return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self) + return SchemaRDD(self._ssql_ctx.sql(sqlQuery).toJavaSchemaRDD(), self) def table(self, tableName): """Returns the specified table as a L{SchemaRDD}. @@ -1300,7 +1300,7 @@ class SQLContext: >>> sorted(srdd.collect()) == sorted(srdd2.collect()) True """ - return SchemaRDD(self._ssql_ctx.table(tableName), self) + return SchemaRDD(self._ssql_ctx.table(tableName).toJavaSchemaRDD(), self) def cacheTable(self, tableName): """Caches the specified table in-memory.""" @@ -1352,7 +1352,7 @@ class HiveContext(SQLContext): warnings.warn("hiveql() is deprecated as the sql function now parses using HiveQL by" + "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'", DeprecationWarning) - return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self) + return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery).toJavaSchemaRDD(), self) def hql(self, hqlQuery): """ @@ -1508,6 +1508,8 @@ class SchemaRDD(RDD): def __init__(self, jschema_rdd, sql_ctx): self.sql_ctx = sql_ctx self._sc = sql_ctx._sc + clsName = jschema_rdd.getClass().getName() + assert clsName.endswith("JavaSchemaRDD"), "jschema_rdd must be JavaSchemaRDD" self._jschema_rdd = jschema_rdd self.is_cached = False @@ -1524,7 +1526,7 @@ class SchemaRDD(RDD): L{pyspark.rdd.RDD} super class (map, filter, etc.). """ if not hasattr(self, '_lazy_jrdd'): - self._lazy_jrdd = self._jschema_rdd.javaToPython() + self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython() return self._lazy_jrdd @property @@ -1580,7 +1582,7 @@ class SchemaRDD(RDD): def schema(self): """Returns the schema of this SchemaRDD (represented by a L{StructType}).""" - return _parse_datatype_string(self._jschema_rdd.schema().toString()) + return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString()) def schemaString(self): """Returns the output schema in the tree format.""" @@ -1631,8 +1633,6 @@ class SchemaRDD(RDD): rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer) schema = self.schema() - import pickle - pickle.loads(pickle.dumps(schema)) def applySchema(_, it): cls = _create_cls(schema) @@ -1669,10 +1669,8 @@ class SchemaRDD(RDD): def getCheckpointFile(self): checkpointFile = self._jschema_rdd.getCheckpointFile() - if checkpointFile.isDefined(): + if checkpointFile.isPresent(): return checkpointFile.get() - else: - return None def coalesce(self, numPartitions, shuffle=False): rdd = self._jschema_rdd.coalesce(numPartitions, shuffle) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 1db922f513..8f0a351b6b 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -41,6 +41,8 @@ from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger +from pyspark.storagelevel import StorageLevel +from pyspark.sql import SQLContext _have_scipy = False _have_numpy = False @@ -469,6 +471,41 @@ class TestRDDFunctions(PySparkTestCase): self.assertRaises(TypeError, lambda: rdd.histogram(2)) +class TestSQL(PySparkTestCase): + + def setUp(self): + PySparkTestCase.setUp(self) + self.sqlCtx = SQLContext(self.sc) + + def test_basic_functions(self): + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + srdd = self.sqlCtx.jsonRDD(rdd) + srdd.count() + srdd.collect() + srdd.schemaString() + srdd.schema() + + # cache and checkpoint + self.assertFalse(srdd.is_cached) + srdd.persist(StorageLevel.MEMORY_ONLY_SER) + srdd.unpersist() + srdd.cache() + self.assertTrue(srdd.is_cached) + self.assertFalse(srdd.isCheckpointed()) + self.assertEqual(None, srdd.getCheckpointFile()) + + srdd = srdd.coalesce(2, True) + srdd = srdd.repartition(3) + srdd = srdd.distinct() + srdd.intersection(srdd) + self.assertEqual(2, srdd.count()) + + srdd.registerTempTable("temp") + srdd = self.sqlCtx.sql("select foo from temp") + srdd.count() + srdd.collect() + + class TestIO(PySparkTestCase): def test_stdout_redirection(self): -- GitLab