From 19530da6903fa59b051eec69b9c17e231c68454b Mon Sep 17 00:00:00 2001
From: Wenchen Fan <wenchen@databricks.com>
Date: Tue, 24 Nov 2015 11:09:01 -0800
Subject: [PATCH] [SPARK-11926][SQL] unify GetStructField and
 GetInternalRowField

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9909 from cloud-fan/get-struct.
---
 .../spark/sql/catalyst/ScalaReflection.scala  |  2 +-
 .../sql/catalyst/analysis/unresolved.scala    |  8 +++----
 .../catalyst/encoders/ExpressionEncoder.scala |  2 +-
 .../sql/catalyst/encoders/RowEncoder.scala    |  2 +-
 .../sql/catalyst/expressions/Expression.scala |  2 +-
 .../expressions/complexTypeExtractors.scala   | 18 ++++++++--------
 .../expressions/namedExpressions.scala        |  4 ++--
 .../sql/catalyst/expressions/objects.scala    | 21 -------------------
 .../expressions/ComplexTypeSuite.scala        |  4 ++--
 9 files changed, 21 insertions(+), 42 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 476becec4d..d133ad3f0d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -130,7 +130,7 @@ object ScalaReflection extends ScalaReflection {
 
     /** Returns the current path with a field at ordinal extracted. */
     def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path
-      .map(p => GetInternalRowField(p, ordinal, dataType))
+      .map(p => GetStructField(p, ordinal))
       .getOrElse(BoundReference(ordinal, dataType, false))
 
     /** Returns the current path or `BoundReference`. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 6485bdfb30..1b2a8dc4c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -201,12 +201,12 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu
     if (attribute.isDefined) {
       // This target resolved to an attribute in child. It must be a struct. Expand it.
       attribute.get.dataType match {
-        case s: StructType => {
-          s.fields.map( f => {
-            val extract = GetStructField(attribute.get, f, s.getFieldIndex(f.name).get)
+        case s: StructType => s.zipWithIndex.map {
+          case (f, i) =>
+            val extract = GetStructField(attribute.get, i)
             Alias(extract, target.get + "." + f.name)()
-          })
         }
+
         case _ => {
           throw new AnalysisException("Can only star expand struct data types. Attribute: `" +
             target.get + "`")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 7bc9aed0b2..0c10a56c55 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -111,7 +111,7 @@ object ExpressionEncoder {
           case UnresolvedAttribute(nameParts) =>
             assert(nameParts.length == 1)
             UnresolvedExtractValue(input, Literal(nameParts.head))
-          case BoundReference(ordinal, dt, _) => GetInternalRowField(input, ordinal, dt)
+          case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal)
         }
       }
     }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index fa553e7c53..67518f52d4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -220,7 +220,7 @@ object RowEncoder {
         If(
           Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
           Literal.create(null, externalDataTypeFor(f.dataType)),
-          constructorFor(GetInternalRowField(input, i, f.dataType)))
+          constructorFor(GetStructField(input, i)))
       }
       CreateExternalRow(convertedFields)
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 540ed35006..169435a10e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -206,7 +206,7 @@ abstract class Expression extends TreeNode[Expression] {
    */
   def prettyString: String = {
     transform {
-      case a: AttributeReference => PrettyAttribute(a.name)
+      case a: AttributeReference => PrettyAttribute(a.name, a.dataType)
       case u: UnresolvedAttribute => PrettyAttribute(u.name)
     }.toString
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index f871b737ff..10ce10aaf6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -51,7 +51,7 @@ object ExtractValue {
       case (StructType(fields), NonNullLiteral(v, StringType)) =>
         val fieldName = v.toString
         val ordinal = findField(fields, fieldName, resolver)
-        GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal)
+        GetStructField(child, ordinal, Some(fieldName))
 
       case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) =>
         val fieldName = v.toString
@@ -97,18 +97,18 @@ object ExtractValue {
  * Returns the value of fields in the Struct `child`.
  *
  * No need to do type checking since it is handled by [[ExtractValue]].
- * TODO: Unify with [[GetInternalRowField]], remove the need to specify a [[StructField]].
+ *
+ * Note that we can pass in the field name directly to keep case preserving in `toString`.
+ * For example, when get field `yEAr` from `<year: int, month: int>`, we should pass in `yEAr`.
  */
-case class GetStructField(child: Expression, field: StructField, ordinal: Int)
+case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None)
   extends UnaryExpression {
 
-  override def dataType: DataType = child.dataType match {
-    case s: StructType => s(ordinal).dataType
-    // This is a hack to avoid breaking existing code until we remove the need for the struct field
-    case _ => field.dataType
-  }
+  private lazy val field = child.dataType.asInstanceOf[StructType](ordinal)
+
+  override def dataType: DataType = field.dataType
   override def nullable: Boolean = child.nullable || field.nullable
-  override def toString: String = s"$child.${field.name}"
+  override def toString: String = s"$child.${name.getOrElse(field.name)}"
 
   protected override def nullSafeEval(input: Any): Any =
     input.asInstanceOf[InternalRow].get(ordinal, field.dataType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 00b7970bd1..26b6aca799 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -273,7 +273,8 @@ case class AttributeReference(
  * A place holder used when printing expressions without debugging information such as the
  * expression id or the unresolved indicator.
  */
-case class PrettyAttribute(name: String) extends Attribute with Unevaluable {
+case class PrettyAttribute(name: String, dataType: DataType = NullType)
+  extends Attribute with Unevaluable {
 
   override def toString: String = name
 
@@ -286,7 +287,6 @@ case class PrettyAttribute(name: String) extends Attribute with Unevaluable {
   override def qualifiers: Seq[String] = throw new UnsupportedOperationException
   override def exprId: ExprId = throw new UnsupportedOperationException
   override def nullable: Boolean = throw new UnsupportedOperationException
-  override def dataType: DataType = NullType
 }
 
 object VirtualColumn {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 4a1f419f0a..62d09f0f55 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -517,27 +517,6 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression {
   }
 }
 
-case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType)
-  extends UnaryExpression {
-
-  override def nullable: Boolean = true
-
-  override def eval(input: InternalRow): Any =
-    throw new UnsupportedOperationException("Only code-generated evaluation is supported")
-
-  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    nullSafeCodeGen(ctx, ev, eval => {
-      s"""
-        if ($eval.isNullAt($ordinal)) {
-          ${ev.isNull} = true;
-        } else {
-          ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)};
-        }
-      """
-    })
-  }
-}
-
 /**
  * Serializes an input object using a generic serializer (Kryo or Java).
  * @param kryo if true, use Kryo. Otherwise, use Java.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index e60990aeb4..62fd47234b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -79,8 +79,8 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
     def getStructField(expr: Expression, fieldName: String): GetStructField = {
       expr.dataType match {
         case StructType(fields) =>
-          val field = fields.find(_.name == fieldName).get
-          GetStructField(expr, field, fields.indexOf(field))
+          val index = fields.indexWhere(_.name == fieldName)
+          GetStructField(expr, index)
       }
     }
 
-- 
GitLab