From c46aaf47f38163e9c7be671d7b8398512df34e62 Mon Sep 17 00:00:00 2001 From: Wenchen Fan <cloud0fan@outlook.com> Date: Mon, 6 Jul 2015 22:13:50 -0700 Subject: [PATCH] [SPARK-8759][SQL] add default eval to binary and unary expression according to default behavior of nullable We have `nullSafeCodeGen` to provide default code generation for binary and unary expression, and we can do the same thing for `eval`. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7157 from cloud-fan/refactor and squashes the following commits: f3987c6 [Wenchen Fan] refactor Expression --- .../spark/sql/catalyst/expressions/Cast.scala | 7 +- .../sql/catalyst/expressions/Expression.scala | 69 ++++++- .../catalyst/expressions/ExtractValue.scala | 51 ++--- .../sql/catalyst/expressions/arithmetic.scala | 97 ++++----- .../sql/catalyst/expressions/bitwise.scala | 8 +- .../expressions/decimalFunctions.scala | 33 +--- .../spark/sql/catalyst/expressions/math.scala | 186 +++++------------- .../spark/sql/catalyst/expressions/misc.scala | 153 ++++++-------- .../sql/catalyst/expressions/predicates.scala | 84 +++----- .../spark/sql/catalyst/expressions/sets.scala | 9 +- .../expressions/stringOperations.scala | 138 +++---------- 11 files changed, 292 insertions(+), 543 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 2d99d1a3fe..4f73ba40b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -114,8 +114,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } } - override def foldable: Boolean = child.foldable - override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable override def toString: String = s"CAST($child, $dataType)" @@ -426,10 +424,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w private[this] lazy val cast: Any => Any = cast(child.dataType, dataType) - override def eval(input: InternalRow): Any = { - val evaluated = child.eval(input) - if (evaluated == null) null else cast(evaluated) - } + protected override def nullSafeEval(input: Any): Any = cast(input) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { // TODO: Add support for more data types. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index cafbbafdca..386feb95b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -183,6 +183,27 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable + /** + * Default behavior of evaluation according to the default nullability of UnaryExpression. + * If subclass of UnaryExpression override nullable, probably should also override this. + */ + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + null + } else { + nullSafeEval(value) + } + } + + /** + * Called by default [[eval]] implementation. If subclass of UnaryExpression keep the default + * nullability, they can override this method to save null-check code. If we need full control + * of evaluation process, we should override [[eval]]. + */ + protected def nullSafeEval(input: Any): Any = + sys.error(s"UnaryExpressions must override either eval or nullSafeEval") + /** * Called by unary expressions to generate a code block that returns null if its parent returns * null, and if not not null, use `f` to generate the expression. @@ -198,21 +219,24 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio ctx: CodeGenContext, ev: GeneratedExpressionCode, f: String => String): String = { - nullSafeCodeGen(ctx, ev, (result, eval) => { - s"$result = ${f(eval)};" + nullSafeCodeGen(ctx, ev, eval => { + s"${ev.primitive} = ${f(eval)};" }) } /** * Called by unary expressions to generate a code block that returns null if its parent returns * null, and if not not null, use `f` to generate the expression. + * + * @param f function that accepts the non-null evaluation result name of child and returns Java + * code to compute the output. */ protected def nullSafeCodeGen( ctx: CodeGenContext, ev: GeneratedExpressionCode, - f: (String, String) => String): String = { + f: String => String): String = { val eval = child.gen(ctx) - val resultCode = f(ev.primitive, eval.primitive) + val resultCode = f(eval.primitive) eval.code + s""" boolean ${ev.isNull} = ${eval.isNull}; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; @@ -235,6 +259,32 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def nullable: Boolean = left.nullable || right.nullable + /** + * Default behavior of evaluation according to the default nullability of BinaryExpression. + * If subclass of BinaryExpression override nullable, probably should also override this. + */ + override def eval(input: InternalRow): Any = { + val value1 = left.eval(input) + if (value1 == null) { + null + } else { + val value2 = right.eval(input) + if (value2 == null) { + null + } else { + nullSafeEval(value1, value2) + } + } + } + + /** + * Called by default [[eval]] implementation. If subclass of BinaryExpression keep the default + * nullability, they can override this method to save null-check code. If we need full control + * of evaluation process, we should override [[eval]]. + */ + protected def nullSafeEval(input1: Any, input2: Any): Any = + sys.error(s"BinaryExpressions must override either eval or nullSafeEval") + /** * Short hand for generating binary evaluation code. * If either of the sub-expressions is null, the result of this computation @@ -246,8 +296,8 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express ctx: CodeGenContext, ev: GeneratedExpressionCode, f: (String, String) => String): String = { - nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { - s"$result = ${f(eval1, eval2)};" + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s"${ev.primitive} = ${f(eval1, eval2)};" }) } @@ -255,14 +305,17 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express * Short hand for generating binary evaluation code. * If either of the sub-expressions is null, the result of this computation * is assumed to be null. + * + * @param f function that accepts the 2 non-null evaluation result names of children + * and returns Java code to compute the output. */ protected def nullSafeCodeGen( ctx: CodeGenContext, ev: GeneratedExpressionCode, - f: (String, String, String) => String): String = { + f: (String, String) => String): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) + val resultCode = f(eval1.primitive, eval2.primitive) s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index 3020e7fc96..e451c7ffbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -122,18 +122,16 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) override def dataType: DataType = field.dataType override def nullable: Boolean = child.nullable || field.nullable - override def eval(input: InternalRow): Any = { - val baseValue = child.eval(input).asInstanceOf[InternalRow] - if (baseValue == null) null else baseValue(ordinal) - } + protected override def nullSafeEval(input: Any): Any = + input.asInstanceOf[InternalRow](ordinal) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (result, eval) => { + nullSafeCodeGen(ctx, ev, eval => { s""" if ($eval.isNullAt($ordinal)) { ${ev.isNull} = true; } else { - $result = ${ctx.getColumn(eval, dataType, ordinal)}; + ${ev.primitive} = ${ctx.getColumn(eval, dataType, ordinal)}; } """ }) @@ -152,12 +150,9 @@ case class GetArrayStructFields( override def dataType: DataType = ArrayType(field.dataType, containsNull) override def nullable: Boolean = child.nullable || containsNull || field.nullable - override def eval(input: InternalRow): Any = { - val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]] - if (baseValue == null) null else { - baseValue.map { row => - if (row == null) null else row(ordinal) - } + protected override def nullSafeEval(input: Any): Any = { + input.asInstanceOf[Seq[InternalRow]].map { row => + if (row == null) null else row(ordinal) } } @@ -165,7 +160,7 @@ case class GetArrayStructFields( val arraySeqClass = "scala.collection.mutable.ArraySeq" // TODO: consider using Array[_] for ArrayType child to avoid // boxing of primitives - nullSafeCodeGen(ctx, ev, (result, eval) => { + nullSafeCodeGen(ctx, ev, eval => { s""" final int n = $eval.size(); final $arraySeqClass<Object> values = new $arraySeqClass<Object>(n); @@ -175,7 +170,7 @@ case class GetArrayStructFields( values.update(j, ${ctx.getColumn("row", field.dataType, ordinal)}); } } - $result = (${ctx.javaType(dataType)}) values; + ${ev.primitive} = (${ctx.javaType(dataType)}) values; """ }) } @@ -193,22 +188,6 @@ abstract class ExtractValueWithOrdinal extends BinaryExpression with ExtractValu /** `Null` is returned for invalid ordinals. */ override def nullable: Boolean = true override def toString: String = s"$child[$ordinal]" - - override def eval(input: InternalRow): Any = { - val value = child.eval(input) - if (value == null) { - null - } else { - val o = ordinal.eval(input) - if (o == null) { - null - } else { - evalNotNull(value, o) - } - } - } - - protected def evalNotNull(value: Any, ordinal: Any): Any } /** @@ -219,7 +198,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType - protected def evalNotNull(value: Any, ordinal: Any) = { + protected override def nullSafeEval(value: Any, ordinal: Any): Any = { // TODO: consider using Array[_] for ArrayType child to avoid // boxing of primitives val baseValue = value.asInstanceOf[Seq[_]] @@ -232,13 +211,13 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" final int index = (int)$eval2; if (index >= $eval1.size() || index < 0) { ${ev.isNull} = true; } else { - $result = (${ctx.boxedType(dataType)})$eval1.apply(index); + ${ev.primitive} = (${ctx.boxedType(dataType)})$eval1.apply(index); } """ }) @@ -253,16 +232,16 @@ case class GetMapValue(child: Expression, ordinal: Expression) override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType - protected def evalNotNull(value: Any, ordinal: Any) = { + protected override def nullSafeEval(value: Any, ordinal: Any): Any = { val baseValue = value.asInstanceOf[Map[Any, _]] baseValue.get(ordinal).orNull } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" if ($eval1.contains($eval2)) { - $result = (${ctx.boxedType(dataType)})$eval1.apply($eval2); + ${ev.primitive} = (${ctx.boxedType(dataType)})$eval1.apply($eval2); } else { ${ev.isNull} = true; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 4fbf4c8700..dca6642665 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -26,18 +26,6 @@ abstract class UnaryArithmetic extends UnaryExpression { self: Product => override def dataType: DataType = child.dataType - - override def eval(input: InternalRow): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - evalInternal(evalE) - } - } - - protected def evalInternal(evalE: Any): Any = - sys.error(s"UnaryArithmetics must override either eval or evalInternal") } case class UnaryMinus(child: Expression) extends UnaryArithmetic { @@ -53,7 +41,7 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))") } - protected override def evalInternal(evalE: Any) = numeric.negate(evalE) + protected override def nullSafeEval(input: Any): Any = numeric.negate(input) } case class UnaryPositive(child: Expression) extends UnaryArithmetic { @@ -62,7 +50,7 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = defineCodeGen(ctx, ev, c => c) - protected override def evalInternal(evalE: Any) = evalE + protected override def nullSafeEval(input: Any): Any = input } /** @@ -74,7 +62,7 @@ case class Abs(child: Expression) extends UnaryArithmetic { private lazy val numeric = TypeUtils.getNumeric(dataType) - protected override def evalInternal(evalE: Any) = numeric.abs(evalE) + protected override def nullSafeEval(input: Any): Any = numeric.abs(input) } abstract class BinaryArithmetic extends BinaryOperator { @@ -94,20 +82,6 @@ abstract class BinaryArithmetic extends BinaryOperator { protected def checkTypesInternal(t: DataType): TypeCheckResult - override def eval(input: InternalRow): Any = { - val evalE1 = left.eval(input) - if(evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - evalInternal(evalE1, evalE2) - } - } - } - /** Name of the function for this expression on a [[Decimal]] type. */ def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") @@ -122,9 +96,6 @@ abstract class BinaryArithmetic extends BinaryOperator { case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") } - - protected def evalInternal(evalE1: Any, evalE2: Any): Any = - sys.error(s"BinaryArithmetics must override either eval or evalInternal") } private[sql] object BinaryArithmetic { @@ -143,7 +114,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { private lazy val numeric = TypeUtils.getNumeric(dataType) - protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.plus(evalE1, evalE2) + protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2) } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { @@ -158,7 +129,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti private lazy val numeric = TypeUtils.getNumeric(dataType) - protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.minus(evalE1, evalE2) + protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2) } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { @@ -173,7 +144,7 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti private lazy val numeric = TypeUtils.getNumeric(dataType) - protected override def evalInternal(evalE1: Any, evalE2: Any) = numeric.times(evalE1, evalE2) + protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { @@ -194,15 +165,15 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } override def eval(input: InternalRow): Any = { - val evalE2 = right.eval(input) - if (evalE2 == null || evalE2 == 0) { + val input2 = right.eval(input) + if (input2 == null || input2 == 0) { null } else { - val evalE1 = left.eval(input) - if (evalE1 == null) { + val input1 = left.eval(input) + if (input1 == null) { null } else { - div(evalE1, evalE2) + div(input1, input2) } } } @@ -260,15 +231,15 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } override def eval(input: InternalRow): Any = { - val evalE2 = right.eval(input) - if (evalE2 == null || evalE2 == 0) { + val input2 = right.eval(input) + if (input2 == null || input2 == 0) { null } else { - val evalE1 = left.eval(input) - if (evalE1 == null) { + val input1 = left.eval(input) + if (input1 == null) { null } else { - integral.rem(evalE1, evalE2) + integral.rem(input1, input2) } } } @@ -317,17 +288,17 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { private lazy val ordering = TypeUtils.getOrdering(dataType) override def eval(input: InternalRow): Any = { - val evalE1 = left.eval(input) - val evalE2 = right.eval(input) - if (evalE1 == null) { - evalE2 - } else if (evalE2 == null) { - evalE1 + val input1 = left.eval(input) + val input2 = right.eval(input) + if (input1 == null) { + input2 + } else if (input2 == null) { + input1 } else { - if (ordering.compare(evalE1, evalE2) < 0) { - evalE2 + if (ordering.compare(input1, input2) < 0) { + input2 } else { - evalE1 + input1 } } } @@ -371,17 +342,17 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { private lazy val ordering = TypeUtils.getOrdering(dataType) override def eval(input: InternalRow): Any = { - val evalE1 = left.eval(input) - val evalE2 = right.eval(input) - if (evalE1 == null) { - evalE2 - } else if (evalE2 == null) { - evalE1 + val input1 = left.eval(input) + val input2 = right.eval(input) + if (input1 == null) { + input2 + } else if (input2 == null) { + input1 } else { - if (ordering.compare(evalE1, evalE2) < 0) { - evalE1 + if (ordering.compare(input1, input2) < 0) { + input1 } else { - evalE2 + input2 } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala index 9002dda7bf..2d47124d24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala @@ -45,7 +45,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any] } - protected override def evalInternal(evalE1: Any, evalE2: Any) = and(evalE1, evalE2) + protected override def nullSafeEval(input1: Any, input2: Any): Any = and(input1, input2) } /** @@ -70,7 +70,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any] } - protected override def evalInternal(evalE1: Any, evalE2: Any) = or(evalE1, evalE2) + protected override def nullSafeEval(input1: Any, input2: Any): Any = or(input1, input2) } /** @@ -95,7 +95,7 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any] } - protected override def evalInternal(evalE1: Any, evalE2: Any): Any = xor(evalE1, evalE2) + protected override def nullSafeEval(input1: Any, input2: Any): Any = xor(input1, input2) } /** @@ -122,5 +122,5 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic { defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)") } - protected override def evalInternal(evalE: Any) = not(evalE) + protected override def nullSafeEval(input: Any): Any = not(input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index f5c2dde191..2fa74b4ffc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -30,14 +30,8 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { override def dataType: DataType = LongType override def toString: String = s"UnscaledValue($child)" - override def eval(input: InternalRow): Any = { - val childResult = child.eval(input) - if (childResult == null) { - null - } else { - childResult.asInstanceOf[Decimal].toUnscaledLong - } - } + protected override def nullSafeEval(input: Any): Any = + input.asInstanceOf[Decimal].toUnscaledLong override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") @@ -54,26 +48,15 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un override def dataType: DataType = DecimalType(precision, scale) override def toString: String = s"MakeDecimal($child,$precision,$scale)" - override def eval(input: InternalRow): Decimal = { - val childResult = child.eval(input) - if (childResult == null) { - null - } else { - new Decimal().setOrNull(childResult.asInstanceOf[Long], precision, scale) - } - } + protected override def nullSafeEval(input: Any): Any = + Decimal(input.asInstanceOf[Long], precision, scale) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val eval = child.gen(ctx) - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.decimalType} ${ev.primitive} = null; - - if (!${ev.isNull}) { - ${ev.primitive} = (new ${ctx.decimalType}()).setOrNull( - ${eval.primitive}, $precision, $scale); + nullSafeCodeGen(ctx, ev, eval => { + s""" + ${ev.primitive} = (new ${ctx.decimalType}()).setOrNull($eval, $precision, $scale); ${ev.isNull} = ${ev.primitive} == null; - } """ + }) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 9250045398..9dca8513c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -61,21 +61,16 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) override def nullable: Boolean = true override def toString: String = s"$name($child)" - override def eval(input: InternalRow): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - val result = f(evalE.asInstanceOf[Double]) - if (result.isNaN) null else result - } + protected override def nullSafeEval(input: Any): Any = { + val result = f(input.asInstanceOf[Double]) + if (result.isNaN) null else result } // name of function in java.lang.Math def funcName: String = name.toLowerCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (result, eval) => { + nullSafeCodeGen(ctx, ev, eval => { s""" ${ev.primitive} = java.lang.Math.${funcName}($eval); if (Double.valueOf(${ev.primitive}).isNaN()) { @@ -101,19 +96,9 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) override def dataType: DataType = DoubleType - override def eval(input: InternalRow): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - val result = f(evalE1.asInstanceOf[Double], evalE2.asInstanceOf[Double]) - if (result.isNaN) null else result - } - } + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val result = f(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) + if (result.isNaN) null else result } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -194,39 +179,29 @@ case class Factorial(child: Expression) extends UnaryExpression with ExpectsInpu override def dataType: DataType = LongType - override def foldable: Boolean = child.foldable - // If the value not in the range of [0, 20], it still will be null, so set it to be true here. override def nullable: Boolean = true - override def eval(input: InternalRow): Any = { - val evalE = child.eval(input) - if (evalE == null) { + protected override def nullSafeEval(input: Any): Any = { + val value = input.asInstanceOf[jl.Integer] + if (value > 20 || value < 0) { null } else { - val input = evalE.asInstanceOf[jl.Integer] - if (input > 20 || input < 0) { - null - } else { - Factorial.factorial(input) - } + Factorial.factorial(value) } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val eval = child.gen(ctx) - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - if (${eval.primitive} > 20 || ${eval.primitive} < 0) { + nullSafeCodeGen(ctx, ev, eval => { + s""" + if ($eval > 20 || $eval < 0) { ${ev.isNull} = true; } else { ${ev.primitive} = - org.apache.spark.sql.catalyst.expressions.Factorial.factorial(${eval.primitive}); + org.apache.spark.sql.catalyst.expressions.Factorial.factorial($eval); } - } - """ + """ + }) } } @@ -235,17 +210,14 @@ case class Log(child: Expression) extends UnaryMathExpression(math.log, "LOG") case class Log2(child: Expression) extends UnaryMathExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val eval = child.gen(ctx) - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = java.lang.Math.log(${eval.primitive}) / java.lang.Math.log(2); + nullSafeCodeGen(ctx, ev, eval => { + s""" + ${ev.primitive} = java.lang.Math.log($eval) / java.lang.Math.log(2); if (Double.valueOf(${ev.primitive}).isNaN()) { ${ev.isNull} = true; } - } - """ + """ + }) } } @@ -283,14 +255,8 @@ case class Bin(child: Expression) override def inputTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = StringType - override def eval(input: InternalRow): Any = { - val evalE = child.eval(input) - if (evalE == null) { - null - } else { - UTF8String.fromString(jl.Long.toBinaryString(evalE.asInstanceOf[Long])) - } - } + protected override def nullSafeEval(input: Any): Any = + UTF8String.fromString(jl.Long.toBinaryString(input.asInstanceOf[Long])) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c) => @@ -326,17 +292,10 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes override def dataType: DataType = StringType - override def eval(input: InternalRow): Any = { - val num = child.eval(input) - if (num == null) { - null - } else { - child.dataType match { - case LongType => hex(num.asInstanceOf[Long]) - case BinaryType => hex(num.asInstanceOf[Array[Byte]]) - case StringType => hex(num.asInstanceOf[UTF8String].getBytes) - } - } + protected override def nullSafeEval(num: Any): Any = child.dataType match { + case LongType => hex(num.asInstanceOf[Long]) + case BinaryType => hex(num.asInstanceOf[Array[Byte]]) + case StringType => hex(num.asInstanceOf[UTF8String].getBytes) } private[this] def hex(bytes: Array[Byte]): UTF8String = { @@ -377,14 +336,8 @@ case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTyp override def nullable: Boolean = true override def dataType: DataType = BinaryType - override def eval(input: InternalRow): Any = { - val num = child.eval(input) - if (num == null) { - null - } else { - unhex(num.asInstanceOf[UTF8String].getBytes) - } - } + protected override def nullSafeEval(num: Any): Any = + unhex(num.asInstanceOf[UTF8String].getBytes) private[this] def unhex(bytes: Array[Byte]): Array[Byte] = { val out = new Array[Byte]((bytes.length + 1) >> 1) @@ -429,21 +382,10 @@ case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTyp case class Atan2(left: Expression, right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { - override def eval(input: InternalRow): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 - val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, - evalE2.asInstanceOf[Double] + 0.0) - if (result.isNaN) null else result - } - } + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 + val result = math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0) + if (result.isNaN) null else result } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -480,25 +422,15 @@ case class ShiftLeft(left: Expression, right: Expression) override def dataType: DataType = left.dataType - override def eval(input: InternalRow): Any = { - val valueLeft = left.eval(input) - if (valueLeft != null) { - val valueRight = right.eval(input) - if (valueRight != null) { - valueLeft match { - case l: jl.Long => l << valueRight.asInstanceOf[jl.Integer] - case i: jl.Integer => i << valueRight.asInstanceOf[jl.Integer] - } - } else { - null - } - } else { - null + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + input1 match { + case l: jl.Long => l << input2.asInstanceOf[jl.Integer] + case i: jl.Integer => i << input2.asInstanceOf[jl.Integer] } } override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left << $right;") + defineCodeGen(ctx, ev, (left, right) => s"$left << $right") } } @@ -516,25 +448,15 @@ case class ShiftRight(left: Expression, right: Expression) override def dataType: DataType = left.dataType - override def eval(input: InternalRow): Any = { - val valueLeft = left.eval(input) - if (valueLeft != null) { - val valueRight = right.eval(input) - if (valueRight != null) { - valueLeft match { - case l: jl.Long => l >> valueRight.asInstanceOf[jl.Integer] - case i: jl.Integer => i >> valueRight.asInstanceOf[jl.Integer] - } - } else { - null - } - } else { - null + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + input1 match { + case l: jl.Long => l >> input2.asInstanceOf[jl.Integer] + case i: jl.Integer => i >> input2.asInstanceOf[jl.Integer] } } override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >> $right;") + defineCodeGen(ctx, ev, (left, right) => s"$left >> $right") } } @@ -552,25 +474,15 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) override def dataType: DataType = left.dataType - override def eval(input: InternalRow): Any = { - val valueLeft = left.eval(input) - if (valueLeft != null) { - val valueRight = right.eval(input) - if (valueRight != null) { - valueLeft match { - case l: jl.Long => l >>> valueRight.asInstanceOf[jl.Integer] - case i: jl.Integer => i >>> valueRight.asInstanceOf[jl.Integer] - } - } else { - null - } - } else { - null + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + input1 match { + case l: jl.Long => l >>> input2.asInstanceOf[jl.Integer] + case i: jl.Integer => i >>> input2.asInstanceOf[jl.Integer] } } override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >>> $right;") + defineCodeGen(ctx, ev, (left, right) => s"$left >>> $right") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index e008af3966..3b59cd431b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -37,14 +37,8 @@ case class Md5(child: Expression) extends UnaryExpression with ExpectsInputTypes override def inputTypes: Seq[DataType] = Seq(BinaryType) - override def eval(input: InternalRow): Any = { - val value = child.eval(input) - if (value == null) { - null - } else { - UTF8String.fromString(DigestUtils.md5Hex(value.asInstanceOf[Array[Byte]])) - } - } + protected override def nullSafeEval(input: Any): Any = + UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[Array[Byte]])) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => @@ -67,76 +61,56 @@ case class Sha2(left: Expression, right: Expression) override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) - override def eval(input: InternalRow): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - val bitLength = evalE2.asInstanceOf[Int] - val input = evalE1.asInstanceOf[Array[Byte]] - bitLength match { - case 224 => - // DigestUtils doesn't support SHA-224 now - try { - val md = MessageDigest.getInstance("SHA-224") - md.update(input) - UTF8String.fromBytes(md.digest()) - } catch { - // SHA-224 is not supported on the system, return null - case noa: NoSuchAlgorithmException => null - } - case 256 | 0 => - UTF8String.fromString(DigestUtils.sha256Hex(input)) - case 384 => - UTF8String.fromString(DigestUtils.sha384Hex(input)) - case 512 => - UTF8String.fromString(DigestUtils.sha512Hex(input)) - case _ => null + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val bitLength = input2.asInstanceOf[Int] + val input = input1.asInstanceOf[Array[Byte]] + bitLength match { + case 224 => + // DigestUtils doesn't support SHA-224 now + try { + val md = MessageDigest.getInstance("SHA-224") + md.update(input) + UTF8String.fromBytes(md.digest()) + } catch { + // SHA-224 is not supported on the system, return null + case noa: NoSuchAlgorithmException => null } - } + case 256 | 0 => + UTF8String.fromString(DigestUtils.sha256Hex(input)) + case 384 => + UTF8String.fromString(DigestUtils.sha384Hex(input)) + case 512 => + UTF8String.fromString(DigestUtils.sha512Hex(input)) + case _ => null } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) val digestUtils = "org.apache.commons.codec.digest.DigestUtils" - - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - if (${eval2.primitive} == 224) { - try { - java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224"); - md.update(${eval1.primitive}); - ${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest()); - } catch (java.security.NoSuchAlgorithmException e) { - ${ev.isNull} = true; - } - } else if (${eval2.primitive} == 256 || ${eval2.primitive} == 0) { - ${ev.primitive} = - ${ctx.stringType}.fromString(${digestUtils}.sha256Hex(${eval1.primitive})); - } else if (${eval2.primitive} == 384) { - ${ev.primitive} = - ${ctx.stringType}.fromString(${digestUtils}.sha384Hex(${eval1.primitive})); - } else if (${eval2.primitive} == 512) { - ${ev.primitive} = - ${ctx.stringType}.fromString(${digestUtils}.sha512Hex(${eval1.primitive})); - } else { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + if ($eval2 == 224) { + try { + java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224"); + md.update($eval1); + ${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest()); + } catch (java.security.NoSuchAlgorithmException e) { ${ev.isNull} = true; } + } else if ($eval2 == 256 || $eval2 == 0) { + ${ev.primitive} = + ${ctx.stringType}.fromString($digestUtils.sha256Hex($eval1)); + } else if ($eval2 == 384) { + ${ev.primitive} = + ${ctx.stringType}.fromString($digestUtils.sha384Hex($eval1)); + } else if ($eval2 == 512) { + ${ev.primitive} = + ${ctx.stringType}.fromString($digestUtils.sha512Hex($eval1)); } else { ${ev.isNull} = true; } - } - """ + """ + }) } } @@ -150,19 +124,12 @@ case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputType override def inputTypes: Seq[DataType] = Seq(BinaryType) - override def eval(input: InternalRow): Any = { - val value = child.eval(input) - if (value == null) { - null - } else { - UTF8String.fromString(DigestUtils.shaHex(value.asInstanceOf[Array[Byte]])) - } - } + protected override def nullSafeEval(input: Any): Any = + UTF8String.fromString(DigestUtils.shaHex(input.asInstanceOf[Array[Byte]])) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => - "org.apache.spark.unsafe.types.UTF8String.fromString" + - s"(org.apache.commons.codec.digest.DigestUtils.shaHex($c))" + s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.shaHex($c))" ) } } @@ -177,30 +144,20 @@ case class Crc32(child: Expression) extends UnaryExpression with ExpectsInputTyp override def inputTypes: Seq[DataType] = Seq(BinaryType) - override def eval(input: InternalRow): Any = { - val value = child.eval(input) - if (value == null) { - null - } else { - val checksum = new CRC32 - checksum.update(value.asInstanceOf[Array[Byte]], 0, value.asInstanceOf[Array[Byte]].length) - checksum.getValue - } + protected override def nullSafeEval(input: Any): Any = { + val checksum = new CRC32 + checksum.update(input.asInstanceOf[Array[Byte]], 0, input.asInstanceOf[Array[Byte]].length) + checksum.getValue } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val value = child.gen(ctx) val CRC32 = "java.util.zip.CRC32" - s""" - ${value.code} - boolean ${ev.isNull} = ${value.isNull}; - long ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${CRC32} checksum = new ${CRC32}(); - checksum.update(${value.primitive}, 0, ${value.primitive}.length); + nullSafeCodeGen(ctx, ev, value => { + s""" + $CRC32 checksum = new $CRC32(); + checksum.update($value, 0, $value.length); ${ev.primitive} = checksum.getValue(); - } - """ + """ + }) } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 0b479f466c..402a0aa232 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -74,12 +74,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex override def inputTypes: Seq[DataType] = Seq(BooleanType) - override def eval(input: InternalRow): Any = { - child.eval(input) match { - case null => null - case b: Boolean => !b - } - } + protected override def nullSafeEval(input: Any): Any = !input.asInstanceOf[Boolean] override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"!($c)") @@ -105,17 +100,14 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { * Optimized version of In clause, when all filter values of In clause are * static. */ -case class InSet(value: Expression, hset: Set[Any]) - extends Predicate { - - override def children: Seq[Expression] = value :: Nil +case class InSet(child: Expression, hset: Set[Any]) + extends UnaryExpression with Predicate { - override def foldable: Boolean = value.foldable override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN. - override def toString: String = s"$value INSET ${hset.mkString("(", ",", ")")}" + override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}" override def eval(input: InternalRow): Any = { - hset.contains(value.eval(input)) + hset.contains(child.eval(input)) } } @@ -127,15 +119,15 @@ case class And(left: Expression, right: Expression) override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) override def eval(input: InternalRow): Any = { - val l = left.eval(input) - if (l == false) { + val input1 = left.eval(input) + if (input1 == false) { false } else { - val r = right.eval(input) - if (r == false) { + val input2 = right.eval(input) + if (input2 == false) { false } else { - if (l != null && r != null) { + if (input1 != null && input2 != null) { true } else { null @@ -176,15 +168,15 @@ case class Or(left: Expression, right: Expression) override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) override def eval(input: InternalRow): Any = { - val l = left.eval(input) - if (l == true) { + val input1 = left.eval(input) + if (input1 == true) { true } else { - val r = right.eval(input) - if (r == true) { + val input2 = right.eval(input) + if (input2 == true) { true } else { - if (l != null && r != null) { + if (input1 != null && input2 != null) { false } else { null @@ -232,20 +224,6 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { protected def checkTypesInternal(t: DataType): TypeCheckResult - override def eval(input: InternalRow): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - evalInternal(evalE1, evalE2) - } - } - } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { if (ctx.isPrimitiveType(left.dataType)) { // faster version @@ -254,9 +232,6 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { defineCodeGen(ctx, ev, (c1, c2) => s"${ctx.genComp(left.dataType, c1, c2)} $symbol 0") } } - - protected def evalInternal(evalE1: Any, evalE2: Any): Any = - sys.error(s"BinaryComparisons must override either eval or evalInternal") } private[sql] object BinaryComparison { @@ -277,9 +252,9 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess - protected override def evalInternal(l: Any, r: Any) = { - if (left.dataType != BinaryType) l == r - else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + if (left.dataType != BinaryType) input1 == input2 + else java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -295,15 +270,18 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess override def eval(input: InternalRow): Any = { - val l = left.eval(input) - val r = right.eval(input) - if (l == null && r == null) { + val input1 = left.eval(input) + val input2 = right.eval(input) + if (input1 == null && input2 == null) { true - } else if (l == null || r == null) { + } else if (input1 == null || input2 == null) { false } else { - if (left.dataType != BinaryType) l == r - else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) + if (left.dataType != BinaryType) { + input1 == input2 + } else { + java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) + } } } @@ -327,7 +305,7 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso private lazy val ordering = TypeUtils.getOrdering(left.dataType) - protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lt(evalE1, evalE2) + protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) } case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { @@ -338,7 +316,7 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo private lazy val ordering = TypeUtils.getOrdering(left.dataType) - protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.lteq(evalE1, evalE2) + protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) } case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { @@ -349,7 +327,7 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar private lazy val ordering = TypeUtils.getOrdering(left.dataType) - protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gt(evalE1, evalE2) + protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) } case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { @@ -360,5 +338,5 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar private lazy val ordering = TypeUtils.getOrdering(left.dataType) - protected override def evalInternal(evalE1: Any, evalE2: Any) = ordering.gteq(evalE1, evalE2) + protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 5d51a4ca65..9b44fb1ed5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -135,6 +135,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { */ case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { + override def nullable: Boolean = left.nullable override def dataType: DataType = left.dataType override def eval(input: InternalRow): Any = { @@ -183,12 +184,8 @@ case class CountSet(child: Expression) extends UnaryExpression { override def dataType: DataType = LongType - override def eval(input: InternalRow): Any = { - val childEval = child.eval(input).asInstanceOf[OpenHashSet[Any]] - if (childEval != null) { - childEval.size.toLong - } - } + protected override def nullSafeEval(input: Any): Any = + input.asInstanceOf[OpenHashSet[Any]].size.toLong override def toString: String = s"$child.count()" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 1a14a7a449..6e6a7fb171 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -31,7 +31,6 @@ trait StringRegexExpression extends ExpectsInputTypes { def escape(v: String): String def matches(regex: Pattern, str: String): Boolean - override def nullable: Boolean = left.nullable || right.nullable override def dataType: DataType = BooleanType override def inputTypes: Seq[DataType] = Seq(StringType, StringType) @@ -50,22 +49,12 @@ trait StringRegexExpression extends ExpectsInputTypes { protected def pattern(str: String) = if (cache == null) compile(str) else cache - override def eval(input: InternalRow): Any = { - val l = left.eval(input) - if (l == null) { + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val regex = pattern(input2.asInstanceOf[UTF8String].toString()) + if(regex == null) { null } else { - val r = right.eval(input) - if(r == null) { - null - } else { - val regex = pattern(r.asInstanceOf[UTF8String].toString()) - if(regex == null) { - null - } else { - matches(regex, l.asInstanceOf[UTF8String].toString()) - } - } + matches(regex, input1.asInstanceOf[UTF8String].toString()) } } } @@ -120,14 +109,8 @@ trait CaseConversionExpression extends ExpectsInputTypes { override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType) - override def eval(input: InternalRow): Any = { - val evaluated = child.eval(input) - if (evaluated == null) { - null - } else { - convert(evaluated.asInstanceOf[UTF8String]) - } - } + protected override def nullSafeEval(input: Any): Any = + convert(input.asInstanceOf[UTF8String]) } /** @@ -160,20 +143,10 @@ trait StringComparison extends ExpectsInputTypes { def compare(l: UTF8String, r: UTF8String): Boolean - override def nullable: Boolean = left.nullable || right.nullable - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - override def eval(input: InternalRow): Any = { - val leftEval = left.eval(input) - if(leftEval == null) { - null - } else { - val rightEval = right.eval(input) - if (rightEval == null) null - else compare(leftEval.asInstanceOf[UTF8String], rightEval.asInstanceOf[UTF8String]) - } - } + protected override def nullSafeEval(input1: Any, input2: Any): Any = + compare(input1.asInstanceOf[UTF8String], input2.asInstanceOf[UTF8String]) override def toString: String = s"$nodeName($left, $right)" } @@ -288,10 +261,8 @@ case class StringLength(child: Expression) extends UnaryExpression with ExpectsI override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType) - override def eval(input: InternalRow): Any = { - val string = child.eval(input) - if (string == null) null else string.asInstanceOf[UTF8String].length - } + protected override def nullSafeEval(string: Any): Any = + string.asInstanceOf[UTF8String].length override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).length()") @@ -310,24 +281,13 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres override def dataType: DataType = IntegerType - override def eval(input: InternalRow): Any = { - val leftValue = left.eval(input) - if (leftValue == null) { - null - } else { - val rightValue = right.eval(input) - if(rightValue == null) { - null - } else { - StringUtils.getLevenshteinDistance(leftValue.toString, rightValue.toString) - } - } - } + protected override def nullSafeEval(input1: Any, input2: Any): Any = + StringUtils.getLevenshteinDistance(input1.toString, input2.toString) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val stringUtils = classOf[StringUtils].getName - nullSafeCodeGen(ctx, ev, (res, left, right) => - s"$res = $stringUtils.getLevenshteinDistance($left.toString(), $right.toString());") + defineCodeGen(ctx, ev, (left, right) => + s"$stringUtils.getLevenshteinDistance($left.toString(), $right.toString())") } } @@ -338,17 +298,12 @@ case class Ascii(child: Expression) extends UnaryExpression with ExpectsInputTyp override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType) - override def eval(input: InternalRow): Any = { - val string = child.eval(input) - if (string == null) { - null + protected override def nullSafeEval(string: Any): Any = { + val bytes = string.asInstanceOf[UTF8String].getBytes + if (bytes.length > 0) { + bytes(0).asInstanceOf[Int] } else { - val bytes = string.asInstanceOf[UTF8String].getBytes - if (bytes.length > 0) { - bytes(0).asInstanceOf[Int] - } else { - 0 - } + 0 } } } @@ -360,15 +315,10 @@ case class Base64(child: Expression) extends UnaryExpression with ExpectsInputTy override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType) - override def eval(input: InternalRow): Any = { - val bytes = child.eval(input) - if (bytes == null) { - null - } else { - UTF8String.fromBytes( - org.apache.commons.codec.binary.Base64.encodeBase64( - bytes.asInstanceOf[Array[Byte]])) - } + protected override def nullSafeEval(bytes: Any): Any = { + UTF8String.fromBytes( + org.apache.commons.codec.binary.Base64.encodeBase64( + bytes.asInstanceOf[Array[Byte]])) } } @@ -379,14 +329,8 @@ case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInput override def dataType: DataType = BinaryType override def inputTypes: Seq[DataType] = Seq(StringType) - override def eval(input: InternalRow): Any = { - val string = child.eval(input) - if (string == null) { - null - } else { - org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString) - } - } + protected override def nullSafeEval(string: Any): Any = + org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString) } /** @@ -402,19 +346,9 @@ case class Decode(bin: Expression, charset: Expression) override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(BinaryType, StringType) - override def eval(input: InternalRow): Any = { - val l = bin.eval(input) - if (l == null) { - null - } else { - val r = charset.eval(input) - if (r == null) { - null - } else { - val fromCharset = r.asInstanceOf[UTF8String].toString - UTF8String.fromString(new String(l.asInstanceOf[Array[Byte]], fromCharset)) - } - } + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val fromCharset = input2.asInstanceOf[UTF8String].toString + UTF8String.fromString(new String(input1.asInstanceOf[Array[Byte]], fromCharset)) } } @@ -431,19 +365,9 @@ case class Encode(value: Expression, charset: Expression) override def dataType: DataType = BinaryType override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - override def eval(input: InternalRow): Any = { - val l = value.eval(input) - if (l == null) { - null - } else { - val r = charset.eval(input) - if (r == null) { - null - } else { - val toCharset = r.asInstanceOf[UTF8String].toString - l.asInstanceOf[UTF8String].toString.getBytes(toCharset) - } - } + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val toCharset = input2.asInstanceOf[UTF8String].toString + input1.asInstanceOf[UTF8String].toString.getBytes(toCharset) } } -- GitLab