From 4fdb491775bb9c4afa40477dc0069ff6fcadfe25 Mon Sep 17 00:00:00 2001
From: Kan Zhang <kzhang@apache.org>
Date: Mon, 16 Jun 2014 11:11:29 -0700
Subject: [PATCH] [SPARK-2010] Support for nested data in PySpark SQL

JIRA issue https://issues.apache.org/jira/browse/SPARK-2010

This PR adds support for nested collection types in PySpark SQL, including
array, dict, list, set, and tuple. Example,

```
>>> from array import array
>>> from pyspark.sql import SQLContext
>>> sqlCtx = SQLContext(sc)
>>> rdd = sc.parallelize([
...         {"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
...         {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}])
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
...                    {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]
True
>>> rdd = sc.parallelize([
...         {"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
...         {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}])
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.collect() == \
... [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
...  {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]
True
```

Author: Kan Zhang <kzhang@apache.org>

Closes #1041 from kanzhang/SPARK-2010 and squashes the following commits:

1b2891d [Kan Zhang] [SPARK-2010] minor doc change and adding a TODO
504f27e [Kan Zhang] [SPARK-2010] Support for nested data in PySpark SQL
---
 python/pyspark/sql.py                         | 22 +++++++++++++-
 .../org/apache/spark/sql/SQLContext.scala     | 29 ++++++++++++-------
 2 files changed, 40 insertions(+), 11 deletions(-)

diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index e344610b1f..c31d49ce83 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -77,12 +77,25 @@ class SQLContext:
         """Infer and apply a schema to an RDD of L{dict}s.
 
         We peek at the first row of the RDD to determine the fields names
-        and types, and then use that to extract all the dictionaries.
+        and types, and then use that to extract all the dictionaries. Nested
+        collections are supported, which include array, dict, list, set, and
+        tuple.
 
         >>> srdd = sqlCtx.inferSchema(rdd)
         >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
         ...                    {"field1" : 3, "field2": "row3"}]
         True
+
+        >>> from array import array
+        >>> srdd = sqlCtx.inferSchema(nestedRdd1)
+        >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
+        ...                    {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]
+        True
+
+        >>> srdd = sqlCtx.inferSchema(nestedRdd2)
+        >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
+        ...                    {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]
+        True
         """
         if (rdd.__class__ is SchemaRDD):
             raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__)
@@ -413,6 +426,7 @@ class SchemaRDD(RDD):
 
 def _test():
     import doctest
+    from array import array
     from pyspark.context import SparkContext
     globs = globals().copy()
     # The small batch size here ensures that we see multiple batches,
@@ -422,6 +436,12 @@ def _test():
     globs['sqlCtx'] = SQLContext(sc)
     globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"},
         {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
+    globs['nestedRdd1'] = sc.parallelize([
+        {"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
+        {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}])
+    globs['nestedRdd2'] = sc.parallelize([
+        {"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
+        {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}])
     (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)
     globs['sc'].stop()
     if failure_count:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 378ff54531..131c130bbb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -298,19 +298,28 @@ class SQLContext(@transient val sparkContext: SparkContext)
 
   /**
    * Peek at the first row of the RDD and infer its schema.
-   * TODO: We only support primitive types, add support for nested types.
+   * TODO: consolidate this with the type system developed in SPARK-2060.
    */
   private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = {
+    import scala.collection.JavaConversions._
+    def typeFor(obj: Any): DataType = obj match {
+      case c: java.lang.String => StringType
+      case c: java.lang.Integer => IntegerType
+      case c: java.lang.Long => LongType
+      case c: java.lang.Double => DoubleType
+      case c: java.lang.Boolean => BooleanType
+      case c: java.util.List[_] => ArrayType(typeFor(c.head))
+      case c: java.util.Set[_] => ArrayType(typeFor(c.head))
+      case c: java.util.Map[_, _] =>
+        val (key, value) = c.head
+        MapType(typeFor(key), typeFor(value))
+      case c if c.getClass.isArray =>
+        val elem = c.asInstanceOf[Array[_]].head
+        ArrayType(typeFor(elem))
+      case c => throw new Exception(s"Object of type $c cannot be used")
+    }
     val schema = rdd.first().map { case (fieldName, obj) =>
-      val dataType = obj.getClass match {
-        case c: Class[_] if c == classOf[java.lang.String] => StringType
-        case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
-        case c: Class[_] if c == classOf[java.lang.Long] => LongType
-        case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
-        case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
-        case c => throw new Exception(s"Object of type $c cannot be used")
-      }
-      AttributeReference(fieldName, dataType, true)()
+      AttributeReference(fieldName, typeFor(obj), true)()
     }.toSeq
 
     val rowRdd = rdd.mapPartitions { iter =>
-- 
GitLab