From 19530da6903fa59b051eec69b9c17e231c68454b Mon Sep 17 00:00:00 2001 From: Wenchen Fan <wenchen@databricks.com> Date: Tue, 24 Nov 2015 11:09:01 -0800 Subject: [PATCH] [SPARK-11926][SQL] unify GetStructField and GetInternalRowField Author: Wenchen Fan <wenchen@databricks.com> Closes #9909 from cloud-fan/get-struct. --- .../spark/sql/catalyst/ScalaReflection.scala | 2 +- .../sql/catalyst/analysis/unresolved.scala | 8 +++---- .../catalyst/encoders/ExpressionEncoder.scala | 2 +- .../sql/catalyst/encoders/RowEncoder.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 2 +- .../expressions/complexTypeExtractors.scala | 18 ++++++++-------- .../expressions/namedExpressions.scala | 4 ++-- .../sql/catalyst/expressions/objects.scala | 21 ------------------- .../expressions/ComplexTypeSuite.scala | 4 ++-- 9 files changed, 21 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 476becec4d..d133ad3f0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -130,7 +130,7 @@ object ScalaReflection extends ScalaReflection { /** Returns the current path with a field at ordinal extracted. */ def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path - .map(p => GetInternalRowField(p, ordinal, dataType)) + .map(p => GetStructField(p, ordinal)) .getOrElse(BoundReference(ordinal, dataType, false)) /** Returns the current path or `BoundReference`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 6485bdfb30..1b2a8dc4c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -201,12 +201,12 @@ case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevalu if (attribute.isDefined) { // This target resolved to an attribute in child. It must be a struct. Expand it. attribute.get.dataType match { - case s: StructType => { - s.fields.map( f => { - val extract = GetStructField(attribute.get, f, s.getFieldIndex(f.name).get) + case s: StructType => s.zipWithIndex.map { + case (f, i) => + val extract = GetStructField(attribute.get, i) Alias(extract, target.get + "." + f.name)() - }) } + case _ => { throw new AnalysisException("Can only star expand struct data types. Attribute: `" + target.get + "`") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 7bc9aed0b2..0c10a56c55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -111,7 +111,7 @@ object ExpressionEncoder { case UnresolvedAttribute(nameParts) => assert(nameParts.length == 1) UnresolvedExtractValue(input, Literal(nameParts.head)) - case BoundReference(ordinal, dt, _) => GetInternalRowField(input, ordinal, dt) + case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index fa553e7c53..67518f52d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -220,7 +220,7 @@ object RowEncoder { If( Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(GetInternalRowField(input, i, f.dataType))) + constructorFor(GetStructField(input, i))) } CreateExternalRow(convertedFields) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 540ed35006..169435a10e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -206,7 +206,7 @@ abstract class Expression extends TreeNode[Expression] { */ def prettyString: String = { transform { - case a: AttributeReference => PrettyAttribute(a.name) + case a: AttributeReference => PrettyAttribute(a.name, a.dataType) case u: UnresolvedAttribute => PrettyAttribute(u.name) }.toString } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index f871b737ff..10ce10aaf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -51,7 +51,7 @@ object ExtractValue { case (StructType(fields), NonNullLiteral(v, StringType)) => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) - GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal) + GetStructField(child, ordinal, Some(fieldName)) case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => val fieldName = v.toString @@ -97,18 +97,18 @@ object ExtractValue { * Returns the value of fields in the Struct `child`. * * No need to do type checking since it is handled by [[ExtractValue]]. - * TODO: Unify with [[GetInternalRowField]], remove the need to specify a [[StructField]]. + * + * Note that we can pass in the field name directly to keep case preserving in `toString`. + * For example, when get field `yEAr` from `<year: int, month: int>`, we should pass in `yEAr`. */ -case class GetStructField(child: Expression, field: StructField, ordinal: Int) +case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) extends UnaryExpression { - override def dataType: DataType = child.dataType match { - case s: StructType => s(ordinal).dataType - // This is a hack to avoid breaking existing code until we remove the need for the struct field - case _ => field.dataType - } + private lazy val field = child.dataType.asInstanceOf[StructType](ordinal) + + override def dataType: DataType = field.dataType override def nullable: Boolean = child.nullable || field.nullable - override def toString: String = s"$child.${field.name}" + override def toString: String = s"$child.${name.getOrElse(field.name)}" protected override def nullSafeEval(input: Any): Any = input.asInstanceOf[InternalRow].get(ordinal, field.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 00b7970bd1..26b6aca799 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -273,7 +273,8 @@ case class AttributeReference( * A place holder used when printing expressions without debugging information such as the * expression id or the unresolved indicator. */ -case class PrettyAttribute(name: String) extends Attribute with Unevaluable { +case class PrettyAttribute(name: String, dataType: DataType = NullType) + extends Attribute with Unevaluable { override def toString: String = name @@ -286,7 +287,6 @@ case class PrettyAttribute(name: String) extends Attribute with Unevaluable { override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException override def nullable: Boolean = throw new UnsupportedOperationException - override def dataType: DataType = NullType } object VirtualColumn { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 4a1f419f0a..62d09f0f55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -517,27 +517,6 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression { } } -case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType) - extends UnaryExpression { - - override def nullable: Boolean = true - - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported") - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, eval => { - s""" - if ($eval.isNullAt($ordinal)) { - ${ev.isNull} = true; - } else { - ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; - } - """ - }) - } -} - /** * Serializes an input object using a generic serializer (Kryo or Java). * @param kryo if true, use Kryo. Otherwise, use Java. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index e60990aeb4..62fd47234b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -79,8 +79,8 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { def getStructField(expr: Expression, fieldName: String): GetStructField = { expr.dataType match { case StructType(fields) => - val field = fields.find(_.name == fieldName).get - GetStructField(expr, field, fields.indexOf(field)) + val index = fields.indexWhere(_.name == fieldName) + GetStructField(expr, index) } } -- GitLab