From afae9766f28d2e58297405c39862d20a04267b62 Mon Sep 17 00:00:00 2001 From: Davies Liu <davies@databricks.com> Date: Mon, 29 Jun 2015 13:20:55 -0700 Subject: [PATCH] [SPARK-8070] [SQL] [PYSPARK] avoid spark jobs in createDataFrame Avoid the unnecessary jobs when infer schema from list. cc yhuai mengxr Author: Davies Liu <davies@databricks.com> Closes #6606 from davies/improve_create and squashes the following commits: a5928bf [Davies Liu] Update MimaExcludes.scala 62da911 [Davies Liu] fix mima bab4d7d [Davies Liu] Merge branch 'improve_create' of github.com:davies/spark into improve_create eee44a8 [Davies Liu] Merge branch 'master' of github.com:apache/spark into improve_create 8d9292d [Davies Liu] Update context.py eb24531 [Davies Liu] Update context.py c969997 [Davies Liu] bug fix d5a8ab0 [Davies Liu] fix tests 8c3f10d [Davies Liu] Merge branch 'master' of github.com:apache/spark into improve_create 6ea5925 [Davies Liu] address comments 6ceaeff [Davies Liu] avoid spark jobs in createDataFrame --- python/pyspark/sql/context.py | 64 +++++++++++++++++++++++++---------- python/pyspark/sql/types.py | 48 +++++++++++++++----------- 2 files changed, 75 insertions(+), 37 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index dc239226e6..4dda3b430c 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -203,7 +203,37 @@ class SQLContext(object): self._sc._javaAccumulator, returnType.json()) + def _inferSchemaFromList(self, data): + """ + Infer schema from list of Row or tuple. + + :param data: list of Row or tuple + :return: StructType + """ + if not data: + raise ValueError("can not infer schema from empty dataset") + first = data[0] + if type(first) is dict: + warnings.warn("inferring schema from dict is deprecated," + "please use pyspark.sql.Row instead") + schema = _infer_schema(first) + if _has_nulltype(schema): + for r in data: + schema = _merge_type(schema, _infer_schema(r)) + if not _has_nulltype(schema): + break + else: + raise ValueError("Some of types cannot be determined after inferring") + return schema + def _inferSchema(self, rdd, samplingRatio=None): + """ + Infer schema from an RDD of Row or tuple. + + :param rdd: an RDD of Row or tuple + :param samplingRatio: sampling ratio, or no sampling (default) + :return: StructType + """ first = rdd.first() if not first: raise ValueError("The first row in RDD is empty, " @@ -322,6 +352,8 @@ class SQLContext(object): data = [r.tolist() for r in data.to_records(index=False)] if not isinstance(data, RDD): + if not isinstance(data, list): + data = list(data) try: # data could be list, tuple, generator ... rdd = self._sc.parallelize(data) @@ -330,28 +362,26 @@ class SQLContext(object): else: rdd = data - if schema is None: - schema = self._inferSchema(rdd, samplingRatio) + if schema is None or isinstance(schema, (list, tuple)): + if isinstance(data, RDD): + struct = self._inferSchema(rdd, samplingRatio) + else: + struct = self._inferSchemaFromList(data) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + schema = struct converter = _create_converter(schema) rdd = rdd.map(converter) - if isinstance(schema, (list, tuple)): - first = rdd.first() - if not isinstance(first, (list, tuple)): - raise TypeError("each row in `rdd` should be list or tuple, " - "but got %r" % type(first)) - row_cls = Row(*schema) - schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio) - - # take the first few rows to verify schema - rows = rdd.take(10) - # Row() cannot been deserialized by Pyrolite - if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': - rdd = rdd.map(tuple) + elif isinstance(schema, StructType): + # take the first few rows to verify schema rows = rdd.take(10) + for row in rows: + _verify_type(row, schema) - for row in rows: - _verify_type(row, schema) + else: + raise TypeError("schema should be StructType or list or None") # convert python objects to sql data converter = _python_to_sql_converter(schema) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 23d9adb0da..932686e5e4 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -635,7 +635,7 @@ def _need_python_to_sql_conversion(dataType): >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), ... StructField("values", ArrayType(DoubleType(), False), False)]) >>> _need_python_to_sql_conversion(schema0) - False + True >>> _need_python_to_sql_conversion(ExamplePointUDT()) True >>> schema1 = ArrayType(ExamplePointUDT(), False) @@ -647,7 +647,8 @@ def _need_python_to_sql_conversion(dataType): True """ if isinstance(dataType, StructType): - return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) + # convert namedtuple or Row into tuple + return True elif isinstance(dataType, ArrayType): return _need_python_to_sql_conversion(dataType.elementType) elif isinstance(dataType, MapType): @@ -688,21 +689,25 @@ def _python_to_sql_converter(dataType): if isinstance(dataType, StructType): names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) - converters = [_python_to_sql_converter(t) for t in types] - - def converter(obj): - if isinstance(obj, dict): - return tuple(c(obj.get(n)) for n, c in zip(names, converters)) - elif isinstance(obj, tuple): - if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): - return tuple(c(v) for c, v in zip(converters, obj)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs - d = dict(obj) - return tuple(c(d.get(n)) for n, c in zip(names, converters)) + if any(_need_python_to_sql_conversion(t) for t in types): + converters = [_python_to_sql_converter(t) for t in types] + + def converter(obj): + if isinstance(obj, dict): + return tuple(c(obj.get(n)) for n, c in zip(names, converters)) + elif isinstance(obj, tuple): + if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): + return tuple(c(v) for c, v in zip(converters, obj)) + else: + return tuple(c(v) for c, v in zip(converters, obj)) + elif obj is not None: + raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + else: + def converter(obj): + if isinstance(obj, dict): + return tuple(obj.get(n) for n in names) else: - return tuple(c(v) for c, v in zip(converters, obj)) - elif obj is not None: - raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + return tuple(obj) return converter elif isinstance(dataType, ArrayType): element_converter = _python_to_sql_converter(dataType.elementType) @@ -1027,10 +1032,13 @@ def _verify_type(obj, dataType): _type = type(dataType) assert _type in _acceptable_types, "unknown datatype: %s" % dataType - # subclass of them can not be deserialized in JVM - if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object in type %s" - % (dataType, type(obj))) + if _type is StructType: + if not isinstance(obj, (tuple, list)): + raise TypeError("StructType can not accept object in type %s" % type(obj)) + else: + # subclass of them can not be deserialized in JVM + if type(obj) not in _acceptable_types[_type]: + raise TypeError("%s can not accept object in type %s" % (dataType, type(obj))) if isinstance(dataType, ArrayType): for i in obj: -- GitLab