From 0ee53ebce9944722e76b2b28fae79d9956be9f17 Mon Sep 17 00:00:00 2001 From: Wenchen Fan <cloud0fan@outlook.com> Date: Mon, 9 Feb 2015 16:39:34 -0800 Subject: [PATCH] [SPARK-2096][SQL] support dot notation on array of struct ~~The rule is simple: If you want `a.b` work, then `a` must be some level of nested array of struct(level 0 means just a StructType). And the result of `a.b` is same level of nested array of b-type. An optimization is: the resolve chain looks like `Attribute -> GetItem -> GetField -> GetField ...`, so we could transmit the nested array information between `GetItem` and `GetField` to avoid repeated computation of `innerDataType` and `containsNullList` of that nested array.~~ marmbrus Could you take a look? to evaluate `a.b`, if `a` is array of struct, then `a.b` means get field `b` on each element of `a`, and return a result of array. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #2405 from cloud-fan/nested-array-dot and squashes the following commits: 08a228a [Wenchen Fan] support dot notation on array of struct --- .../sql/catalyst/analysis/Analyzer.scala | 30 +++++++++------- .../catalyst/expressions/complexTypes.scala | 34 ++++++++++++++++--- .../sql/catalyst/optimizer/Optimizer.scala | 3 +- .../ExpressionEvaluationSuite.scala | 2 +- .../org/apache/spark/sql/json/JsonSuite.scala | 6 ++-- 5 files changed, 53 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0b59ed1739..fb2ff014ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -22,8 +22,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{ArrayType, StructField, StructType, IntegerType} /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing @@ -311,18 +310,25 @@ class Analyzer(catalog: Catalog, * desired fields are found. */ protected def resolveGetField(expr: Expression, fieldName: String): Expression = { + def findField(fields: Array[StructField]): Int = { + val checkField = (f: StructField) => resolver(f.name, fieldName) + val ordinal = fields.indexWhere(checkField) + if (ordinal == -1) { + sys.error( + s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") + } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { + sys.error(s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") + } else { + ordinal + } + } expr.dataType match { case StructType(fields) => - val actualField = fields.filter(f => resolver(f.name, fieldName)) - if (actualField.length == 0) { - sys.error( - s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") - } else if (actualField.length == 1) { - val field = actualField(0) - GetField(expr, field, fields.indexOf(field)) - } else { - sys.error(s"Ambiguous reference to fields ${actualField.mkString(", ")}") - } + val ordinal = findField(fields) + StructGetField(expr, fields(ordinal), ordinal) + case ArrayType(StructType(fields), containsNull) => + val ordinal = findField(fields) + ArrayGetField(expr, fields(ordinal), ordinal, containsNull) case otherType => sys.error(s"GetField is not valid on fields of type $otherType") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 66e2e5c4ba..68051a2a20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -70,22 +70,48 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { } } + +trait GetField extends UnaryExpression { + self: Product => + + type EvaluatedType = Any + override def foldable = child.foldable + override def toString = s"$child.${field.name}" + + def field: StructField +} + /** * Returns the value of fields in the Struct `child`. */ -case class GetField(child: Expression, field: StructField, ordinal: Int) extends UnaryExpression { - type EvaluatedType = Any +case class StructGetField(child: Expression, field: StructField, ordinal: Int) extends GetField { def dataType = field.dataType override def nullable = child.nullable || field.nullable - override def foldable = child.foldable override def eval(input: Row): Any = { val baseValue = child.eval(input).asInstanceOf[Row] if (baseValue == null) null else baseValue(ordinal) } +} - override def toString = s"$child.${field.name}" +/** + * Returns the array of value of fields in the Array of Struct `child`. + */ +case class ArrayGetField(child: Expression, field: StructField, ordinal: Int, containsNull: Boolean) + extends GetField { + + def dataType = ArrayType(field.dataType, containsNull) + override def nullable = child.nullable + + override def eval(input: Row): Any = { + val baseValue = child.eval(input).asInstanceOf[Seq[Row]] + if (baseValue == null) null else { + baseValue.map { row => + if (row == null) null else row(ordinal) + } + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index fd58b9681e..0da081ed1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -209,7 +209,8 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ IsNotNull(c) if !c.nullable => Literal(true, BooleanType) case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType) case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType) - case e @ GetField(Literal(null, _), _, _) => Literal(null, e.dataType) + case e @ StructGetField(Literal(null, _), _, _) => Literal(null, e.dataType) + case e @ ArrayGetField(Literal(null, _), _, _, _) => Literal(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case e @ Count(expr) if !expr.nullable => Count(Literal(1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 7cf6c80194..dcfd8b28cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -851,7 +851,7 @@ class ExpressionEvaluationSuite extends FunSuite { expr.dataType match { case StructType(fields) => val field = fields.find(_.name == fieldName).get - GetField(expr, field, fields.indexOf(field)) + StructGetField(expr, field, fields.indexOf(field)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 926ba68828..7870cf9b0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -342,21 +342,19 @@ class JsonSuite extends QueryTest { ) } - ignore("Complex field and type inferring (Ignored)") { + test("GetField operation on complex data type") { val jsonDF = jsonRDD(complexFieldAndType1) jsonDF.registerTempTable("jsonTable") - // Right now, "field1" and "field2" are treated as aliases. We should fix it. checkAnswer( sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), Row(true, "str1") ) - // Right now, the analyzer cannot resolve arrayOfStruct.field1 and arrayOfStruct.field2. // Getting all values of a specific field from an array of structs. checkAnswer( sql("select arrayOfStruct.field1, arrayOfStruct.field2 from jsonTable"), - Row(Seq(true, false), Seq("str1", null)) + Row(Seq(true, false, null), Seq("str1", null, null)) ) } -- GitLab