From 0ee53ebce9944722e76b2b28fae79d9956be9f17 Mon Sep 17 00:00:00 2001
From: Wenchen Fan <cloud0fan@outlook.com>
Date: Mon, 9 Feb 2015 16:39:34 -0800
Subject: [PATCH] [SPARK-2096][SQL] support dot notation on array of struct

~~The rule is simple: If you want `a.b` work, then `a` must be some level of nested array of struct(level 0 means just a StructType). And the result of `a.b` is same level of nested array of b-type.
An optimization is: the resolve chain looks like `Attribute -> GetItem -> GetField -> GetField ...`, so we could transmit the nested array information between `GetItem` and `GetField` to avoid repeated computation of `innerDataType` and `containsNullList` of that nested array.~~
marmbrus Could you take a look?

to evaluate `a.b`, if `a` is array of struct, then `a.b` means get field `b` on each element of `a`, and return a result of array.

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #2405 from cloud-fan/nested-array-dot and squashes the following commits:

08a228a [Wenchen Fan] support dot notation on array of struct
---
 .../sql/catalyst/analysis/Analyzer.scala      | 30 +++++++++-------
 .../catalyst/expressions/complexTypes.scala   | 34 ++++++++++++++++---
 .../sql/catalyst/optimizer/Optimizer.scala    |  3 +-
 .../ExpressionEvaluationSuite.scala           |  2 +-
 .../org/apache/spark/sql/json/JsonSuite.scala |  6 ++--
 5 files changed, 53 insertions(+), 22 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 0b59ed1739..fb2ff014ce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -22,8 +22,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.{ArrayType, StructField, StructType, IntegerType}
 
 /**
  * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
@@ -311,18 +310,25 @@ class Analyzer(catalog: Catalog,
      * desired fields are found.
      */
     protected def resolveGetField(expr: Expression, fieldName: String): Expression = {
+      def findField(fields: Array[StructField]): Int = {
+        val checkField = (f: StructField) => resolver(f.name, fieldName)
+        val ordinal = fields.indexWhere(checkField)
+        if (ordinal == -1) {
+          sys.error(
+            s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
+        } else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
+          sys.error(s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
+        } else {
+          ordinal
+        }
+      }
       expr.dataType match {
         case StructType(fields) =>
-          val actualField = fields.filter(f => resolver(f.name, fieldName))
-          if (actualField.length == 0) {
-            sys.error(
-              s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
-          } else if (actualField.length == 1) {
-            val field = actualField(0)
-            GetField(expr, field, fields.indexOf(field))
-          } else {
-            sys.error(s"Ambiguous reference to fields ${actualField.mkString(", ")}")
-          }
+          val ordinal = findField(fields)
+          StructGetField(expr, fields(ordinal), ordinal)
+        case ArrayType(StructType(fields), containsNull) =>
+          val ordinal = findField(fields)
+          ArrayGetField(expr, fields(ordinal), ordinal, containsNull)
         case otherType => sys.error(s"GetField is not valid on fields of type $otherType")
       }
     }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
index 66e2e5c4ba..68051a2a20 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala
@@ -70,22 +70,48 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
   }
 }
 
+
+trait GetField extends UnaryExpression {
+  self: Product =>
+
+  type EvaluatedType = Any
+  override def foldable = child.foldable
+  override def toString = s"$child.${field.name}"
+
+  def field: StructField
+}
+
 /**
  * Returns the value of fields in the Struct `child`.
  */
-case class GetField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression {
-  type EvaluatedType = Any
+case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField {
 
   def dataType = field.dataType
   override def nullable = child.nullable || field.nullable
-  override def foldable = child.foldable
 
   override def eval(input: Row): Any = {
     val baseValue = child.eval(input).asInstanceOf[Row]
     if (baseValue == null) null else baseValue(ordinal)
   }
+}
 
-  override def toString = s"$child.${field.name}"
+/**
+ * Returns the array of value of fields in the Array of Struct `child`.
+ */
+case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean)
+  extends GetField {
+
+  def dataType = ArrayType(field.dataType, containsNull)
+  override def nullable = child.nullable
+
+  override def eval(input: Row): Any = {
+    val baseValue = child.eval(input).asInstanceOf[Seq[Row]]
+    if (baseValue == null) null else {
+      baseValue.map { row =>
+        if (row == null) null else row(ordinal)
+      }
+    }
+  }
 }
 
 /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index fd58b9681e..0da081ed1a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -209,7 +209,8 @@ object NullPropagation extends Rule[LogicalPlan] {
       case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType)
       case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType)
       case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType)
-      case e @ GetField(Literal(null, _), _, _) => Literal(null, e.dataType)
+      case e @ StructGetField(Literal(null, _), _, _) => Literal(null, e.dataType)
+      case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal(null, e.dataType)
       case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
       case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
       case e @ Count(expr) if !expr.nullable => Count(Literal(1))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 7cf6c80194..dcfd8b28cb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -851,7 +851,7 @@ class ExpressionEvaluationSuite extends FunSuite {
       expr.dataType match {
         case StructType(fields) =>
           val field = fields.find(_.name == fieldName).get
-          GetField(expr, field, fields.indexOf(field))
+          StructGetField(expr, field, fields.indexOf(field))
       }
     }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 926ba68828..7870cf9b0a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -342,21 +342,19 @@ class JsonSuite extends QueryTest {
     )
   }
 
-  ignore("Complex field and type inferring (Ignored)") {
+  test("GetField operation on complex data type") {
     val jsonDF = jsonRDD(complexFieldAndType1)
     jsonDF.registerTempTable("jsonTable")
 
-    // Right now, "field1" and "field2" are treated as aliases. We should fix it.
     checkAnswer(
       sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"),
       Row(true, "str1")
     )
 
-    // Right now, the analyzer cannot resolve arrayOfStruct.field1 and arrayOfStruct.field2.
     // Getting all values of a specific field from an array of structs.
     checkAnswer(
       sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"),
-      Row(Seq(true, false), Seq("str1", null))
+      Row(Seq(true, false, null), Seq("str1", null, null))
     )
   }
 
-- 
GitLab