diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index c94b2c0e270b653a55e739e20a2c6fb449a2f891..397abc7391ec65efe3022c8caed8fd3e2dd1818a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ /** @@ -68,7 +68,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def newInstance(): NamedExpression = this - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 6f199cfc5d8cd49cee147b2c0ad204f402d5928f..1072158f0458583bded998cc9e865e2cef6bbd8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -446,7 +446,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { protected override def nullSafeEval(input: Any): Any = cast(input) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval = child.gen(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) eval.code + @@ -460,7 +460,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { private[this] def nullSafeCastFunction( from: DataType, to: DataType, - ctx: CodeGenContext): CastFunction = to match { + ctx: CodegenContext): CastFunction = to match { case _ if from == NullType => (c, evPrim, evNull) => s"$evNull = true;" case _ if to == from => (c, evPrim, evNull) => s"$evPrim = $c;" @@ -491,7 +491,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's // Key and Value, Struct's field, we need to name out all the variable names involved in a cast. - private[this] def castCode(ctx: CodeGenContext, childPrim: String, childNull: String, + private[this] def castCode(ctx: CodegenContext, childPrim: String, childNull: String, resultPrim: String, resultNull: String, resultType: DataType, cast: CastFunction): String = { s""" boolean $resultNull = $childNull; @@ -502,7 +502,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { """ } - private[this] def castToStringCode(from: DataType, ctx: CodeGenContext): CastFunction = { + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromBytes($c);" @@ -524,7 +524,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { private[this] def castToDateCode( from: DataType, - ctx: CodeGenContext): CastFunction = from match { + ctx: CodegenContext): CastFunction = from match { case StringType => val intOpt = ctx.freshName("intOpt") (c, evPrim, evNull) => s""" @@ -556,7 +556,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { private[this] def castToDecimalCode( from: DataType, target: DecimalType, - ctx: CodeGenContext): CastFunction = { + ctx: CodegenContext): CastFunction = { val tmp = ctx.freshName("tmpDecimal") from match { case StringType => @@ -614,7 +614,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { private[this] def castToTimestampCode( from: DataType, - ctx: CodeGenContext): CastFunction = from match { + ctx: CodegenContext): CastFunction = from match { case StringType => val longOpt = ctx.freshName("longOpt") (c, evPrim, evNull) => @@ -826,7 +826,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { } private[this] def castArrayCode( - fromType: DataType, toType: DataType, ctx: CodeGenContext): CastFunction = { + fromType: DataType, toType: DataType, ctx: CodegenContext): CastFunction = { val elementCast = nullSafeCastFunction(fromType, toType, ctx) val arrayClass = classOf[GenericArrayData].getName val fromElementNull = ctx.freshName("feNull") @@ -861,7 +861,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { """ } - private[this] def castMapCode(from: MapType, to: MapType, ctx: CodeGenContext): CastFunction = { + private[this] def castMapCode(from: MapType, to: MapType, ctx: CodegenContext): CastFunction = { val keysCast = castArrayCode(from.keyType, to.keyType, ctx) val valuesCast = castArrayCode(from.valueType, to.valueType, ctx) @@ -889,7 +889,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { } private[this] def castStructCode( - from: StructType, to: StructType, ctx: CodeGenContext): CastFunction = { + from: StructType, to: StructType, ctx: CodegenContext): CastFunction = { val fieldsCasts = from.fields.zip(to.fields).map { case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index d6219514b752bb680527044abd729b24ed077eb2..25cf210c4b527642761242e9803e9bd1b2d1f827 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -86,22 +86,22 @@ abstract class Expression extends TreeNode[Expression] { def eval(input: InternalRow = null): Any /** - * Returns an [[GeneratedExpressionCode]], which contains Java source code that + * Returns an [[ExprCode]], which contains Java source code that * can be used to generate the result of evaluating the expression on an input row. * - * @param ctx a [[CodeGenContext]] - * @return [[GeneratedExpressionCode]] + * @param ctx a [[CodegenContext]] + * @return [[ExprCode]] */ - def gen(ctx: CodeGenContext): GeneratedExpressionCode = { + def gen(ctx: CodegenContext): ExprCode = { ctx.subExprEliminationExprs.get(this).map { subExprState => // This expression is repeated meaning the code to evaluated has already been added // as a function and called in advance. Just use it. val code = s"/* ${this.toCommentSafeString} */" - GeneratedExpressionCode(code, subExprState.isNull, subExprState.value) + ExprCode(code, subExprState.isNull, subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") val primitive = ctx.freshName("primitive") - val ve = GeneratedExpressionCode("", isNull, primitive) + val ve = ExprCode("", isNull, primitive) ve.code = genCode(ctx, ve) // Add `this` in the comment. ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) @@ -113,11 +113,11 @@ abstract class Expression extends TreeNode[Expression] { * The default behavior is to call the eval method of the expression. Concrete expression * implementations should override this to do actual code generation. * - * @param ctx a [[CodeGenContext]] - * @param ev an [[GeneratedExpressionCode]] with unique terms. + * @param ctx a [[CodegenContext]] + * @param ev an [[ExprCode]] with unique terms. * @return Java source code */ - protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String + protected def genCode(ctx: CodegenContext, ev: ExprCode): String /** * Returns `true` if this expression and all its children have been resolved to a specific schema @@ -245,7 +245,7 @@ trait Unevaluable extends Expression { final override def eval(input: InternalRow = null): Any = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") - final override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + final override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") } @@ -330,8 +330,8 @@ abstract class UnaryExpression extends Expression { * @param f function that accepts a variable name and returns Java code to compute the output. */ protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, + ctx: CodegenContext, + ev: ExprCode, f: String => String): String = { nullSafeCodeGen(ctx, ev, eval => { s"${ev.value} = ${f(eval)};" @@ -346,8 +346,8 @@ abstract class UnaryExpression extends Expression { * code to compute the output. */ protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, + ctx: CodegenContext, + ev: ExprCode, f: String => String): String = { val eval = child.gen(ctx) if (nullable) { @@ -420,8 +420,8 @@ abstract class BinaryExpression extends Expression { * @param f accepts two variable names and returns Java code to compute the output. */ protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, + ctx: CodegenContext, + ev: ExprCode, f: (String, String) => String): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s"${ev.value} = ${f(eval1, eval2)};" @@ -437,8 +437,8 @@ abstract class BinaryExpression extends Expression { * and returns Java code to compute the output. */ protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, + ctx: CodegenContext, + ev: ExprCode, f: (String, String) => String): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -560,8 +560,8 @@ abstract class TernaryExpression extends Expression { * @param f accepts two variable names and returns Java code to compute the output. */ protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, + ctx: CodegenContext, + ev: ExprCode, f: (String, String, String) => String): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3) => { s"${ev.value} = ${f(eval1, eval2, eval3)};" @@ -577,8 +577,8 @@ abstract class TernaryExpression extends Expression { * and returns Java code to compute the output. */ protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, + ctx: CodegenContext, + ev: ExprCode, f: (String, String, String) => String): String = { val evals = children.map(_.gen(ctx)) val resultCode = f(evals(0).value, evals(1).value, evals(2).value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index 827dce8af100e36187758257287773e3281e2601..c49c601c3034b90bc460477f43937eed8e0da864 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.SqlNewHadoopRDDState import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -43,7 +43,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { SqlNewHadoopRDDState.getInputFileName() } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { ev.isNull = "false" s"final ${ctx.javaType(dataType)} ${ev.value} = " + "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 94f8801dec3692ef44e56692a94140a6b77b25c0..5d28f8fbde8be141e0313eb8ac103dd50c513a34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, LongType} /** @@ -65,7 +65,7 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with partitionMask + currentCount } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val countTerm = ctx.freshName("count") val partitionMaskTerm = ctx.freshName("partitionMask") ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 3a6c909fffce732cd8a21ae9530be44710cab9b0..4035c9dfa4f8cd883ad68df8c987fc41d86d8433 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -974,7 +974,7 @@ case class ScalaUDF( // scalastyle:on line.size.limit // Generate codes used to convert the arguments to Scala type for user-defined funtions - private[this] def genCodeForConverter(ctx: CodeGenContext, index: Int): String = { + private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): String = { val converterClassName = classOf[Any => Any].getName val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" val expressionClassName = classOf[Expression].getName @@ -990,8 +990,8 @@ case class ScalaUDF( } override def genCode( - ctx: CodeGenContext, - ev: GeneratedExpressionCode): String = { + ctx: CodegenContext, + ev: ExprCode): String = { ctx.references += this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 1cb1b9da3049b39a08d1fb8062d1e8bb7832c479..bd1d91487275baa71d36b9ed79d261cfc20a65dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator @@ -69,7 +69,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { override def eval(input: InternalRow): Any = throw new UnsupportedOperationException - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val childCode = child.child.gen(ctx) val input = childCode.value val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index aa3951480c5033497bbc0234ba7fe71fa2178595..377f08eb105fa43680ed29305e6a8f4ba44b2bca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -44,7 +44,7 @@ private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterm override protected def evalInternal(input: InternalRow): Int = partitionId - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val idTerm = ctx.freshName("partitionId") ctx.addMutableState(ctx.JAVA_INT, idTerm, s"$idTerm = org.apache.spark.TaskContext.getPartitionId();") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 7bd851c059d0e8c46d91b6313605ec72d28ec932..1cacd3f76aa3670eebedf7ef5d4408cecf714de4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -34,7 +34,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp private lazy val numeric = TypeUtils.getNumeric(dataType) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { val originValue = ctx.freshName("origin") @@ -65,7 +65,7 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects override def dataType: DataType = child.dataType - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + override def genCode(ctx: CodegenContext, ev: ExprCode): String = defineCodeGen(ctx, ev, c => c) protected override def nullSafeEval(input: Any): Any = input @@ -87,7 +87,7 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes private lazy val numeric = TypeUtils.getNumeric(dataType) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.abs()") case dt: NumericType => @@ -109,7 +109,7 @@ abstract class BinaryArithmetic extends BinaryOperator { def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") // byte and short are casted into int when add, minus, times or divide @@ -141,7 +141,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)") case ByteType | ShortType => @@ -170,7 +170,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)") case ByteType | ShortType => @@ -225,7 +225,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic /** * Special case handling due to division by 0 => null. */ - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val isZero = if (dataType.isInstanceOf[DecimalType]) { @@ -287,7 +287,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet /** * Special case handling for x % 0 ==> null. */ - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val isZero = if (dataType.isInstanceOf[DecimalType]) { @@ -344,7 +344,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val compCode = ctx.genComp(dataType, eval1.value, eval2.value) @@ -398,7 +398,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val compCode = ctx.genComp(dataType, eval1.value, eval2.value) @@ -449,7 +449,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { case _: DecimalType => pmod(left.asInstanceOf[Decimal], right.asInstanceOf[Decimal]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { dataType match { case dt: DecimalType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index a1e48c4210877c70ef57780d977b73c436e76a96..a97bd9edcef84e8ac83a8fd8a2b29f25235fc7db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -118,7 +118,7 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 6daa8ee2f42bf0968f5a43dcea4a09fc458280b0..1c7083bbdacb27874c37962040cb7f8483612b20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -42,14 +42,14 @@ import org.apache.spark.util.Utils * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class GeneratedExpressionCode(var code: String, var isNull: String, var value: String) +case class ExprCode(var code: String, var isNull: String, var value: String) /** * A context for codegen, which is used to bookkeeping the expressions those are not supported * by codegen, then they are evaluated directly. The unsupported expression is appended at the * end of `references`, the position of it is kept in the code, used to access and evaluate it. */ -class CodeGenContext { +class CodegenContext { /** * Holding all the expressions those do not support codegen, will be evaluated directly. @@ -454,7 +454,7 @@ class CodeGenContext { * expression will be combined in the `expressions` order. */ def generateExpressions(expressions: Seq[Expression], - doSubexpressionElimination: Boolean = false): Seq[GeneratedExpressionCode] = { + doSubexpressionElimination: Boolean = false): Seq[ExprCode] = { if (doSubexpressionElimination) subexpressionElimination(expressions) expressions.map(e => e.gen(this)) } @@ -479,17 +479,17 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected val mutableRowType: String = classOf[MutableRow].getName protected val genericMutableRowType: String = classOf[GenericMutableRow].getName - protected def declareMutableStates(ctx: CodeGenContext): String = { + protected def declareMutableStates(ctx: CodegenContext): String = { ctx.mutableStates.map { case (javaType, variableName, _) => s"private $javaType $variableName;" }.mkString("\n") } - protected def initMutableStates(ctx: CodeGenContext): String = { + protected def initMutableStates(ctx: CodegenContext): String = { ctx.mutableStates.map(_._3).mkString("\n") } - protected def declareAddedFunctions(ctx: CodeGenContext): String = { + protected def declareAddedFunctions(ctx: CodegenContext): String = { ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n").trim } @@ -591,7 +591,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * Create a new codegen context for expression evaluator, used to store those * expressions that don't support codegen */ - def newCodeGenContext(): CodeGenContext = { - new CodeGenContext + def newCodeGenContext(): CodegenContext = { + new CodegenContext } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 3353580148799febd16c0b10474f1211b7561ab1..c98b7350b3594eabd675c3a721e2b71d012f0dfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Nondeterministic} */ trait CodegenFallback extends Expression { - protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { foreach { case n: Nondeterministic => n.setInitialValues() case _ => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 1af7c73cd4bf57d4ae7054371d1f132a8461ed13..88bcf5b4ed369e0f2e624ac43210115d68aea0be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -55,7 +55,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR * Generates the code for comparing a struct type according to its natural ordering * (i.e. ascending order by field 1, then field 2, ..., then field n. */ - def genComparisons(ctx: CodeGenContext, schema: StructType): String = { + def genComparisons(ctx: CodegenContext, schema: StructType): String = { val ordering = schema.fields.map(_.dataType).zipWithIndex.map { case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending) } @@ -65,7 +65,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR /** * Generates the code for ordering based on the given order. */ - def genComparisons(ctx: CodeGenContext, ordering: Seq[SortOrder]): String = { + def genComparisons(ctx: CodegenContext, ordering: Seq[SortOrder]): String = { val comparisons = ordering.map { order => val eval = order.child.gen(ctx) val asc = order.direction == Ascending diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 364dbb770f5e5d1d8e271b9dfaf9cfb9da951dd5..865170764640e68fd7cf86bad3f9cabb7b97bcb2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -40,9 +40,9 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] in.map(BindReferences.bindReference(_, inputSchema)) private def createCodeForStruct( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, - schema: StructType): GeneratedExpressionCode = { + schema: StructType): ExprCode = { val tmp = ctx.freshName("tmp") val output = ctx.freshName("safeRow") val values = ctx.freshName("values") @@ -68,13 +68,13 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final InternalRow $output = new $rowClass($values); """ - GeneratedExpressionCode(code, "false", output) + ExprCode(code, "false", output) } private def createCodeForArray( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, - elementType: DataType): GeneratedExpressionCode = { + elementType: DataType): ExprCode = { val tmp = ctx.freshName("tmp") val output = ctx.freshName("safeArray") val values = ctx.freshName("values") @@ -96,14 +96,14 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final ArrayData $output = new $arrayClass($values); """ - GeneratedExpressionCode(code, "false", output) + ExprCode(code, "false", output) } private def createCodeForMap( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, keyType: DataType, - valueType: DataType): GeneratedExpressionCode = { + valueType: DataType): ExprCode = { val tmp = ctx.freshName("tmp") val output = ctx.freshName("safeMap") val mapClass = classOf[ArrayBasedMapData].getName @@ -117,20 +117,20 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); """ - GeneratedExpressionCode(code, "false", output) + ExprCode(code, "false", output) } private def convertToSafe( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, - dataType: DataType): GeneratedExpressionCode = dataType match { + dataType: DataType): ExprCode = dataType match { case s: StructType => createCodeForStruct(ctx, input, s) case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe. - case StringType => GeneratedExpressionCode("", "false", s"$input.clone()") + case StringType => ExprCode("", "false", s"$input.clone()") case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) - case _ => GeneratedExpressionCode("", "false", input) + case _ => ExprCode("", "false", input) } protected def create(expressions: Seq[Expression]): Projection = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index d0e031f27990c70c6b0cc48c64b1864e2dee607e..3a929927c3f604fe4013be5e13c92a97d156e4f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -48,7 +48,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // TODO: if the nullability of field is correct, we can use it to save null check. private def writeStructToBuffer( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, fieldTypes: Seq[DataType], bufferHolder: String): String = { @@ -56,7 +56,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val fieldName = ctx.freshName("fieldName") val code = s"final ${ctx.javaType(dt)} $fieldName = ${ctx.getValue(input, dt, i.toString)};" val isNull = s"$input.isNullAt($i)" - GeneratedExpressionCode(code, isNull, fieldName) + ExprCode(code, isNull, fieldName) } s""" @@ -69,9 +69,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } private def writeExpressionsToBuffer( - ctx: CodeGenContext, + ctx: CodegenContext, row: String, - inputs: Seq[GeneratedExpressionCode], + inputs: Seq[ExprCode], inputTypes: Seq[DataType], bufferHolder: String): String = { val rowWriter = ctx.freshName("rowWriter") @@ -160,7 +160,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // TODO: if the nullability of array element is correct, we can use it to save null check. private def writeArrayToBuffer( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, elementType: DataType, bufferHolder: String): String = { @@ -232,7 +232,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // TODO: if the nullability of value element is correct, we can use it to save null check. private def writeMapToBuffer( - ctx: CodeGenContext, + ctx: CodegenContext, input: String, keyType: DataType, valueType: DataType, @@ -270,7 +270,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro * If the input is already in unsafe format, we don't need to go through all elements/fields, * we can directly write it. */ - private def writeUnsafeData(ctx: CodeGenContext, input: String, bufferHolder: String) = { + private def writeUnsafeData(ctx: CodegenContext, input: String, bufferHolder: String) = { val sizeInBytes = ctx.freshName("sizeInBytes") s""" final int $sizeInBytes = $input.getSizeInBytes(); @@ -282,9 +282,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } def createCode( - ctx: CodeGenContext, + ctx: CodegenContext, expressions: Seq[Expression], - useSubexprElimination: Boolean = false): GeneratedExpressionCode = { + useSubexprElimination: Boolean = false): ExprCode = { val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) val exprTypes = expressions.map(_.dataType) @@ -305,7 +305,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize()); """ - GeneratedExpressionCode(code, "false", result) + ExprCode(code, "false", result) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7aac2e5e6c1b844ec2f123bf03f64245faad2815..e36c9852491bbbf62166edaa9d4e8501e5d8fe39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -35,7 +35,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType case _: MapType => value.asInstanceOf[MapData].numElements() } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).numElements();") } } @@ -170,7 +170,7 @@ case class ArrayContains(left: Expression, right: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (arr, value) => { val i = ctx.freshName("i") val getValue = ctx.getValue(arr, right.dataType, i) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index d71bbd63c8e89ad6bc658942c4d564212f37dba9..0df8101d9417b892def56e33e3cf9c38a2c84d44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -46,7 +46,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { new GenericArrayData(children.map(_.eval(input)).toArray) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val arrayClass = classOf[GenericArrayData].getName val values = ctx.freshName("values") s""" @@ -94,7 +94,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { InternalRow(children.map(_.eval(input)): _*) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") s""" @@ -171,7 +171,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { InternalRow(valExprs.map(_.eval(input)): _*) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") s""" @@ -223,7 +223,7 @@ case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { InternalRow(children.map(_.eval(input)): _*) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval = GenerateUnsafeProjection.createCode(ctx, children) ev.isNull = eval.isNull ev.value = eval.value @@ -263,7 +263,7 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression InternalRow(valExprs.map(_.eval(input)): _*) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) ev.isNull = eval.isNull ev.value = eval.value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 5bd97cc7467ab345c3ca9f7a3b2e1032a1b0f15d..5256baaf432a2ccaa695e12aed6fb038b32f739a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -113,7 +113,7 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] protected override def nullSafeEval(input: Any): Any = input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, eval => { if (nullable) { s""" @@ -170,7 +170,7 @@ case class GetArrayStructFields( new GenericArrayData(result) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, eval => { s""" @@ -225,7 +225,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" final int index = (int) $eval2; @@ -285,7 +285,7 @@ case class GetMapValue(child: Expression, key: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val index = ctx.freshName("index") val length = ctx.freshName("length") val keys = ctx.freshName("keys") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 83abbcdc61175eb11c3685fd60d3b462a3e24ed4..2a24235a29c9cf3121d9354f8b1474492ad377f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -52,7 +52,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val condEval = predicate.gen(ctx) val trueEval = trueValue.gen(ctx) val falseEval = falseValue.gen(ctx) @@ -136,7 +136,7 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { // Generate code that looks like: // // condA = ... @@ -275,11 +275,11 @@ case class Least(children: Seq[Expression]) extends Expression { }) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val evalChildren = children.map(_.gen(ctx)) val first = evalChildren(0) val rest = evalChildren.drop(1) - def updateEval(eval: GeneratedExpressionCode): String = { + def updateEval(eval: ExprCode): String = { s""" ${eval.code} if (!${eval.isNull} && (${ev.isNull} || @@ -334,11 +334,11 @@ case class Greatest(children: Seq[Expression]) extends Expression { }) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val evalChildren = children.map(_.gen(ctx)) val first = evalChildren(0) val rest = evalChildren.drop(1) - def updateEval(eval: GeneratedExpressionCode): String = { + def updateEval(eval: ExprCode): String = { s""" ${eval.code} if (!${eval.isNull} && (${ev.isNull} || diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 17f1df06f2fad58bc26ef053c6e0957a350507a8..1d0ea68d7a7bf967e0b7839eceb9260b915fa8f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -23,8 +23,8 @@ import java.util.{Calendar, TimeZone} import scala.util.Try import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, - GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, + ExprCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -84,7 +84,7 @@ case class DateAdd(startDate: Expression, days: Expression) start.asInstanceOf[Int] + d.asInstanceOf[Int] } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (sd, d) => { s"""${ev.value} = $sd + $d;""" }) @@ -109,7 +109,7 @@ case class DateSub(startDate: Expression, days: Expression) start.asInstanceOf[Int] - d.asInstanceOf[Int] } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (sd, d) => { s"""${ev.value} = $sd - $d;""" }) @@ -128,7 +128,7 @@ case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInpu DateTimeUtils.getHours(timestamp.asInstanceOf[Long]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getHours($c)") } @@ -144,7 +144,7 @@ case class Minute(child: Expression) extends UnaryExpression with ImplicitCastIn DateTimeUtils.getMinutes(timestamp.asInstanceOf[Long]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c)") } @@ -160,7 +160,7 @@ case class Second(child: Expression) extends UnaryExpression with ImplicitCastIn DateTimeUtils.getSeconds(timestamp.asInstanceOf[Long]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c)") } @@ -176,7 +176,7 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas DateTimeUtils.getDayInYear(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getDayInYear($c)") } @@ -193,7 +193,7 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu DateTimeUtils.getYear(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getYear($c)") } @@ -209,7 +209,7 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI DateTimeUtils.getQuarter(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getQuarter($c)") } @@ -225,7 +225,7 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp DateTimeUtils.getMonth(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getMonth($c)") } @@ -241,7 +241,7 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa DateTimeUtils.getDayOfMonth(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getDayOfMonth($c)") } @@ -265,7 +265,7 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa c.get(Calendar.WEEK_OF_YEAR) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName val c = ctx.freshName("cal") @@ -295,7 +295,7 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx UTF8String.fromString(sdf.format(new java.util.Date(timestamp.asInstanceOf[Long] / 1000))) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val sdf = classOf[SimpleDateFormat].getName defineCodeGen(ctx, ev, (timestamp, format) => { s"""UTF8String.fromString((new $sdf($format.toString())) @@ -386,7 +386,7 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { left.dataType match { case StringType if right.foldable => val sdf = classOf[SimpleDateFormat].getName @@ -503,7 +503,7 @@ case class FromUnixTime(sec: Expression, format: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val sdf = classOf[SimpleDateFormat].getName if (format.foldable) { if (constFormat == null) { @@ -555,7 +555,7 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC DateTimeUtils.getLastDayOfMonth(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, sd => s"$dtu.getLastDayOfMonth($sd)") } @@ -591,7 +591,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) } } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (sd, dowS) => { val dateTimeUtilClass = DateTimeUtils.getClass.getName.stripSuffix("$") val dayOfWeekTerm = ctx.freshName("dayOfWeek") @@ -643,7 +643,7 @@ case class TimeAdd(start: Expression, interval: Expression) start.asInstanceOf[Long], itvl.months, itvl.microseconds) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (sd, i) => { s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds)""" @@ -666,7 +666,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) timezone.asInstanceOf[UTF8String].toString) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { val tz = right.eval() @@ -718,7 +718,7 @@ case class TimeSub(start: Expression, interval: Expression) start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (sd, i) => { s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds)""" @@ -743,7 +743,7 @@ case class AddMonths(startDate: Expression, numMonths: Expression) DateTimeUtils.dateAddMonths(start.asInstanceOf[Int], months.asInstanceOf[Int]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (sd, m) => { s"""$dtu.dateAddMonths($sd, $m)""" @@ -770,7 +770,7 @@ case class MonthsBetween(date1: Expression, date2: Expression) DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (l, r) => { s"""$dtu.monthsBetween($l, $r)""" @@ -795,7 +795,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) timezone.asInstanceOf[UTF8String].toString) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { val tz = right.eval() @@ -840,7 +840,7 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn override def eval(input: InternalRow): Any = child.eval(input) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, d => d) } @@ -882,7 +882,7 @@ case class TruncDate(date: Expression, format: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (format.foldable) { @@ -933,7 +933,7 @@ case class DateDiff(endDate: Expression, startDate: Expression) end.asInstanceOf[Int] - start.asInstanceOf[Int] } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (end, start) => s"$end - $start") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 5f8b544edb51186a88b54a3f2d42fc30c1f690c0..74e86f40c0364a79fbaf561a9be3bc9495fe40af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types._ /** @@ -34,7 +34,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { protected override def nullSafeEval(input: Any): Any = input.asInstanceOf[Decimal].toUnscaledLong - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") } } @@ -53,7 +53,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un protected override def nullSafeEval(input: Any): Any = Decimal(input.asInstanceOf[Long], precision, scale) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, eval => { s""" ${ev.value} = (new Decimal()).setOrNull($eval, $precision, $scale); @@ -70,8 +70,8 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un case class PromotePrecision(child: Expression) extends UnaryExpression { override def dataType: DataType = child.dataType override def eval(input: InternalRow): Any = child.eval(input) - override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" + override def gen(ctx: CodegenContext): ExprCode = child.gen(ctx) + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = "" override def prettyName: String = "promote_precision" override def sql: String = child.sql } @@ -93,7 +93,7 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary } } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, eval => { val tmp = ctx.freshName("tmp") s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e0b020330278bca6bae730807f3fa6ea96750056..db30845fdab6c7926761db991c3af7b13e7a2c23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -171,7 +171,7 @@ case class Literal protected (value: Any, dataType: DataType) override def eval(input: InternalRow): Any = value - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { // change the isNull and primitive to consts, to inline them if (value == null) { ev.isNull = "true" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 66d8631a846ab946b0cea73cbc5c8a9830b6c475..8b9a60f97ce6e43d166d532290e12235e661fdd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -67,7 +67,7 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) // name of function in java.lang.Math def funcName: String = name.toLowerCase - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") } @@ -87,7 +87,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String) if (d <= yAsymptote) null else f(d) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, c => s""" if ($c <= $yAsymptote) { @@ -119,7 +119,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) f(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") } } @@ -172,7 +172,7 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") case DecimalType.Fixed(precision, scale) => @@ -207,7 +207,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre toBase.asInstanceOf[Int]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val numconv = NumberConverter.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (num, from, to) => s""" @@ -240,7 +240,7 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") case DecimalType.Fixed(precision, scale) => @@ -299,7 +299,7 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, eval => { s""" if ($eval > 20 || $eval < 0) { @@ -317,7 +317,7 @@ case class Log(child: Expression) extends UnaryLogExpression(math.log, "LOG") case class Log2(child: Expression) extends UnaryLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, c => s""" if ($c <= $yAsymptote) { @@ -369,7 +369,7 @@ case class Bin(child: Expression) protected override def nullSafeEval(input: Any): Any = UTF8String.fromString(jl.Long.toBinaryString(input.asInstanceOf[Long])) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c) => s"UTF8String.fromString(java.lang.Long.toBinaryString($c))") } @@ -464,7 +464,7 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput case StringType => Hex.hex(num.asInstanceOf[UTF8String].getBytes) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (c) => { val hex = Hex.getClass.getName.stripSuffix("$") s"${ev.value} = " + (child.dataType match { @@ -489,7 +489,7 @@ case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInp protected override def nullSafeEval(num: Any): Any = Hex.unhex(num.asInstanceOf[UTF8String].getBytes) - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (c) => { val hex = Hex.getClass.getName.stripSuffix("$") s""" @@ -516,14 +516,14 @@ case class Atan2(left: Expression, right: Expression) math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") } } case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") } } @@ -549,7 +549,7 @@ case class ShiftLeft(left: Expression, right: Expression) } } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (left, right) => s"$left << $right") } } @@ -575,7 +575,7 @@ case class ShiftRight(left: Expression, right: Expression) } } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (left, right) => s"$left >> $right") } } @@ -601,7 +601,7 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) } } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (left, right) => s"$left >>> $right") } } @@ -635,7 +635,7 @@ case class Logarithm(left: Expression, right: Expression) if (dLeft <= 0.0 || dRight <= 0.0) null else math.log(dRight) / math.log(dLeft) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { if (left.isInstanceOf[EulerNumber]) { nullSafeCodeGen(ctx, ev, (c1, c2) => s""" @@ -758,7 +758,7 @@ case class Round(child: Expression, scale: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val ce = child.gen(ctx) val evaluationCode = child.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 4751fbe4146fe4bdfa7768ab66f192abf788f210..2c12de08f4115c1da6b38bb86beb034586316308 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -47,7 +47,7 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput protected override def nullSafeEval(input: Any): Any = UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[Array[Byte]])) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") } @@ -100,7 +100,7 @@ case class Sha2(left: Expression, right: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val digestUtils = "org.apache.commons.codec.digest.DigestUtils" nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" @@ -145,7 +145,7 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu protected override def nullSafeEval(input: Any): Any = UTF8String.fromString(DigestUtils.shaHex(input.asInstanceOf[Array[Byte]])) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.shaHex($c))" ) @@ -171,7 +171,7 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp checksum.getValue } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val CRC32 = "java.util.zip.CRC32" nullSafeCodeGen(ctx, ev, value => { s""" @@ -323,7 +323,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { ev.isNull = "false" val childrenHash = children.zipWithIndex.map { case (child, dt) => @@ -347,12 +347,12 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression input: String, dataType: DataType, seed: String, - ctx: CodeGenContext): GeneratedExpressionCode = { + ctx: CodegenContext): ExprCode = { val hasher = classOf[Murmur3_x86_32].getName - def hashInt(i: String): GeneratedExpressionCode = inlineValue(s"$hasher.hashInt($i, $seed)") - def hashLong(l: String): GeneratedExpressionCode = inlineValue(s"$hasher.hashLong($l, $seed)") - def inlineValue(v: String): GeneratedExpressionCode = - GeneratedExpressionCode(code = "", isNull = "false", value = v) + def hashInt(i: String): ExprCode = inlineValue(s"$hasher.hashInt($i, $seed)") + def hashLong(l: String): ExprCode = inlineValue(s"$hasher.hashLong($l, $seed)") + def inlineValue(v: String): ExprCode = + ExprCode(code = "", isNull = "false", value = v) dataType match { case NullType => inlineValue(seed) @@ -369,7 +369,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression val code = s"byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();" val offset = "Platform.BYTE_ARRAY_OFFSET" val result = s"$hasher.hashUnsafeBytes($bytes, $offset, $bytes.length, $seed)" - GeneratedExpressionCode(code, "false", result) + ExprCode(code, "false", result) } case CalendarIntervalType => val microsecondsHash = s"$hasher.hashLong($input.microseconds, $seed)" @@ -400,7 +400,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression } } """ - GeneratedExpressionCode(code, "false", result) + ExprCode(code, "false", result) case MapType(kt, vt, _) => val result = ctx.freshName("result") @@ -427,7 +427,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression } } """ - GeneratedExpressionCode(code, "false", result) + ExprCode(code, "false", result) case StructType(fields) => val result = ctx.freshName("result") @@ -448,7 +448,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression int $result = $seed; $fieldsHash """ - GeneratedExpressionCode(code, "false", result) + ExprCode(code, "false", result) case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, seed, ctx) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index b6d7a7f5e8d01389d5b7d96bbfd5e10b7fe83956..7983501ada9bd3ce3fe0363633cc6cfc299a3f20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -133,8 +133,8 @@ case class Alias(child: Expression, name: String)( override def eval(input: InternalRow): Any = child.eval(input) /** Just a simple passthrough for code generation. */ - override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" + override def gen(ctx: CodegenContext): ExprCode = child.gen(ctx) + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = "" override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 89aec2b20fd0c1d5dae125809c030d4491aa0178..667d3513d32b9cb4e3c9606c6310062bdd3917a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -61,7 +61,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { result } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val first = children(0) val rest = children.drop(1) val firstEval = first.gen(ctx) @@ -110,7 +110,7 @@ case class IsNaN(child: Expression) extends UnaryExpression } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval = child.gen(ctx) child.dataType match { case DoubleType | FloatType => @@ -150,7 +150,7 @@ case class NaNvl(left: Expression, right: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val leftGen = left.gen(ctx) val rightGen = right.gen(ctx) left.dataType match { @@ -189,7 +189,7 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { child.eval(input) == null } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval = child.gen(ctx) ev.isNull = "false" ev.value = eval.isNull @@ -210,7 +210,7 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { child.eval(input) != null } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval = child.gen(ctx) ev.isNull = "false" ev.value = s"(!(${eval.isNull}))" @@ -250,7 +250,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate numNonNulls >= n } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val nonnull = ctx.freshName("nonnull") val code = children.map { e => val eval = e.gen(ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 8385f7e1da591789e2b17a2dbba709225dfbcc17..79fe0033b71ab416abe76a25605756198acd34e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -24,7 +24,7 @@ import org.apache.spark.SparkConf import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ @@ -56,7 +56,7 @@ case class StaticInvoke( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val argGen = arguments.map(_.gen(ctx)) val argString = argGen.map(_.value).mkString(", ") @@ -145,7 +145,7 @@ case class Invoke( case _ => identity[String] _ } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val obj = targetObject.gen(ctx) val argGen = arguments.map(_.gen(ctx)) @@ -214,7 +214,7 @@ case class NewInstance( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val argGen = arguments.map(_.gen(ctx)) val argString = argGen.map(_.value).mkString(", ") @@ -277,7 +277,7 @@ case class UnwrapOption( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val inputObject = child.gen(ctx) @@ -309,7 +309,7 @@ case class WrapOption(child: Expression, optType: DataType) override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val inputObject = child.gen(ctx) s""" @@ -332,8 +332,8 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext override def nullable: Boolean = true - override def gen(ctx: CodeGenContext): GeneratedExpressionCode = { - GeneratedExpressionCode(code = "", value = value, isNull = isNull) + override def gen(ctx: CodegenContext): ExprCode = { + ExprCode(code = "", value = value, isNull = isNull) } } @@ -415,7 +415,7 @@ case class MapObjects( override def dataType: DataType = ArrayType(lambdaFunction.dataType) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val elementJavaType = ctx.javaType(loopVar.dataType) val genInputData = inputData.gen(ctx) @@ -491,7 +491,7 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression { override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val rowClass = classOf[GenericRow].getName val values = ctx.freshName("values") s""" @@ -521,7 +521,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends Unary override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { // Code to initialize the serializer. val serializer = ctx.freshName("serializer") val (serializerClass, serializerInstanceClass) = { @@ -560,7 +560,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) extends Unary case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) extends UnaryExpression { - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { // Code to initialize the serializer. val serializer = ctx.freshName("serializer") val (serializerClass, serializerInstanceClass) = { @@ -605,7 +605,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val instanceGen = beanInstance.gen(ctx) val initialize = setters.map { @@ -648,7 +648,7 @@ case class AssertNotNull( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val childGen = child.gen(ctx) ev.isNull = "false" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index bca12a8d21023f035c7010d4eeff171650b1f6c6..a3c10c81c35e5ed4ce416c34a7aa5bd94ffc3aad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -98,7 +98,7 @@ case class Not(child: Expression) protected override def nullSafeEval(input: Any): Any = !input.asInstanceOf[Boolean] - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"!($c)") } @@ -154,7 +154,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val valueGen = value.gen(ctx) val listGen = list.map(_.gen(ctx)) val listCode = listGen.map(x => @@ -213,7 +213,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with def getHSet(): Set[Any] = hset - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val setName = classOf[Set[Any]].getName val InSetName = classOf[InSet].getName val childGen = child.gen(ctx) @@ -267,7 +267,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -318,7 +318,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) @@ -347,7 +347,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P abstract class BinaryComparison extends BinaryOperator with Predicate { - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { if (ctx.isPrimitiveType(left.dataType) && left.dataType != BooleanType // java boolean doesn't support > or < operator && left.dataType != FloatType @@ -394,7 +394,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2)) } } @@ -428,7 +428,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 8de47e9ddc28de3efc3d9dd868d0df3047a95c5b..2e703671fcd66693d43865f760104b17e8ba7c07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, DoubleType} import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -65,7 +65,7 @@ case class Rand(seed: Long) extends RDG { case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, @@ -88,7 +88,7 @@ case class Randn(seed: Long) extends RDG { case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index db266639b8560d1ad4b0e59b0caebe886776b188..b68009331b0ad2394dc1da74cf72cefe69c0855c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -76,7 +76,7 @@ case class Like(left: Expression, right: Expression) override def toString: String = s"$left LIKE $right" - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val patternClass = classOf[Pattern].getName val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" val pattern = ctx.freshName("pattern") @@ -125,7 +125,7 @@ case class RLike(left: Expression, right: Expression) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def toString: String = s"$left RLIKE $right" - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val patternClass = classOf[Pattern].getName val pattern = ctx.freshName("pattern") @@ -182,7 +182,7 @@ case class StringSplit(str: Expression, pattern: Expression) new GenericArrayData(strings.asInstanceOf[Array[Any]]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, (str, pattern) => // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. @@ -238,7 +238,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def children: Seq[Expression] = subject :: regexp :: rep :: Nil override def prettyName: String = "regexp_replace" - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val termLastRegex = ctx.freshName("lastRegex") val termPattern = ctx.freshName("pattern") @@ -318,7 +318,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override def children: Seq[Expression] = subject :: regexp :: idx :: Nil override def prettyName: String = "regexp_extract" - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val termLastRegex = ctx.freshName("lastRegex") val termPattern = ctx.freshName("pattern") val classNamePattern = classOf[Pattern].getCanonicalName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 931f752b4dc1af5188cf82469ee53cfe0eb4a3f5..b965212f2777769f88703ea577c061216c1a5590 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -49,7 +49,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas UTF8String.concat(inputs : _*) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val evals = children.map(_.gen(ctx)) val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.value}" @@ -102,7 +102,7 @@ case class ConcatWs(children: Seq[Expression]) UTF8String.concatWs(flatInputs.head, flatInputs.tail : _*) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { if (children.forall(_.dataType == StringType)) { // All children are strings. In that case we can construct a fixed size array. val evals = children.map(_.gen(ctx)) @@ -183,7 +183,7 @@ case class Upper(child: Expression) override def convert(v: UTF8String): UTF8String = v.toUpperCase - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") } } @@ -198,7 +198,7 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx override def convert(v: UTF8String): UTF8String = v.toLowerCase - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") } } @@ -223,7 +223,7 @@ trait StringPredicate extends Predicate with ImplicitCastInputTypes { case class Contains(left: Expression, right: Expression) extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") } } @@ -234,7 +234,7 @@ case class Contains(left: Expression, right: Expression) case class StartsWith(left: Expression, right: Expression) extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") } } @@ -245,7 +245,7 @@ case class StartsWith(left: Expression, right: Expression) case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") } } @@ -291,7 +291,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac srcEval.asInstanceOf[UTF8String].translate(dict) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val termLastMatching = ctx.freshName("lastMatching") val termLastReplace = ctx.freshName("lastReplace") val termDict = ctx.freshName("dict") @@ -338,7 +338,7 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override protected def nullSafeEval(word: Any, set: Any): Any = set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String]) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (word, set) => s"${ev.value} = $set.findInSet($word);" ) @@ -359,7 +359,7 @@ case class StringTrim(child: Expression) override def prettyName: String = "trim" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).trim()") } } @@ -374,7 +374,7 @@ case class StringTrimLeft(child: Expression) override def prettyName: String = "ltrim" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).trimLeft()") } } @@ -389,7 +389,7 @@ case class StringTrimRight(child: Expression) override def prettyName: String = "rtrim" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).trimRight()") } } @@ -415,7 +415,7 @@ case class StringInstr(str: Expression, substr: Expression) override def prettyName: String = "instr" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0) + 1") } @@ -441,7 +441,7 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: count.asInstanceOf[Int]) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)") } } @@ -484,7 +484,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) } } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { val substrGen = substr.gen(ctx) val strGen = str.gen(ctx) val startGen = start.gen(ctx) @@ -526,7 +526,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) str.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (str, len, pad) => s"$str.lpad($len, $pad)") } @@ -547,7 +547,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) str.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (str, len, pad) => s"$str.rpad($len, $pad)") } @@ -583,7 +583,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val pattern = children.head.gen(ctx) val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx))) @@ -634,7 +634,7 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI override def nullSafeEval(string: Any): Any = { string.asInstanceOf[UTF8String].toTitleCase } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, str => s"$str.toTitleCase()") } } @@ -656,7 +656,7 @@ case class StringRepeat(str: Expression, times: Expression) override def prettyName: String = "repeat" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (l, r) => s"($l).repeat($r)") } } @@ -669,7 +669,7 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2 override def prettyName: String = "reverse" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"($c).reverse()") } } @@ -688,7 +688,7 @@ case class StringSpace(child: Expression) UTF8String.blankString(if (length < 0) 0 else length) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (length) => s"""${ev.value} = UTF8String.blankString(($length < 0) ? 0 : $length);""") } @@ -723,7 +723,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, (string, pos, len) => { str.dataType match { @@ -746,7 +746,7 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy case BinaryType => value.asInstanceOf[Array[Byte]].length } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { child.dataType match { case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()") case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") @@ -766,7 +766,7 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any = leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String]) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (left, right) => s"${ev.value} = $left.levenshteinDistance($right);") } @@ -783,7 +783,7 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT override def nullSafeEval(input: Any): Any = input.asInstanceOf[UTF8String].soundex() - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { defineCodeGen(ctx, ev, c => s"$c.soundex()") } } @@ -805,7 +805,7 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (child) => { val bytes = ctx.freshName("bytes") s""" @@ -833,7 +833,7 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn bytes.asInstanceOf[Array[Byte]])) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (child) => { s"""${ev.value} = UTF8String.fromBytes( org.apache.commons.codec.binary.Base64.encodeBase64($child)); @@ -852,7 +852,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast protected override def nullSafeEval(string: Any): Any = org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (child) => { s""" ${ev.value} = org.apache.commons.codec.binary.Base64.decodeBase64($child.toString()); @@ -878,7 +878,7 @@ case class Decode(bin: Expression, charset: Expression) UTF8String.fromString(new String(input1.asInstanceOf[Array[Byte]], fromCharset)) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (bytes, charset) => s""" try { @@ -908,7 +908,7 @@ case class Encode(value: Expression, charset: Expression) input1.asInstanceOf[UTF8String].toString.getBytes(toCharset) } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (string, charset) => s""" try { @@ -985,7 +985,7 @@ case class FormatNumber(x: Expression, d: Expression) } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { nullSafeCodeGen(ctx, ev, (num, d) => { def typeHelper(p: String): String = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala index 118fd695fe2f56ca3945d90b907e040011d4b87c..ff34b1e37be935b92012239c57258bd51b9954ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala @@ -35,7 +35,7 @@ case class NonFoldableLiteral(value: Any, dataType: DataType) extends LeafExpres override def eval(input: InternalRow): Any = value - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { Literal.create(value, dataType).genCode(ctx, ev) } }