From ebc124d4c44d4c84f7868f390f778c0ff5cd66cb Mon Sep 17 00:00:00 2001 From: hyukjinkwon <gurwls223@gmail.com> Date: Tue, 11 Jul 2017 22:03:10 +0800 Subject: [PATCH] [SPARK-21365][PYTHON] Deduplicate logics parsing DDL type/schema definition ## What changes were proposed in this pull request? This PR deals with four points as below: - Reuse existing DDL parser APIs rather than reimplementing within PySpark - Support DDL formatted string, `field type, field type`. - Support case-insensitivity for parsing. - Support nested data types as below: **Before** ``` >>> spark.createDataFrame([[[1]]], "struct<a: struct<b: int>>").show() ... ValueError: The strcut field string format is: 'field_name:field_type', but got: a: struct<b: int> ``` ``` >>> spark.createDataFrame([[[1]]], "a: struct<b: int>").show() ... ValueError: The strcut field string format is: 'field_name:field_type', but got: a: struct<b: int> ``` ``` >>> spark.createDataFrame([[1]], "a int").show() ... ValueError: Could not parse datatype: a int ``` **After** ``` >>> spark.createDataFrame([[[1]]], "struct<a: struct<b: int>>").show() +---+ | a| +---+ |[1]| +---+ ``` ``` >>> spark.createDataFrame([[[1]]], "a: struct<b: int>").show() +---+ | a| +---+ |[1]| +---+ ``` ``` >>> spark.createDataFrame([[1]], "a int").show() +---+ | a| +---+ | 1| +---+ ``` ## How was this patch tested? Author: hyukjinkwon <gurwls223@gmail.com> Closes #18590 from HyukjinKwon/deduplicate-python-ddl. --- python/pyspark/sql/functions.py | 16 +++- python/pyspark/sql/tests.py | 25 ++++++ python/pyspark/sql/types.py | 88 +++++++------------ .../spark/sql/api/python/PythonSQLUtils.scala | 25 ++++++ 4 files changed, 97 insertions(+), 57 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f203d85dd9..d45ff63355 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2037,15 +2037,25 @@ class UserDefinedFunction(object): "{0}".format(type(func))) self.func = func - self.returnType = ( - returnType if isinstance(returnType, DataType) - else _parse_datatype_string(returnType)) + self._returnType = returnType # Stores UserDefinedPythonFunctions jobj, once initialized + self._returnType_placeholder = None self._judf_placeholder = None self._name = name or ( func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) + @property + def returnType(self): + # This makes sure this is called after SparkContext is initialized. + # ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string. + if self._returnType_placeholder is None: + if isinstance(self._returnType, DataType): + self._returnType_placeholder = self._returnType + else: + self._returnType_placeholder = _parse_datatype_string(self._returnType) + return self._returnType_placeholder + @property def _judf(self): # It is possible that concurrent access, to newly created UDF, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index bd8477e35f..29e48a6ccf 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1255,6 +1255,31 @@ class SQLTests(ReusedPySparkTestCase): with self.assertRaises(TypeError): not_a_field = struct1[9.9] + def test_parse_datatype_string(self): + from pyspark.sql.types import _all_atomic_types, _parse_datatype_string + for k, t in _all_atomic_types.items(): + if t != NullType: + self.assertEqual(t(), _parse_datatype_string(k)) + self.assertEqual(IntegerType(), _parse_datatype_string("int")) + self.assertEqual(DecimalType(1, 1), _parse_datatype_string("decimal(1 ,1)")) + self.assertEqual(DecimalType(10, 1), _parse_datatype_string("decimal( 10,1 )")) + self.assertEqual(DecimalType(11, 1), _parse_datatype_string("decimal(11,1)")) + self.assertEqual( + ArrayType(IntegerType()), + _parse_datatype_string("array<int >")) + self.assertEqual( + MapType(IntegerType(), DoubleType()), + _parse_datatype_string("map< int, double >")) + self.assertEqual( + StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), + _parse_datatype_string("struct<a:int, c:double >")) + self.assertEqual( + StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), + _parse_datatype_string("a:int, c:double")) + self.assertEqual( + StructType([StructField("a", IntegerType()), StructField("c", DoubleType())]), + _parse_datatype_string("a INT, c DOUBLE")) + def test_metadata_null(self): from pyspark.sql.types import StructType, StringType, StructField schema = StructType([StructField("f1", StringType(), True, None), diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index f5505ed472..22fa273fc1 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -32,6 +32,7 @@ if sys.version >= "3": from py4j.protocol import register_input_converter from py4j.java_gateway import JavaClass +from pyspark import SparkContext from pyspark.serializers import CloudPickleSerializer __all__ = [ @@ -727,18 +728,6 @@ _FIXED_DECIMAL = re.compile("decimal\\(\\s*(\\d+)\\s*,\\s*(\\d+)\\s*\\)") _BRACKETS = {'(': ')', '[': ']', '{': '}'} -def _parse_basic_datatype_string(s): - if s in _all_atomic_types.keys(): - return _all_atomic_types[s]() - elif s == "int": - return IntegerType() - elif _FIXED_DECIMAL.match(s): - m = _FIXED_DECIMAL.match(s) - return DecimalType(int(m.group(1)), int(m.group(2))) - else: - raise ValueError("Could not parse datatype: %s" % s) - - def _ignore_brackets_split(s, separator): """ Splits the given string by given separator, but ignore separators inside brackets pairs, e.g. @@ -771,32 +760,23 @@ def _ignore_brackets_split(s, separator): return parts -def _parse_struct_fields_string(s): - parts = _ignore_brackets_split(s, ",") - fields = [] - for part in parts: - name_and_type = _ignore_brackets_split(part, ":") - if len(name_and_type) != 2: - raise ValueError("The strcut field string format is: 'field_name:field_type', " + - "but got: %s" % part) - field_name = name_and_type[0].strip() - field_type = _parse_datatype_string(name_and_type[1]) - fields.append(StructField(field_name, field_type)) - return StructType(fields) - - def _parse_datatype_string(s): """ Parses the given data type string to a :class:`DataType`. The data type string format equals to :class:`DataType.simpleString`, except that top level struct type can omit the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use ``byte`` instead of ``tinyint`` for :class:`ByteType`. We can also use ``int`` as a short name - for :class:`IntegerType`. + for :class:`IntegerType`. Since Spark 2.3, this also supports a schema in a DDL-formatted + string and case-insensitive strings. >>> _parse_datatype_string("int ") IntegerType + >>> _parse_datatype_string("INT ") + IntegerType >>> _parse_datatype_string("a: byte, b: decimal( 16 , 8 ) ") StructType(List(StructField(a,ByteType,true),StructField(b,DecimalType(16,8),true))) + >>> _parse_datatype_string("a DOUBLE, b STRING") + StructType(List(StructField(a,DoubleType,true),StructField(b,StringType,true))) >>> _parse_datatype_string("a: array< short>") StructType(List(StructField(a,ArrayType(ShortType,true),true))) >>> _parse_datatype_string(" map<string , string > ") @@ -806,43 +786,43 @@ def _parse_datatype_string(s): >>> _parse_datatype_string("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - ValueError:... + ParseException:... >>> _parse_datatype_string("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - ValueError:... + ParseException:... >>> _parse_datatype_string("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - ValueError:... + ParseException:... >>> _parse_datatype_string("map<int, boolean>>") # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - ValueError:... + ParseException:... """ - s = s.strip() - if s.startswith("array<"): - if s[-1] != ">": - raise ValueError("'>' should be the last char, but got: %s" % s) - return ArrayType(_parse_datatype_string(s[6:-1])) - elif s.startswith("map<"): - if s[-1] != ">": - raise ValueError("'>' should be the last char, but got: %s" % s) - parts = _ignore_brackets_split(s[4:-1], ",") - if len(parts) != 2: - raise ValueError("The map type string format is: 'map<key_type,value_type>', " + - "but got: %s" % s) - kt = _parse_datatype_string(parts[0]) - vt = _parse_datatype_string(parts[1]) - return MapType(kt, vt) - elif s.startswith("struct<"): - if s[-1] != ">": - raise ValueError("'>' should be the last char, but got: %s" % s) - return _parse_struct_fields_string(s[7:-1]) - elif ":" in s: - return _parse_struct_fields_string(s) - else: - return _parse_basic_datatype_string(s) + sc = SparkContext._active_spark_context + + def from_ddl_schema(type_str): + return _parse_datatype_json_string( + sc._jvm.org.apache.spark.sql.types.StructType.fromDDL(type_str).json()) + + def from_ddl_datatype(type_str): + return _parse_datatype_json_string( + sc._jvm.org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str).json()) + + try: + # DDL format, "fieldname datatype, fieldname datatype". + return from_ddl_schema(s) + except Exception as e: + try: + # For backwards compatibility, "integer", "struct<fieldname: datatype>" and etc. + return from_ddl_datatype(s) + except: + try: + # For backwards compatibility, "fieldname: datatype, fieldname: datatype" case. + return from_ddl_datatype("struct<%s>" % s.strip()) + except: + raise e def _parse_datatype_json_string(json_string): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala new file mode 100644 index 0000000000..731feb914d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.python + +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.types.DataType + +private[sql] object PythonSQLUtils { + def parseDataType(typeText: String): DataType = CatalystSqlParser.parseDataType(typeText) +} -- GitLab