diff --git a/docs/_config.yml b/docs/_config.yml index e2db274e1f619f135c66e1d588b5e759a58a1585..0652927a8ce9bef00e72803ed18b2a5c43cf6ced 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -10,6 +10,7 @@ kramdown: include: - _static + - _modules # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst index e03379e521a071856f64a1714509041ed5d4500a..2e3f69b9a562a3c51c30fb087a1fc101a1603237 100644 --- a/python/docs/pyspark.sql.rst +++ b/python/docs/pyspark.sql.rst @@ -7,7 +7,6 @@ Module Context .. automodule:: pyspark.sql :members: :undoc-members: - :show-inheritance: pyspark.sql.types module @@ -15,7 +14,6 @@ pyspark.sql.types module .. automodule:: pyspark.sql.types :members: :undoc-members: - :show-inheritance: pyspark.sql.functions module @@ -23,4 +21,3 @@ pyspark.sql.functions module .. automodule:: pyspark.sql.functions :members: :undoc-members: - :show-inheritance: diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 125933c9d3ae0004bbf7ad67ba5f539157faed0b..5d7aeb664cadfd1a3aec477895914d441fd37e1d 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -129,6 +129,7 @@ class SQLContext(object): >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() [Row(c0=u'4')] + >>> from pyspark.sql.types import IntegerType >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() @@ -197,31 +198,6 @@ class SQLContext(object): >>> df = sqlCtx.inferSchema(rdd) >>> df.collect()[0] Row(field1=1, field2=u'row1') - - >>> NestedRow = Row("f1", "f2") - >>> nestedRdd1 = sc.parallelize([ - ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), - ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) - >>> df = sqlCtx.inferSchema(nestedRdd1) - >>> df.collect() - [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] - - >>> nestedRdd2 = sc.parallelize([ - ... NestedRow([[1, 2], [2, 3]], [1, 2]), - ... NestedRow([[2, 3], [3, 4]], [2, 3])]) - >>> df = sqlCtx.inferSchema(nestedRdd2) - >>> df.collect() - [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] - - >>> from collections import namedtuple - >>> CustomRow = namedtuple('CustomRow', 'field1 field2') - >>> rdd = sc.parallelize( - ... [CustomRow(field1=1, field2="row1"), - ... CustomRow(field1=2, field2="row2"), - ... CustomRow(field1=3, field2="row3")]) - >>> df = sqlCtx.inferSchema(rdd) - >>> df.collect()[0] - Row(field1=1, field2=u'row1') """ if isinstance(rdd, DataFrame): @@ -252,56 +228,8 @@ class SQLContext(object): >>> schema = StructType([StructField("field1", IntegerType(), False), ... StructField("field2", StringType(), False)]) >>> df = sqlCtx.applySchema(rdd2, schema) - >>> sqlCtx.registerDataFrameAsTable(df, "table1") - >>> df2 = sqlCtx.sql("SELECT * from table1") - >>> df2.collect() - [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] - - >>> from datetime import date, datetime - >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, - ... date(2010, 1, 1), - ... datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, (2,), [1, 2, 3], None)]) - >>> schema = StructType([ - ... StructField("byte1", ByteType(), False), - ... StructField("byte2", ByteType(), False), - ... StructField("short1", ShortType(), False), - ... StructField("short2", ShortType(), False), - ... StructField("int1", IntegerType(), False), - ... StructField("float1", FloatType(), False), - ... StructField("date1", DateType(), False), - ... StructField("time1", TimestampType(), False), - ... StructField("map1", - ... MapType(StringType(), IntegerType(), False), False), - ... StructField("struct1", - ... StructType([StructField("b", ShortType(), False)]), False), - ... StructField("list1", ArrayType(ByteType(), False), False), - ... StructField("null1", DoubleType(), True)]) - >>> df = sqlCtx.applySchema(rdd, schema) - >>> results = df.map( - ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1, - ... x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) - >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE - (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), - datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) - - >>> df.registerTempTable("table2") - >>> sqlCtx.sql( - ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + - ... "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + - ... "float1 + 1.5 as float1 FROM table2").collect() - [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int1=2147483646, float1=2.5)] - - >>> from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type - >>> rdd = sc.parallelize([(127, -32768, 1.0, - ... datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, (2,), [1, 2, 3])]) - >>> abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]" - >>> schema = _parse_schema_abstract(abstract) - >>> typedSchema = _infer_schema_type(rdd.first(), schema) - >>> df = sqlCtx.applySchema(rdd, typedSchema) >>> df.collect() - [Row(byte1=127, short1=-32768, float1=1.0, time1=..., list1=[1, 2, 3])] + [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] """ if isinstance(rdd, DataFrame): @@ -459,46 +387,28 @@ class SQLContext(object): >>> import tempfile, shutil >>> jsonFile = tempfile.mkdtemp() >>> shutil.rmtree(jsonFile) - >>> ofn = open(jsonFile, 'w') - >>> for json in jsonStrings: - ... print>>ofn, json - >>> ofn.close() + >>> with open(jsonFile, 'w') as f: + ... f.writelines(jsonStrings) >>> df1 = sqlCtx.jsonFile(jsonFile) - >>> sqlCtx.registerDataFrameAsTable(df1, "table1") - >>> df2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table1") - >>> for r in df2.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema) - >>> sqlCtx.registerDataFrameAsTable(df3, "table2") - >>> df4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table2") - >>> for r in df4.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + >>> df1.printSchema() + root + |-- field1: long (nullable = true) + |-- field2: string (nullable = true) + |-- field3: struct (nullable = true) + | |-- field4: long (nullable = true) >>> from pyspark.sql.types import * >>> schema = StructType([ - ... StructField("field2", StringType(), True), + ... StructField("field2", StringType()), ... StructField("field3", - ... StructType([ - ... StructField("field5", - ... ArrayType(IntegerType(), False), True)]), False)]) - >>> df5 = sqlCtx.jsonFile(jsonFile, schema) - >>> sqlCtx.registerDataFrameAsTable(df5, "table3") - >>> df6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, " - ... "field3.field5[0] as f3 from table3") - >>> df6.collect() - [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] + ... StructType([StructField("field5", ArrayType(IntegerType()))]))]) + >>> df2 = sqlCtx.jsonFile(jsonFile, schema) + >>> df2.printSchema() + root + |-- field2: string (nullable = true) + |-- field3: struct (nullable = true) + | |-- field5: array (nullable = true) + | | |-- element: integer (containsNull = true) """ if schema is None: df = self._ssql_ctx.jsonFile(path, samplingRatio) @@ -517,48 +427,23 @@ class SQLContext(object): determine the schema. >>> df1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerDataFrameAsTable(df1, "table1") - >>> df2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table1") - >>> for r in df2.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> df3 = sqlCtx.jsonRDD(json, df1.schema) - >>> sqlCtx.registerDataFrameAsTable(df3, "table2") - >>> df4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table2") - >>> for r in df4.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + >>> df1.first() + Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) + + >>> df2 = sqlCtx.jsonRDD(json, df1.schema) + >>> df2.first() + Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) >>> from pyspark.sql.types import * >>> schema = StructType([ - ... StructField("field2", StringType(), True), + ... StructField("field2", StringType()), ... StructField("field3", - ... StructType([ - ... StructField("field5", - ... ArrayType(IntegerType(), False), True)]), False)]) - >>> df5 = sqlCtx.jsonRDD(json, schema) - >>> sqlCtx.registerDataFrameAsTable(df5, "table3") - >>> df6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, " - ... "field3.field5[0] as f3 from table3") - >>> df6.collect() - [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] - - >>> sqlCtx.jsonRDD(sc.parallelize(['{}', - ... '{"key0": {"key1": "value1"}}'])).collect() - [Row(key0=None), Row(key0=Row(key1=u'value1'))] - >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}', - ... '{"key0": {"key1": "value1"}}'])).collect() - [Row(key0=None), Row(key0=Row(key1=u'value1'))] + ... StructType([StructField("field5", ArrayType(IntegerType()))])) + ... ]) + >>> df3 = sqlCtx.jsonRDD(json, schema) + >>> df3.first() + Row(field2=u'row1', field3=Row(field5=None)) + """ def func(iterator): @@ -848,7 +733,8 @@ def _test(): globs['jsonStrings'] = jsonStrings globs['json'] = sc.parallelize(jsonStrings) (failure_count, test_count) = doctest.testmod( - pyspark.sql.context, globs=globs, optionflags=doctest.ELLIPSIS) + pyspark.sql.context, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) globs['sc'].stop() if failure_count: exit(-1) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 6f746d136b22d734ae99843dd42cf9c2cf9f8bcc..6d42410020b64f39b9229c1f127a5a4b552a16f7 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -96,7 +96,7 @@ class DataFrame(object): return self._lazy_rdd def toJSON(self, use_unicode=False): - """Convert a DataFrame into a MappedRDD of JSON documents; one document per row. + """Convert a :class:`DataFrame` into a MappedRDD of JSON documents; one document per row. >>> df.toJSON().first() '{"age":2,"name":"Alice"}' @@ -108,7 +108,7 @@ class DataFrame(object): """Save the contents as a Parquet file, preserving the schema. Files that are written out using this method can be read back in as - a DataFrame using the L{SQLContext.parquetFile} method. + a :class:`DataFrame` using the L{SQLContext.parquetFile} method. >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() @@ -139,7 +139,7 @@ class DataFrame(object): self.registerTempTable(name) def insertInto(self, tableName, overwrite=False): - """Inserts the contents of this DataFrame into the specified table. + """Inserts the contents of this :class:`DataFrame` into the specified table. Optionally overwriting any existing data. """ @@ -165,7 +165,7 @@ class DataFrame(object): return jmode def saveAsTable(self, tableName, source=None, mode="append", **options): - """Saves the contents of the DataFrame to a data source as a table. + """Saves the contents of the :class:`DataFrame` to a data source as a table. The data source is specified by the `source` and a set of `options`. If `source` is not specified, the default data source configured by @@ -174,12 +174,13 @@ class DataFrame(object): Additionally, mode is used to specify the behavior of the saveAsTable operation when table already exists in the data source. There are four modes: - * append: Contents of this DataFrame are expected to be appended to existing table. - * overwrite: Data in the existing table is expected to be overwritten by the contents of \ - this DataFrame. + * append: Contents of this :class:`DataFrame` are expected to be appended \ + to existing table. + * overwrite: Data in the existing table is expected to be overwritten by \ + the contents of this DataFrame. * error: An exception is expected to be thrown. - * ignore: The save operation is expected to not save the contents of the DataFrame and \ - to not change the existing table. + * ignore: The save operation is expected to not save the contents of the \ + :class:`DataFrame` and to not change the existing table. """ if source is None: source = self.sql_ctx.getConf("spark.sql.sources.default", @@ -190,7 +191,7 @@ class DataFrame(object): self._jdf.saveAsTable(tableName, source, jmode, joptions) def save(self, path=None, source=None, mode="append", **options): - """Saves the contents of the DataFrame to a data source. + """Saves the contents of the :class:`DataFrame` to a data source. The data source is specified by the `source` and a set of `options`. If `source` is not specified, the default data source configured by @@ -199,11 +200,11 @@ class DataFrame(object): Additionally, mode is used to specify the behavior of the save operation when data already exists in the data source. There are four modes: - * append: Contents of this DataFrame are expected to be appended to existing data. + * append: Contents of this :class:`DataFrame` are expected to be appended to existing data. * overwrite: Existing data is expected to be overwritten by the contents of this DataFrame. * error: An exception is expected to be thrown. - * ignore: The save operation is expected to not save the contents of the DataFrame and \ - to not change the existing data. + * ignore: The save operation is expected to not save the contents of \ + the :class:`DataFrame` and to not change the existing data. """ if path is not None: options["path"] = path @@ -217,7 +218,7 @@ class DataFrame(object): @property def schema(self): - """Returns the schema of this DataFrame (represented by + """Returns the schema of this :class:`DataFrame` (represented by a L{StructType}). >>> df.schema @@ -275,12 +276,12 @@ class DataFrame(object): """ Print the first 20 rows. + >>> df + DataFrame[age: int, name: string] >>> df.show() age name 2 Alice 5 Bob - >>> df - DataFrame[age: int, name: string] """ print self._jdf.showString().encode('utf8', 'ignore') @@ -481,8 +482,8 @@ class DataFrame(object): def join(self, other, joinExprs=None, joinType=None): """ - Join with another DataFrame, using the given join expression. - The following performs a full outer join between `df1` and `df2`:: + Join with another :class:`DataFrame`, using the given join expression. + The following performs a full outer join between `df1` and `df2`. :param other: Right side of the join :param joinExprs: Join expression @@ -582,8 +583,6 @@ class DataFrame(object): def select(self, *cols): """ Selecting a set of expressions. - >>> df.select().collect() - [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] >>> df.select('*').collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] >>> df.select('name', 'age').collect() @@ -591,8 +590,6 @@ class DataFrame(object): >>> df.select(df.name, (df.age + 10).alias('age')).collect() [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] """ - if not cols: - cols = ["*"] jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) @@ -612,7 +609,7 @@ class DataFrame(object): def filter(self, condition): """ Filtering rows using the given condition, which could be - Column expression or string of SQL expression. + :class:`Column` expression or string of SQL expression. where() is an alias for filter(). @@ -666,7 +663,7 @@ class DataFrame(object): return self.groupBy().agg(*exprs) def unionAll(self, other): - """ Return a new DataFrame containing union of rows in this + """ Return a new :class:`DataFrame` containing union of rows in this frame and another frame. This is equivalent to `UNION ALL` in SQL. @@ -919,9 +916,10 @@ class Column(object): """ A column in a DataFrame. - `Column` instances can be created by:: + :class:`Column` instances can be created by:: # 1. Select a column out of a DataFrame + df.colName df["colName"] @@ -975,7 +973,7 @@ class Column(object): def substr(self, startPos, length): """ - Return a Column which is a substring of the column + Return a :class:`Column` which is a substring of the column :param startPos: start position (int or Column) :param length: length of the substring (int or Column) @@ -996,8 +994,10 @@ class Column(object): __getslice__ = substr # order - asc = _unary_op("asc") - desc = _unary_op("desc") + asc = _unary_op("asc", "Returns a sort expression based on the" + " ascending order of the given column name.") + desc = _unary_op("desc", "Returns a sort expression based on the" + " descending order of the given column name.") isNull = _unary_op("isNull", "True if the current expression is null.") isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8aa44765205c1d7b316be07fa1f2500c45f6ec08..5873f09ae3275a360a0bed54a4667899a0cc9077 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -72,6 +72,7 @@ for _name, _doc in _functions.items(): globals()[_name] = _create_function(_name, _doc) del _name, _doc __all__ += _functions.keys() +__all__.sort() def countDistinct(col, *cols): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 39071e7e35ca18b78afab557e4bfe6ec3fcab807..83899ad4b1b1230e0ab0018e98e6d378fae51870 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -36,9 +36,9 @@ if sys.version_info[:2] <= (2, 6): else: import unittest -from pyspark.sql import SQLContext, HiveContext, Column -from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \ - UserDefinedType, DoubleType, LongType, StringType, _infer_type +from pyspark.sql import SQLContext, HiveContext, Column, Row +from pyspark.sql.types import * +from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase @@ -204,6 +204,68 @@ class SQLTests(ReusedPySparkTestCase): result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) + def test_infer_nested_schema(self): + NestedRow = Row("f1", "f2") + nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), + NestedRow([2, 3], {"row2": 2.0})]) + df = self.sqlCtx.inferSchema(nestedRdd1) + self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0]) + + nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]), + NestedRow([[2, 3], [3, 4]], [2, 3])]) + df = self.sqlCtx.inferSchema(nestedRdd2) + self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0]) + + from collections import namedtuple + CustomRow = namedtuple('CustomRow', 'field1 field2') + rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"), + CustomRow(field1=2, field2="row2"), + CustomRow(field1=3, field2="row3")]) + df = self.sqlCtx.inferSchema(rdd) + self.assertEquals(Row(field1=1, field2=u'row1'), df.first()) + + def test_apply_schema(self): + from datetime import date, datetime + rdd = self.sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, + date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1), + {"a": 1}, (2,), [1, 2, 3], None)]) + schema = StructType([ + StructField("byte1", ByteType(), False), + StructField("byte2", ByteType(), False), + StructField("short1", ShortType(), False), + StructField("short2", ShortType(), False), + StructField("int1", IntegerType(), False), + StructField("float1", FloatType(), False), + StructField("date1", DateType(), False), + StructField("time1", TimestampType(), False), + StructField("map1", MapType(StringType(), IntegerType(), False), False), + StructField("struct1", StructType([StructField("b", ShortType(), False)]), False), + StructField("list1", ArrayType(ByteType(), False), False), + StructField("null1", DoubleType(), True)]) + df = self.sqlCtx.applySchema(rdd, schema) + results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1, + x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) + r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), + datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + self.assertEqual(r, results.first()) + + df.registerTempTable("table2") + r = self.sqlCtx.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + + "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + + "float1 + 1.5 as float1 FROM table2").first() + + self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r)) + + from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type + rdd = self.sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), + {"a": 1}, (2,), [1, 2, 3])]) + abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]" + schema = _parse_schema_abstract(abstract) + typedSchema = _infer_schema_type(rdd.first(), schema) + df = self.sqlCtx.applySchema(rdd, typedSchema) + r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3]) + self.assertEqual(r, tuple(df.first())) + def test_struct_in_map(self): d = [Row(m={Row(i=1): Row(s="")})] df = self.sc.parallelize(d).toDF() diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index b6e41cf0b29fff6b85cc31691c87157d44adea49..0f5dc2be6dab889b65e9c0575b9b6670c4305db7 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -28,7 +28,7 @@ from operator import itemgetter __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", - "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", ] + "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"] class DataType(object):