diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 005de3166095ff242e044098685f50538fa2e5a3..fcadf9595e768b60383e9acdbad6fbf921b38680 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.trees @@ -43,7 +43,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def exprId: ExprId = throw new UnsupportedOperationException - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { s""" boolean ${ev.isNull} = i.isNullAt($ordinal); ${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 2a1f96409daf4bc9f195aa1099e62f0d8eb053dd..18102d1acb5b3a267ae07970419b51811ee72805 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -435,7 +435,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w if (evaluated == null) null else cast(evaluated) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { // TODO(cg): Add support for more data types. (child.dataType, dataType) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 432d65eee54fb5e5a2ca8e1e147e7bd2cd372186..a9a9c0cfb7027f5f93d5be0c47c9143dd744effe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext, Term} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ @@ -76,7 +76,7 @@ abstract class Expression extends TreeNode[Expression] { * @param ev an [[GeneratedExpressionCode]] with unique terms. * @return Java source code */ - protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { ctx.references += this val objectTerm = ctx.freshName("obj") s""" @@ -166,7 +166,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express protected def defineCodeGen( ctx: CodeGenContext, ev: GeneratedExpressionCode, - f: (Term, Term) => Code): String = { + f: (String, String) => String): String = { // TODO: Right now some timestamp tests fail if we enforce this... if (left.dataType != right.dataType) { // log.warn(s"${left.dataType} != ${right.dataType}") @@ -182,7 +182,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${eval2.code} - if(!${eval2.isNull}) { + if (!${eval2.isNull}) { ${ev.primitive} = $resultCode; } else { ${ev.isNull} = true; @@ -217,7 +217,7 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio protected def defineCodeGen( ctx: CodeGenContext, ev: GeneratedExpressionCode, - f: Term => Code): Code = { + f: String => String): String = { val eval = child.gen(ctx) // reuse the previous isNull ev.isNull = eval.isNull diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d4efda2e04c2994c8ee0a8b79682729661380f05..124274c94203c0e6eb4781df6a7836ac876bc8ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{Code, GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -50,7 +50,7 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { private lazy val numeric = TypeUtils.getNumeric(dataType) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()") case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)") } @@ -74,7 +74,7 @@ case class Sqrt(child: Expression) extends UnaryArithmetic { else math.sqrt(value) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = child.gen(ctx) eval.code + s""" boolean ${ev.isNull} = ${eval.isNull}; @@ -138,7 +138,7 @@ abstract class BinaryArithmetic extends BinaryExpression { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = dataType match { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") // byte and short are casted into int when add, minus, times or divide @@ -236,7 +236,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic /** * Special case handling due to division by 0 => null. */ - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val test = if (left.dataType.isInstanceOf[DecimalType]) { @@ -296,7 +296,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet /** * Special case handling for x % 0 ==> null. */ - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val test = if (left.dataType.isInstanceOf[DecimalType]) { @@ -346,7 +346,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { if (ctx.isNativeType(left.dataType)) { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -400,7 +400,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { if (ctx.isNativeType(left.dataType)) { val eval1 = left.gen(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala index ef34586261e7062ee8d30e44cfb8748151ce3df6..9002dda7bf4d0302d67e569624976559b779377d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala @@ -25,6 +25,8 @@ import org.apache.spark.sql.types._ /** * A function that calculates bitwise and(&) of two numbers. + * + * Code generation inherited from BinaryArithmetic. */ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "&" @@ -48,6 +50,8 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme /** * A function that calculates bitwise or(|) of two numbers. + * + * Code generation inherited from BinaryArithmetic. */ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "|" @@ -71,6 +75,8 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet /** * A function that calculates bitwise xor of two numbers. + * + * Code generation inherited from BinaryArithmetic. */ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "^" @@ -112,8 +118,8 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic { ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)})~($c)") + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)") } protected override def evalInternal(evalE: Any) = not(evalE) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c8d0aaf79f5f2ecf981e0fd8a09b3847f372e283..e95682f952a7b327330f88072779394905f02691 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -40,7 +40,7 @@ class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] * @param primitive A term for a possible primitive value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class GeneratedExpressionCode(var code: Code, var isNull: Term, var primitive: Term) +case class GeneratedExpressionCode(var code: String, var isNull: String, var primitive: String) /** * A context for codegen, which is used to bookkeeping the expressions those are not supported @@ -65,14 +65,14 @@ class CodeGenContext { * (Since we aren't in a macro context we do not seem to have access to the built in `freshName` * function.) */ - def freshName(prefix: String): Term = { + def freshName(prefix: String): String = { s"$prefix${curId.getAndIncrement}" } /** * Return the code to access a column for given DataType */ - def getColumn(dataType: DataType, ordinal: Int): Code = { + def getColumn(dataType: DataType, ordinal: Int): String = { if (isNativeType(dataType)) { s"i.${accessorForType(dataType)}($ordinal)" } else { @@ -83,7 +83,7 @@ class CodeGenContext { /** * Return the code to update a column in Row for given DataType */ - def setColumn(dataType: DataType, ordinal: Int, value: Term): Code = { + def setColumn(dataType: DataType, ordinal: Int, value: String): String = { if (isNativeType(dataType)) { s"${mutatorForType(dataType)}($ordinal, $value)" } else { @@ -94,7 +94,7 @@ class CodeGenContext { /** * Return the name of accessor in Row for a DataType */ - def accessorForType(dt: DataType): Term = dt match { + def accessorForType(dt: DataType): String = dt match { case IntegerType => "getInt" case other => s"get${boxedType(dt)}" } @@ -102,7 +102,7 @@ class CodeGenContext { /** * Return the name of mutator in Row for a DataType */ - def mutatorForType(dt: DataType): Term = dt match { + def mutatorForType(dt: DataType): String = dt match { case IntegerType => "setInt" case other => s"set${boxedType(dt)}" } @@ -110,7 +110,7 @@ class CodeGenContext { /** * Return the Java type for a DataType */ - def javaType(dt: DataType): Term = dt match { + def javaType(dt: DataType): String = dt match { case IntegerType => "int" case LongType => "long" case ShortType => "short" @@ -131,7 +131,7 @@ class CodeGenContext { /** * Return the boxed type in Java */ - def boxedType(dt: DataType): Term = dt match { + def boxedType(dt: DataType): String = dt match { case IntegerType => "Integer" case LongType => "Long" case ShortType => "Short" @@ -146,7 +146,7 @@ class CodeGenContext { /** * Return the representation of default value for given DataType */ - def defaultValue(dt: DataType): Term = dt match { + def defaultValue(dt: DataType): String = dt match { case BooleanType => "false" case FloatType => "-1.0f" case ShortType => "(short)-1" @@ -161,7 +161,7 @@ class CodeGenContext { /** * Returns a function to generate equal expression in Java */ - def equalFunc(dataType: DataType): ((Term, Term) => Code) = dataType match { + def equalFunc(dataType: DataType): ((String, String) => String) = dataType match { case BinaryType => { case (eval1, eval2) => s"java.util.Arrays.equals($eval1, $eval2)" } case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala index 6f9589d20445e32547db8a95edf84e9b8aef1883..7f1b12cdd580035e5f2fa6f8df6bbb16fa31d59f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/package.scala @@ -27,9 +27,6 @@ import org.apache.spark.util.Utils */ package object codegen { - type Term = String - type Code = String - /** Canonicalizes an expression so those that differ only by names can reuse the same code. */ object ExpressionCanonicalizer extends rules.RuleExecutor[Expression] { val batches = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index 3aa86edd7ab205e653482714c32ad589fbc70063..1a5cde26c9b13d83512dcbf38da63c92b63bd493 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -50,7 +50,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val condEval = predicate.gen(ctx) val trueEval = trueValue.gen(ctx) val falseEval = falseValue.gen(ctx) @@ -155,7 +155,7 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { return res } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val len = branchesArr.length val got = ctx.freshName("got") @@ -248,7 +248,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW return res } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val keyEval = key.gen(ctx) val len = branchesArr.length val got = ctx.freshName("got") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index ddfadf314f838e5ce6f5bb2462f2e7fafa1adbd7..8ab6d977dd3a65c58e8c85a8d0cd0d2d7e1b6471 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ /** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ @@ -37,7 +37,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") } } @@ -59,7 +59,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = child.gen(ctx) eval.code + s""" boolean ${ev.isNull} = ${eval.isNull}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 3a9271678bc9c31c5ae7d2cf62b73f6b5e2cdfc1..297b35b4da94c077054ebb4e62ccea6a85d08348 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen.{Code, CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.DateUtils import org.apache.spark.sql.types._ @@ -88,7 +88,7 @@ case class Literal protected (value: Any, dataType: DataType) extends LeafExpres override def eval(input: Row): Any = value - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { // change the isNull and primitive to consts, to inline them if (value == null) { ev.isNull = "true" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index a18067e4a58f1288e4bf344c69a24ba15c2cd314..7dacb6a9b47b6f92fd15b7dc733b10347e68b72e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -48,7 +48,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) // name of function in java.lang.Math def funcName: String = name.toLowerCase - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = child.gen(ctx) eval.code + s""" boolean ${ev.isNull} = ${eval.isNull}; @@ -93,7 +93,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") } } @@ -180,7 +180,7 @@ case class Atan2(left: Expression, right: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" if (Double.valueOf(${ev.primitive}).isNaN()) { ${ev.isNull} = true; @@ -194,7 +194,7 @@ case class Hypot(left: Expression, right: Expression) case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" if (Double.valueOf(${ev.primitive}).isNaN()) { ${ev.isNull} = true; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 9ecfb3ccc262f9a3fe77d3639aec1f90b07b8e37..c2d1a4eadae29cf61ec6a906c57a8bae29399c31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.types.DataType @@ -53,7 +53,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { result } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { s""" boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; @@ -81,7 +81,7 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr child.eval(input) == null } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = child.gen(ctx) ev.isNull = "false" ev.primitive = eval.isNull @@ -100,7 +100,7 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E child.eval(input) != null } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = child.gen(ctx) ev.isNull = "false" ev.primitive = s"(!(${eval.isNull}))" @@ -130,7 +130,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate numNonNulls >= n } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val nonnull = ctx.freshName("nonnull") val code = children.map { e => val eval = e.gen(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 5edcf3bd77d201fd7117f724cf9601bbcf13331e..3cbdfdfb13847ba533111ba7bcda5ccde067259f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ @@ -84,7 +84,7 @@ case class Not(child: Expression) extends UnaryExpression with Predicate with Ex } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"!($c)") } } @@ -147,7 +147,7 @@ case class And(left: Expression, right: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -155,7 +155,7 @@ case class And(left: Expression, right: Expression) s""" ${eval1.code} boolean ${ev.isNull} = false; - boolean ${ev.primitive} = false; + boolean ${ev.primitive} = false; if (!${eval1.isNull} && !${eval1.primitive}) { } else { @@ -196,7 +196,7 @@ case class Or(left: Expression, right: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -249,7 +249,7 @@ abstract class BinaryComparison extends BinaryExpression with Predicate { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { left.dataType match { case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, { (c1, c3) => s"$c1 $symbol $c3" @@ -280,7 +280,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison if (left.dataType != BinaryType) l == r else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType)) } } @@ -304,7 +304,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val equalCode = ctx.equalFunc(left.dataType)(eval1.primitive, eval2.primitive) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index b39349b988389cac69ed22bfb8d40635cc01eada..2bcb960e9177ea8ed53a46b5407eb18458bb0f60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, Code, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -61,7 +61,7 @@ case class NewSet(elementType: DataType) extends LeafExpression { new OpenHashSet[Any]() } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { elementType match { case IntegerType | LongType => ev.isNull = "false" @@ -103,7 +103,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val elementType = set.dataType.asInstanceOf[OpenHashSetUDT].elementType elementType match { case IntegerType | LongType => @@ -154,7 +154,7 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val elementType = left.dataType.asInstanceOf[OpenHashSetUDT].elementType elementType match { case IntegerType | LongType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 78adb509b470b1a6030cdd87d0d5516cc6fd0e12..aae122a981e4784b1d62e1d8d62725206bca34c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -139,7 +139,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE override def toString: String = s"Upper($child)" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") } } @@ -153,7 +153,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE override def toString: String = s"Lower($child)" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") } } @@ -190,7 +190,7 @@ trait StringComparison extends ExpectsInputTypes { case class Contains(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") } } @@ -201,7 +201,7 @@ case class Contains(left: Expression, right: Expression) case class StartsWith(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") } } @@ -212,7 +212,7 @@ case class StartsWith(left: Expression, right: Expression) case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with Predicate with StringComparison { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): Code = { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") } }