diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 4a1a1ed61ebe76b53cecf24d3ceb4f0df6c968a6..0daee1990a6e0a04781b9ed769b94cf30f5fcb19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.{errors, trees} -import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.errors import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.trees.TreeNode @@ -50,7 +49,7 @@ case class UnresolvedRelation( /** * Holds the name of an attribute that has yet to be resolved. */ -case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute { +case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Unevaluable { def name: String = nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") @@ -66,10 +65,6 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute { override def withQualifiers(newQualifiers: Seq[String]): UnresolvedAttribute = this override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName) - // Unresolved attributes are transient at compile time and don't get evaluated during execution. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString: String = s"'$name" } @@ -78,16 +73,14 @@ object UnresolvedAttribute { def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name)) } -case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression { +case class UnresolvedFunction(name: String, children: Seq[Expression]) + extends Expression with Unevaluable { + override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString: String = s"'$name(${children.mkString(",")})" } @@ -105,10 +98,6 @@ abstract class Star extends LeafExpression with NamedExpression { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") override lazy val resolved = false - // Star gets expanded at runtime so we never evaluate a Star. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] } @@ -120,7 +109,7 @@ abstract class Star extends LeafExpression with NamedExpression { * @param table an optional table that should be the target of the expansion. If omitted all * tables' columns are produced. */ -case class UnresolvedStar(table: Option[String]) extends Star { +case class UnresolvedStar(table: Option[String]) extends Star with Unevaluable { override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = { val expandedAttributes: Seq[Attribute] = table match { @@ -149,7 +138,7 @@ case class UnresolvedStar(table: Option[String]) extends Star { * @param names the names to be associated with each output of computing [[child]]. */ case class MultiAlias(child: Expression, names: Seq[String]) - extends UnaryExpression with NamedExpression { + extends UnaryExpression with NamedExpression with CodegenFallback { override def name: String = throw new UnresolvedException(this, "name") @@ -165,9 +154,6 @@ case class MultiAlias(child: Expression, names: Seq[String]) override lazy val resolved = false - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString: String = s"$child AS $names" } @@ -178,7 +164,7 @@ case class MultiAlias(child: Expression, names: Seq[String]) * * @param expressions Expressions to expand. */ -case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star { +case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star with Unevaluable { override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")") } @@ -192,23 +178,21 @@ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star { * can be key of Map, index of Array, field name of Struct. */ case class UnresolvedExtractValue(child: Expression, extraction: Expression) - extends UnaryExpression { + extends UnaryExpression with Unevaluable { override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString: String = s"$child[$extraction]" } /** * Holds the expression that has yet to be aliased. */ -case class UnresolvedAlias(child: Expression) extends UnaryExpression with NamedExpression { +case class UnresolvedAlias(child: Expression) + extends UnaryExpression with NamedExpression with Unevaluable { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") @@ -218,7 +202,4 @@ case class UnresolvedAlias(child: Expression) extends UnaryExpression with Named override def name: String = throw new UnresolvedException(this, "name") override lazy val resolved = false - - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 692b9fddbb0410ee33c7fb4fad35ec871b7f6320..3346d3c9f9e617bf46453501e2e6a69194b4eac9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -18,12 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import java.math.{BigDecimal => JavaBigDecimal} -import java.sql.{Date, Timestamp} -import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{Interval, UTF8String} @@ -106,7 +104,8 @@ object Cast { } /** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { +case class Cast(child: Expression, dataType: DataType) + extends UnaryExpression with CodegenFallback { override def checkInputDataTypes(): TypeCheckResult = { if (Cast.canCast(child.dataType, dataType)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 0e128d8bdcd9699d24493d07dd2102cb4afa5b57..d0a1aa9a1e9120b76e60d93a0fd297e89538fb88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -101,19 +101,7 @@ abstract class Expression extends TreeNode[Expression] { * @param ev an [[GeneratedExpressionCode]] with unique terms. * @return Java source code */ - protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - ctx.references += this - val objectTerm = ctx.freshName("obj") - s""" - /* expression: ${this} */ - Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i); - boolean ${ev.isNull} = $objectTerm == null; - ${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = (${ctx.boxedType(this.dataType)}) $objectTerm; - } - """ - } + protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String /** * Returns `true` if this expression and all its children have been resolved to a specific schema @@ -182,6 +170,20 @@ abstract class Expression extends TreeNode[Expression] { } +/** + * An expression that cannot be evaluated. Some expressions don't live past analysis or optimization + * time (e.g. Star). This trait is used by those expressions. + */ +trait Unevaluable { self: Expression => + + override def eval(input: InternalRow = null): Any = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") +} + + /** * A leaf expression, i.e. one without any child expressions. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 22687acd68a97216ddb4baed4ba0a774e653a292..11c7950c0613b259ccc72a3a591aafbc51201b6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types.DataType /** @@ -29,7 +30,8 @@ case class ScalaUDF( function: AnyRef, dataType: DataType, children: Seq[Expression], - inputTypes: Seq[DataType] = Nil) extends Expression with ImplicitCastInputTypes { + inputTypes: Seq[DataType] = Nil) + extends Expression with ImplicitCastInputTypes with CodegenFallback { override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index b8f7068c9e5e5dac2e5a85ed8757548f8fc69a78..3f436c0eb893cf711e42aada57298a27a79b84b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -17,9 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.types.DataType abstract sealed class SortDirection @@ -30,7 +27,8 @@ case object Descending extends SortDirection * An expression that can be used to sort a tuple. This class extends expression primarily so that * transformations over expression will descend into its child. */ -case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression { +case class SortOrder(child: Expression, direction: SortDirection) + extends UnaryExpression with Unevaluable { /** Sort order is not foldable because we don't have an eval for it. */ override def foldable: Boolean = false @@ -38,9 +36,5 @@ case class SortOrder(child: Expression, direction: SortDirection) extends UnaryE override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable - // SortOrder itself is never evaluated. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index af9a674ab4958043892a893d8cd6efcb913cca82..d705a1286065cd0262395431b1712339dee3e8b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -26,7 +26,8 @@ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet -trait AggregateExpression extends Expression { + +trait AggregateExpression extends Expression with Unevaluable { /** * Aggregate expressions should not be foldable. @@ -38,13 +39,6 @@ trait AggregateExpression extends Expression { * of input rows/ */ def newInstance(): AggregateFunction - - /** - * [[AggregateExpression.eval]] should never be invoked because [[AggregateExpression]]'s are - * replaced with a physical aggregate operator at runtime. - */ - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index e83650fc8cb0eb9e49d2ac1133af64a0650fb5de..05b5ad88fee8f700c16eec101b7e89967edd2f01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.Interval @@ -65,7 +65,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects /** * A function that get the absolute value of the numeric value. */ -case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Abs(child: Expression) + extends UnaryExpression with ExpectsInputTypes with CodegenFallback { override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala new file mode 100644 index 0000000000000000000000000000000000000000..bf4f600cb26e5cde9a0159a2fb829dd7480e444f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions.Expression + +/** + * A trait that can be used to provide a fallback mode for expression code generation. + */ +trait CodegenFallback { self: Expression => + + protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + ctx.references += this + val objectTerm = ctx.freshName("obj") + s""" + /* expression: ${this} */ + Object $objectTerm = expressions[${ctx.references.size - 1}].eval(i); + boolean ${ev.isNull} = $objectTerm == null; + ${ctx.javaType(this.dataType)} ${ev.primitive} = ${ctx.defaultValue(this.dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = (${ctx.boxedType(this.dataType)}) $objectTerm; + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index d1e4c458864f1b752745fd0185570ad09a43bfee..f9fd04c02aaef924e0de0e65887f92284a753735 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** * Returns an Array containing the evaluation of all children expressions. */ -case class CreateArray(children: Seq[Expression]) extends Expression { +case class CreateArray(children: Seq[Expression]) extends Expression with CodegenFallback { override def foldable: Boolean = children.forall(_.foldable) @@ -51,7 +52,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { * Returns a Row containing the evaluation of all children expressions. * TODO: [[CreateStruct]] does not support codegen. */ -case class CreateStruct(children: Seq[Expression]) extends Expression { +case class CreateStruct(children: Seq[Expression]) extends Expression with CodegenFallback { override def foldable: Boolean = children.forall(_.foldable) @@ -83,7 +84,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { * * @param children Seq(name1, val1, name2, val2, ...) */ -case class CreateNamedStruct(children: Seq[Expression]) extends Expression { +case class CreateNamedStruct(children: Seq[Expression]) extends Expression with CodegenFallback { private lazy val (nameExprs, valExprs) = children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip @@ -103,11 +104,11 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { override def checkInputDataTypes(): TypeCheckResult = { if (children.size % 2 != 0) { - TypeCheckResult.TypeCheckFailure("CreateNamedStruct expects an even number of arguments.") + TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") } else { val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable) - if (invalidNames.size != 0) { + if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( s"Odd position only allow foldable and not-null StringType expressions, got :" + s" ${invalidNames.mkString(",")}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index dd5ec330a771bfaa221b574f20e0a2c710ad1c14..4bed140cffbfa6a98d7aee8d140906368e958c34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -27,7 +28,7 @@ import org.apache.spark.sql.types._ * * There is no code generation since this expression should get constant folded by the optimizer. */ -case class CurrentDate() extends LeafExpression { +case class CurrentDate() extends LeafExpression with CodegenFallback { override def foldable: Boolean = true override def nullable: Boolean = false @@ -44,7 +45,7 @@ case class CurrentDate() extends LeafExpression { * * There is no code generation since this expression should get constant folded by the optimizer. */ -case class CurrentTimestamp() extends LeafExpression { +case class CurrentTimestamp() extends LeafExpression with CodegenFallback { override def foldable: Boolean = true override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index c58a6d36141c1b34cf9aa3e7139da68bfa9efadc..2dbcf2830f876cfbe566629518d2434bf0c07681 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ /** @@ -73,7 +73,7 @@ case class UserDefinedGenerator( elementTypes: Seq[(DataType, Boolean)], function: Row => TraversableOnce[InternalRow], children: Seq[Expression]) - extends Generator { + extends Generator with CodegenFallback { @transient private[this] var inputRow: InterpretedProjection = _ @transient private[this] var convertToScala: (InternalRow) => Row = _ @@ -100,7 +100,7 @@ case class UserDefinedGenerator( /** * Given an input array produces a sequence of rows for each value in the array. */ -case class Explode(child: Expression) extends UnaryExpression with Generator { +case class Explode(child: Expression) extends UnaryExpression with Generator with CodegenFallback { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e1fdb29541fa850c3d632de95598e7c376faad0c..f25ac32679587da0fed5ac7d5b2d67426efadb37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types._ @@ -75,7 +75,8 @@ object IntegerLiteral { /** * In order to do type checking, use Literal.create() instead of constructor */ -case class Literal protected (value: Any, dataType: DataType) extends LeafExpression { +case class Literal protected (value: Any, dataType: DataType) + extends LeafExpression with CodegenFallback { override def foldable: Boolean = true override def nullable: Boolean = value == null @@ -142,7 +143,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres // TODO: Specialize case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) - extends LeafExpression { + extends LeafExpression with CodegenFallback { def update(expression: Expression, input: InternalRow): Unit = { value = expression.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index eb5c065a34123370f2067843d44610110ef0059c..7ce64d29ba59ab80ca382dffefc1903075a0ba9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} -import java.util.Arrays import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure} @@ -29,11 +28,14 @@ import org.apache.spark.unsafe.types.UTF8String /** * A leaf expression specifically for math constants. Math constants expect no input. + * + * There is no code generation because they should get constant folded by the optimizer. + * * @param c The math constant. * @param name The short name of the function */ abstract class LeafMathExpression(c: Double, name: String) - extends LeafExpression with Serializable { + extends LeafExpression with CodegenFallback { override def dataType: DataType = DoubleType override def foldable: Boolean = true @@ -41,13 +43,6 @@ abstract class LeafMathExpression(c: Double, name: String) override def toString: String = s"$name()" override def eval(input: InternalRow): Any = c - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - s""" - boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.primitive} = java.lang.Math.$name; - """ - } } /** @@ -130,8 +125,16 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// +/** + * Euler's number. Note that there is no code generation because this is only + * evaluated by the optimizer during constant folding. + */ case class EulerNumber() extends LeafMathExpression(math.E, "E") +/** + * Pi. Note that there is no code generation because this is only + * evaluated by the optimizer during constant folding. + */ case class Pi() extends LeafMathExpression(math.Pi, "PI") //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -161,7 +164,7 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH" * @param toBaseExpr to which base */ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) - extends Expression with ImplicitCastInputTypes{ + extends Expression with ImplicitCastInputTypes with CodegenFallback { override def foldable: Boolean = numExpr.foldable && fromBaseExpr.foldable && toBaseExpr.foldable @@ -171,6 +174,8 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) + override def dataType: DataType = StringType + /** Returns the result of evaluating this expression on a given input Row */ override def eval(input: InternalRow): Any = { val num = numExpr.eval(input) @@ -179,17 +184,13 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre if (num == null || fromBase == null || toBase == null) { null } else { - conv(num.asInstanceOf[UTF8String].getBytes, - fromBase.asInstanceOf[Int], toBase.asInstanceOf[Int]) + conv( + num.asInstanceOf[UTF8String].getBytes, + fromBase.asInstanceOf[Int], + toBase.asInstanceOf[Int]) } } - /** - * Returns the [[DataType]] of the result of evaluating this expression. It is - * invalid to query the dataType of an unresolved expression (i.e., when `resolved` == false). - */ - override def dataType: DataType = StringType - private val value = new Array[Byte](64) /** @@ -208,7 +209,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre // Two's complement => x = uval - 2*MAX - 2 // => uval = x + 2*MAX + 2 // Now, use the fact: (a+b)/c = a/c + b/c + (a%c+b%c)/c - (x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m) + x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m } } @@ -220,7 +221,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre */ private def decode(v: Long, radix: Int): Unit = { var tmpV = v - Arrays.fill(value, 0.asInstanceOf[Byte]) + java.util.Arrays.fill(value, 0.asInstanceOf[Byte]) var i = value.length - 1 while (tmpV != 0) { val q = unsignedLongDiv(tmpV, radix) @@ -254,7 +255,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre v = v * radix + value(i) i += 1 } - return v + v } /** @@ -292,16 +293,16 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv */ private def conv(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = { - if (n == null || fromBase == null || toBase == null || n.isEmpty) { - return null - } - if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX || Math.abs(toBase) < Character.MIN_RADIX || Math.abs(toBase) > Character.MAX_RADIX) { return null } + if (n.length == 0) { + return null + } + var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0) // Copy the digits in the right side of the array @@ -340,7 +341,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre resultStartPos = firstNonZeroPos - 1 value(resultStartPos) = '-' } - UTF8String.fromBytes( Arrays.copyOfRange(value, resultStartPos, value.length)) + UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, resultStartPos, value.length)) } } @@ -495,8 +496,8 @@ object Hex { * Otherwise if the number is a STRING, it converts each character into its hex representation * and returns the resulting STRING. Negative numbers would be treated as two's complement. */ -case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - // TODO: Create code-gen version. +case class Hex(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(LongType, BinaryType, StringType)) @@ -539,8 +540,8 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput * Performs the inverse operation of HEX. * Resulting characters are returned as a byte array. */ -case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { - // TODO: Create code-gen version. +case class Unhex(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback { override def inputTypes: Seq[AbstractDataType] = Seq(StringType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index c083ac08ded2198f3ab116eabca3953af896ef97..6f173b52ad9b938ab190fc67c5e5100bb14c5952 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -19,9 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ object NamedExpression { @@ -122,7 +120,9 @@ case class Alias(child: Expression, name: String)( override def eval(input: InternalRow): Any = child.eval(input) + /** Just a simple passthrough for code generation. */ override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable @@ -177,7 +177,7 @@ case class AttributeReference( override val metadata: Metadata = Metadata.empty)( val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) - extends Attribute { + extends Attribute with Unevaluable { /** * Returns true iff the expression id is the same for both attributes. @@ -236,10 +236,6 @@ case class AttributeReference( } } - // Unresolved attributes are transient at compile time and don't get evaluated during execution. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString: String = s"$name#${exprId.id}$typeSuffix" } @@ -247,7 +243,7 @@ case class AttributeReference( * A place holder used when printing expressions without debugging information such as the * expression id or the unresolved indicator. */ -case class PrettyAttribute(name: String) extends Attribute { +case class PrettyAttribute(name: String) extends Attribute with Unevaluable { override def toString: String = name @@ -259,7 +255,6 @@ case class PrettyAttribute(name: String) extends Attribute { override def withName(newName: String): Attribute = throw new UnsupportedOperationException override def qualifiers: Seq[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def nullable: Boolean = throw new UnsupportedOperationException override def dataType: DataType = NullType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index bddd2a9eccfc0b72fd3f3463f3bdf14e5628ae3f..40ec3df224ce14750db7f96d3c48713e50273f12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ + object InterpretedPredicate { def create(expression: Expression, inputSchema: Seq[Attribute]): (InternalRow => Boolean) = create(BindReferences.bindReference(expression, inputSchema)) @@ -91,7 +92,7 @@ case class Not(child: Expression) /** * Evaluates to `true` if `list` contains `value`. */ -case class In(value: Expression, list: Seq[Expression]) extends Predicate { +case class In(value: Expression, list: Seq[Expression]) extends Predicate with CodegenFallback { override def children: Seq[Expression] = value +: list override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. @@ -109,7 +110,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { * static. */ case class InSet(child: Expression, hset: Set[Any]) - extends UnaryExpression with Predicate { + extends UnaryExpression with Predicate with CodegenFallback { override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 49b2026364cd68e6ddeb49ba6501e3ca33f2597c..5b0fe8dfe2fc886918008c41a82e0e8e53a526ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -52,7 +52,7 @@ private[sql] class OpenHashSetUDT( /** * Creates a new set of the specified type */ -case class NewSet(elementType: DataType) extends LeafExpression { +case class NewSet(elementType: DataType) extends LeafExpression with CodegenFallback { override def nullable: Boolean = false @@ -82,7 +82,8 @@ case class NewSet(elementType: DataType) extends LeafExpression { * Note: this expression is internal and created only by the GeneratedAggregate, * we don't need to do type check for it. */ -case class AddItemToSet(item: Expression, set: Expression) extends Expression { +case class AddItemToSet(item: Expression, set: Expression) + extends Expression with CodegenFallback { override def children: Seq[Expression] = item :: set :: Nil @@ -134,7 +135,8 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { * Note: this expression is internal and created only by the GeneratedAggregate, * we don't need to do type check for it. */ -case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { +case class CombineSets(left: Expression, right: Expression) + extends BinaryExpression with CodegenFallback { override def nullable: Boolean = left.nullable override def dataType: DataType = left.dataType @@ -181,7 +183,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres * Note: this expression is internal and created only by the GeneratedAggregate, * we don't need to do type check for it. */ -case class CountSet(child: Expression) extends UnaryExpression { +case class CountSet(child: Expression) extends UnaryExpression with CodegenFallback { override def dataType: DataType = LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index b36354eff092ae552f60d09cc58401cddee0840e..560b1bc2d889f12fa1daaf8322caac3b2f7429e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -103,7 +103,7 @@ trait StringRegexExpression extends ImplicitCastInputTypes { * Simple RegEx pattern matching function */ case class Like(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression { + extends BinaryExpression with StringRegexExpression with CodegenFallback { // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character @@ -133,14 +133,16 @@ case class Like(left: Expression, right: Expression) override def toString: String = s"$left LIKE $right" } + case class RLike(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression { + extends BinaryExpression with StringRegexExpression with CodegenFallback { override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def toString: String = s"$left RLIKE $right" } + trait String2StringExpression extends ImplicitCastInputTypes { self: UnaryExpression => @@ -156,7 +158,8 @@ trait String2StringExpression extends ImplicitCastInputTypes { /** * A function that converts the characters of a string to uppercase. */ -case class Upper(child: Expression) extends UnaryExpression with String2StringExpression { +case class Upper(child: Expression) + extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.toUpperCase @@ -301,7 +304,7 @@ case class StringInstr(str: Expression, substr: Expression) * in given string after position pos. */ case class StringLocate(substr: Expression, str: Expression, start: Expression) - extends Expression with ImplicitCastInputTypes { + extends Expression with ImplicitCastInputTypes with CodegenFallback { def this(substr: Expression, str: Expression) = { this(substr, str, Literal(0)) @@ -342,7 +345,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) * Returns str, left-padded with pad to a length of len. */ case class StringLPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ImplicitCastInputTypes { + extends Expression with ImplicitCastInputTypes with CodegenFallback { override def children: Seq[Expression] = str :: len :: pad :: Nil override def foldable: Boolean = children.forall(_.foldable) @@ -380,7 +383,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) * Returns str, right-padded with pad to a length of len. */ case class StringRPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ImplicitCastInputTypes { + extends Expression with ImplicitCastInputTypes with CodegenFallback { override def children: Seq[Expression] = str :: len :: pad :: Nil override def foldable: Boolean = children.forall(_.foldable) @@ -417,9 +420,9 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) /** * Returns the input formatted according do printf-style format strings */ -case class StringFormat(children: Expression*) extends Expression { +case class StringFormat(children: Expression*) extends Expression with CodegenFallback { - require(children.length >=1, "printf() should take at least 1 argument") + require(children.nonEmpty, "printf() should take at least 1 argument") override def foldable: Boolean = children.forall(_.foldable) override def nullable: Boolean = children(0).nullable @@ -436,7 +439,7 @@ case class StringFormat(children: Expression*) extends Expression { val formatter = new java.util.Formatter(sb, Locale.US) val arglist = args.map(_.eval(input).asInstanceOf[AnyRef]) - formatter.format(pattern.asInstanceOf[UTF8String].toString(), arglist: _*) + formatter.format(pattern.asInstanceOf[UTF8String].toString, arglist: _*) UTF8String.fromString(sb.toString) } @@ -483,7 +486,8 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2 /** * Returns a n spaces string. */ -case class StringSpace(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class StringSpace(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -503,7 +507,7 @@ case class StringSpace(child: Expression) extends UnaryExpression with ImplicitC * Splits str around pat (pattern is a regular expression). */ case class StringSplit(str: Expression, pattern: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with CodegenFallback { override def left: Expression = str override def right: Expression = pattern @@ -524,7 +528,7 @@ case class StringSplit(str: Expression, pattern: Expression) * Defined for String and Binary types. */ case class Substring(str: Expression, pos: Expression, len: Expression) - extends Expression with ImplicitCastInputTypes { + extends Expression with ImplicitCastInputTypes with CodegenFallback { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) @@ -606,8 +610,6 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") } } - - override def prettyName: String = "length" } /** @@ -632,7 +634,9 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres /** * Returns the numeric value of the first character of str. */ -case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Ascii(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback { + override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -649,7 +653,9 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp /** * Converts the argument from binary to a base 64 string. */ -case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Base64(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback { + override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType) @@ -663,7 +669,9 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn /** * Converts the argument from a base 64 string to BINARY. */ -case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class UnBase64(child: Expression) + extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback { + override def dataType: DataType = BinaryType override def inputTypes: Seq[DataType] = Seq(StringType) @@ -677,7 +685,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast * If either argument is null, the result will also be null. */ case class Decode(bin: Expression, charset: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with CodegenFallback { override def left: Expression = bin override def right: Expression = charset @@ -696,7 +704,7 @@ case class Decode(bin: Expression, charset: Expression) * If either argument is null, the result will also be null. */ case class Encode(value: Expression, charset: Expression) - extends BinaryExpression with ImplicitCastInputTypes { + extends BinaryExpression with ImplicitCastInputTypes with CodegenFallback { override def left: Expression = value override def right: Expression = charset @@ -715,7 +723,7 @@ case class Encode(value: Expression, charset: Expression) * fractional part. */ case class FormatNumber(x: Expression, d: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ExpectsInputTypes with CodegenFallback { override def left: Expression = x override def right: Expression = d diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index c8aa571df64fc59ff3ee568a2497c42b4fd5f427..50bbfd644d30270b4516c1aee70ba16e9ca93e14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.types.{DataType, NumericType} /** @@ -37,7 +36,7 @@ sealed trait WindowSpec case class WindowSpecDefinition( partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], - frameSpecification: WindowFrame) extends Expression with WindowSpec { + frameSpecification: WindowFrame) extends Expression with WindowSpec with Unevaluable { def validate: Option[String] = frameSpecification match { case UnspecifiedFrame => @@ -75,7 +74,6 @@ case class WindowSpecDefinition( override def toString: String = simpleString - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def nullable: Boolean = true override def foldable: Boolean = false override def dataType: DataType = throw new UnsupportedOperationException @@ -274,60 +272,43 @@ trait WindowFunction extends Expression { case class UnresolvedWindowFunction( name: String, children: Seq[Expression]) - extends Expression with WindowFunction { + extends Expression with WindowFunction with Unevaluable { override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - override def init(): Unit = - throw new UnresolvedException(this, "init") - override def reset(): Unit = - throw new UnresolvedException(this, "reset") + override def init(): Unit = throw new UnresolvedException(this, "init") + override def reset(): Unit = throw new UnresolvedException(this, "reset") override def prepareInputParameters(input: InternalRow): AnyRef = throw new UnresolvedException(this, "prepareInputParameters") - override def update(input: AnyRef): Unit = - throw new UnresolvedException(this, "update") + override def update(input: AnyRef): Unit = throw new UnresolvedException(this, "update") override def batchUpdate(inputs: Array[AnyRef]): Unit = throw new UnresolvedException(this, "batchUpdate") - override def evaluate(): Unit = - throw new UnresolvedException(this, "evaluate") - override def get(index: Int): Any = - throw new UnresolvedException(this, "get") - // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + override def evaluate(): Unit = throw new UnresolvedException(this, "evaluate") + override def get(index: Int): Any = throw new UnresolvedException(this, "get") override def toString: String = s"'$name(${children.mkString(",")})" - override def newInstance(): WindowFunction = - throw new UnresolvedException(this, "newInstance") + override def newInstance(): WindowFunction = throw new UnresolvedException(this, "newInstance") } case class UnresolvedWindowExpression( child: UnresolvedWindowFunction, - windowSpec: WindowSpecReference) extends UnaryExpression { + windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false - - // Unresolved functions are transient at compile time and don't get evaluated during execution. - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } case class WindowExpression( windowFunction: WindowFunction, - windowSpec: WindowSpecDefinition) extends Expression { - - override def children: Seq[Expression] = - windowFunction :: windowSpec :: Nil + windowSpec: WindowSpecDefinition) extends Expression with Unevaluable { - override def eval(input: InternalRow): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil override def dataType: DataType = windowFunction.dataType override def foldable: Boolean = windowFunction.foldable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 42dead7c28425b77d7dea1d531c4674b33d4e233..2dcfa19fec383dadaf568d10edbf9b165016ec1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Unevaluable, Expression, SortOrder} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -146,8 +144,7 @@ case object BroadcastPartitioning extends Partitioning { * in the same partition. */ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) - extends Expression - with Partitioning { + extends Expression with Partitioning with Unevaluable { override def children: Seq[Expression] = expressions override def nullable: Boolean = false @@ -169,9 +166,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) } override def keyExpressions: Seq[Expression] = expressions - - override def eval(input: InternalRow = null): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } /** @@ -187,8 +181,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) * into its child. */ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) - extends Expression - with Partitioning { + extends Expression with Partitioning with Unevaluable { override def children: Seq[SortOrder] = ordering override def nullable: Boolean = false @@ -213,7 +206,4 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) } override def keyExpressions: Seq[Expression] = ordering.map(_.child) - - override def eval(input: InternalRow): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 2147d07e09bd30f0608b70f31893e300225ba033..dca8c881f21ab751a490cd123b626da94abf7c06 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -31,9 +31,9 @@ import org.apache.spark.sql.types._ case class TestFunction( children: Seq[Expression], - inputTypes: Seq[AbstractDataType]) extends Expression with ImplicitCastInputTypes { + inputTypes: Seq[AbstractDataType]) + extends Expression with ImplicitCastInputTypes with Unevaluable { override def nullable: Boolean = true - override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def dataType: DataType = StringType } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index c9b3c69c6de89dc2bcb956c0162843c4a14ece7c..f9442bccc4a7aa077fb6a2506dc06a0b2f6de3ee 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -363,26 +363,26 @@ class HiveTypeCoercionSuite extends PlanTest { object HiveTypeCoercionSuite { case class AnyTypeUnaryExpression(child: Expression) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with ExpectsInputTypes with Unevaluable { override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) override def dataType: DataType = NullType } case class NumericTypeUnaryExpression(child: Expression) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with ExpectsInputTypes with Unevaluable { override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) override def dataType: DataType = NullType } case class AnyTypeBinaryOperator(left: Expression, right: Expression) - extends BinaryOperator { + extends BinaryOperator with Unevaluable { override def dataType: DataType = NullType override def inputType: AbstractDataType = AnyDataType override def symbol: String = "anytype" } case class NumericTypeBinaryOperator(left: Expression, right: Expression) - extends BinaryOperator { + extends BinaryOperator with Unevaluable { override def dataType: DataType = NullType override def inputType: AbstractDataType = NumericType override def symbol: String = "numerictype" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 1bd7d4e5cdf0fdab6fade6af9598f432224aa95d..8fff39906b3426fd71aa90627aa135aeef15672d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -22,9 +22,10 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types.{IntegerType, StringType, NullType} -case class Dummy(optKey: Option[Expression]) extends Expression { +case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback { override def children: Seq[Expression] = optKey.toSeq override def nullable: Boolean = true override def dataType: NullType = NullType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 6d6e67dace177222166653f0191cdf70f3757699..e6e27a87c715159c4d5cfca735a75d78aa6dbeeb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -51,15 +51,11 @@ private[spark] case class PythonUDF( broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]], dataType: DataType, - children: Seq[Expression]) extends Expression with SparkLogging { + children: Seq[Expression]) extends Expression with Unevaluable with SparkLogging { override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" override def nullable: Boolean = true - - override def eval(input: InternalRow): Any = { - throw new UnsupportedOperationException("PythonUDFs can not be directly evaluated.") - } } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 0bc8adb16afc0fa3821a95db57a223e4a24b1930..4d23c7035c03d3e4233c40af39644257053ada56 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -36,8 +36,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.hive.HiveShim._ @@ -81,7 +81,7 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) } private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Expression with HiveInspectors with Logging { + extends Expression with HiveInspectors with CodegenFallback with Logging { type UDFType = UDF @@ -146,7 +146,7 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector) } private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Expression with HiveInspectors with Logging { + extends Expression with HiveInspectors with CodegenFallback with Logging { type UDFType = GenericUDF override def deterministic: Boolean = isUDFDeterministic @@ -166,8 +166,8 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr @transient protected lazy val isUDFDeterministic = { - val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) - (udfType != null && udfType.deterministic()) + val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) + udfType != null && udfType.deterministic() } override def foldable: Boolean = @@ -301,7 +301,7 @@ private[hive] case class HiveWindowFunction( pivotResult: Boolean, isUDAFBridgeRequired: Boolean, children: Seq[Expression]) extends WindowFunction - with HiveInspectors { + with HiveInspectors with Unevaluable { // Hive window functions are based on GenericUDAFResolver2. type UDFType = GenericUDAFResolver2 @@ -330,7 +330,7 @@ private[hive] case class HiveWindowFunction( evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) } - def dataType: DataType = + override def dataType: DataType = if (!pivotResult) { inspectorToDataType(returnInspector) } else { @@ -344,10 +344,7 @@ private[hive] case class HiveWindowFunction( } } - def nullable: Boolean = true - - override def eval(input: InternalRow): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + override def nullable: Boolean = true @transient lazy val inputProjection = new InterpretedProjection(children) @@ -406,7 +403,7 @@ private[hive] case class HiveWindowFunction( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - override def newInstance: WindowFunction = + override def newInstance(): WindowFunction = new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children) } @@ -476,7 +473,7 @@ private[hive] case class HiveUDAF( /** * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a - * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow + * [[Generator]]. Note that the semantics of Generators do not allow * Generators to maintain state in between input rows. Thus UDTFs that rely on partitioning * dependent operations like calls to `close()` before producing output will not operate the same as * in Hive. However, in practice this should not affect compatibility for most sane UDTFs @@ -488,7 +485,7 @@ private[hive] case class HiveUDAF( private[hive] case class HiveGenericUDTF( funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Generator with HiveInspectors { + extends Generator with HiveInspectors with CodegenFallback { @transient protected lazy val function: GenericUDTF = {