diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala similarity index 86% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 2b25ba03579ec2d3655ec40b8690d76cdf833272..73cc930c458327ac928179c1459d2c2fdfa789f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -25,6 +25,11 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines all the expressions to extract values out of complex types. +// For example, getting a field out of an array, map, or struct. +//////////////////////////////////////////////////////////////////////////////////////////////////// + object ExtractValue { /** @@ -73,11 +78,10 @@ object ExtractValue { } } - def unapply(g: ExtractValue): Option[(Expression, Expression)] = { - g match { - case o: ExtractValueWithOrdinal => Some((o.child, o.ordinal)) - case s: ExtractValueWithStruct => Some((s.child, null)) - } + def unapply(g: ExtractValue): Option[(Expression, Expression)] = g match { + case o: GetArrayItem => Some((o.child, o.ordinal)) + case o: GetMapValue => Some((o.child, o.key)) + case s: ExtractValueWithStruct => Some((s.child, null)) } /** @@ -117,6 +121,8 @@ abstract class ExtractValueWithStruct extends UnaryExpression with ExtractValue /** * Returns the value of fields in the Struct `child`. + * + * No need to do type checking since it is handled by [[ExtractValue]]. */ case class GetStructField(child: Expression, field: StructField, ordinal: Int) extends ExtractValueWithStruct { @@ -142,6 +148,8 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) /** * Returns the array of value of fields in the Array of Struct `child`. + * + * No need to do type checking since it is handled by [[ExtractValue]]. */ case class GetArrayStructFields( child: Expression, @@ -178,25 +186,21 @@ case class GetArrayStructFields( } } -abstract class ExtractValueWithOrdinal extends BinaryExpression with ExtractValue { - self: Product => +/** + * Returns the field at `ordinal` in the Array `child`. + * + * No need to do type checking since it is handled by [[ExtractValue]]. + */ +case class GetArrayItem(child: Expression, ordinal: Expression) + extends BinaryExpression with ExtractValue { - def ordinal: Expression - def child: Expression + override def toString: String = s"$child[$ordinal]" override def left: Expression = child override def right: Expression = ordinal /** `Null` is returned for invalid ordinals. */ override def nullable: Boolean = true - override def toString: String = s"$child[$ordinal]" -} - -/** - * Returns the field at `ordinal` in the Array `child` - */ -case class GetArrayItem(child: Expression, ordinal: Expression) - extends ExtractValueWithOrdinal { override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType @@ -227,10 +231,20 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } /** - * Returns the value of key `ordinal` in Map `child` + * Returns the value of key `ordinal` in Map `child`. + * + * No need to do type checking since it is handled by [[ExtractValue]]. */ -case class GetMapValue(child: Expression, ordinal: Expression) - extends ExtractValueWithOrdinal { +case class GetMapValue(child: Expression, key: Expression) + extends BinaryExpression with ExtractValue { + + override def toString: String = s"$child[$key]" + + override def left: Expression = child + override def right: Expression = key + + /** `Null` is returned for invalid ordinals. */ + override def nullable: Boolean = true override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType