Skip to content
Snippets Groups Projects
Commit 4fdb4917 authored by Kan Zhang's avatar Kan Zhang Committed by Reynold Xin
Browse files

[SPARK-2010] Support for nested data in PySpark SQL

JIRA issue https://issues.apache.org/jira/browse/SPARK-2010

This PR adds support for nested collection types in PySpark SQL, including
array, dict, list, set, and tuple. Example,

```
>>> from array import array
>>> from pyspark.sql import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([
...         {"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
...         {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}])
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
...                    {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]
True
>>> rdd = sc.parallelize([
...         {"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
...         {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}])
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.collect() == \
... [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
...  {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]
True
```

Author: Kan Zhang <kzhang@apache.org>

Closes #1041 from kanzhang/SPARK-2010 and squashes the following commits:

1b2891d [Kan Zhang] [SPARK-2010] minor doc change and adding a TODO
504f27e [Kan Zhang] [SPARK-2010] Support for nested data in PySpark SQL
parent 716c88aa
No related branches found
No related tags found
No related merge requests found
...@@ -77,12 +77,25 @@ class SQLContext: ...@@ -77,12 +77,25 @@ class SQLContext:
"""Infer and apply a schema to an RDD of L{dict}s. """Infer and apply a schema to an RDD of L{dict}s.
We peek at the first row of the RDD to determine the fields names We peek at the first row of the RDD to determine the fields names
and types, and then use that to extract all the dictionaries. and types, and then use that to extract all the dictionaries. Nested
collections are supported, which include array, dict, list, set, and
tuple.
>>> srdd = sqlCtx.inferSchema(rdd) >>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
... {"field1" : 3, "field2": "row3"}] ... {"field1" : 3, "field2": "row3"}]
True True
>>> from array import array
>>> srdd = sqlCtx.inferSchema(nestedRdd1)
>>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]
True
>>> srdd = sqlCtx.inferSchema(nestedRdd2)
>>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]
True
""" """
if (rdd.__class__ is SchemaRDD): if (rdd.__class__ is SchemaRDD):
raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__) raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__)
...@@ -413,6 +426,7 @@ class SchemaRDD(RDD): ...@@ -413,6 +426,7 @@ class SchemaRDD(RDD):
def _test(): def _test():
import doctest import doctest
from array import array
from pyspark.context import SparkContext from pyspark.context import SparkContext
globs = globals().copy() globs = globals().copy()
# The small batch size here ensures that we see multiple batches, # The small batch size here ensures that we see multiple batches,
...@@ -422,6 +436,12 @@ def _test(): ...@@ -422,6 +436,12 @@ def _test():
globs['sqlCtx'] = SQLContext(sc) globs['sqlCtx'] = SQLContext(sc)
globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"}, globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"},
{"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}]) {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
globs['nestedRdd1'] = sc.parallelize([
{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
{"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}])
globs['nestedRdd2'] = sc.parallelize([
{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
{"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}])
(failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS) (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)
globs['sc'].stop() globs['sc'].stop()
if failure_count: if failure_count:
......
...@@ -298,19 +298,28 @@ class SQLContext(@transient val sparkContext: SparkContext) ...@@ -298,19 +298,28 @@ class SQLContext(@transient val sparkContext: SparkContext)
/** /**
* Peek at the first row of the RDD and infer its schema. * Peek at the first row of the RDD and infer its schema.
* TODO: We only support primitive types, add support for nested types. * TODO: consolidate this with the type system developed in SPARK-2060.
*/ */
private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = { private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = {
import scala.collection.JavaConversions._
def typeFor(obj: Any): DataType = obj match {
case c: java.lang.String => StringType
case c: java.lang.Integer => IntegerType
case c: java.lang.Long => LongType
case c: java.lang.Double => DoubleType
case c: java.lang.Boolean => BooleanType
case c: java.util.List[_] => ArrayType(typeFor(c.head))
case c: java.util.Set[_] => ArrayType(typeFor(c.head))
case c: java.util.Map[_, _] =>
val (key, value) = c.head
MapType(typeFor(key), typeFor(value))
case c if c.getClass.isArray =>
val elem = c.asInstanceOf[Array[_]].head
ArrayType(typeFor(elem))
case c => throw new Exception(s"Object of type $c cannot be used")
}
val schema = rdd.first().map { case (fieldName, obj) => val schema = rdd.first().map { case (fieldName, obj) =>
val dataType = obj.getClass match { AttributeReference(fieldName, typeFor(obj), true)()
case c: Class[_] if c == classOf[java.lang.String] => StringType
case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
case c: Class[_] if c == classOf[java.lang.Long] => LongType
case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
case c => throw new Exception(s"Object of type $c cannot be used")
}
AttributeReference(fieldName, dataType, true)()
}.toSeq }.toSeq
val rowRdd = rdd.mapPartitions { iter => val rowRdd = rdd.mapPartitions { iter =>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment