From 6ddbf467b41126c894e2a725f2460ba0a1e9292b Mon Sep 17 00:00:00 2001 From: Wenchen Fan <wenchen@databricks.com> Date: Tue, 27 Dec 2016 06:22:12 -0800 Subject: [PATCH] [SPARK-18999][SQL][MINOR] simplify Literal codegen ## What changes were proposed in this pull request? `Literal` can use `CodegenContex.addReferenceObj` to implement codegen, instead of `CodegenFallback`. This can also simplify the generated code a little bit, before we will generate: `((Expression) references[1]).eval(null)`, now it's just `references[1]`. ## How was this patch tested? N/A Author: Wenchen Fan <wenchen@databricks.com> Closes #16402 from cloud-fan/minor. --- .../expressions/codegen/CodeGenerator.scala | 4 +- .../sql/catalyst/expressions/literals.scala | 41 ++++++------------- .../spark/sql/catalyst/expressions/misc.scala | 2 +- .../expressions/objects/objects.scala | 6 +-- 4 files changed, 18 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 09007b7c89..d7746ca7a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -92,10 +92,10 @@ class CodegenContext { * This is for minor objects not to store the object into field but refer it from the references * field at the time of use because number of fields in class is limited so we should reduce it. */ - def addReferenceObj(obj: Any): String = { + def addReferenceMinorObj(obj: Any, className: String = null): String = { val idx = references.length references += obj - val clsName = obj.getClass.getName + val clsName = Option(className).getOrElse(obj.getClass.getName) s"(($clsName) references[$idx])" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 1985e68c94..ab45c41bc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -220,7 +220,7 @@ object DecimalLiteral { /** * In order to do type checking, use Literal.create() instead of constructor */ -case class Literal (value: Any, dataType: DataType) extends LeafExpression with CodegenFallback { +case class Literal (value: Any, dataType: DataType) extends LeafExpression { override def foldable: Boolean = true override def nullable: Boolean = value == null @@ -271,45 +271,28 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression with ev.isNull = "true" ev.copy(s"final ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};") } else { - dataType match { - case BooleanType => - ev.isNull = "false" - ev.value = value.toString - ev.copy("") + ev.isNull = "false" + ev.value = dataType match { + case BooleanType | IntegerType | DateType => value.toString case FloatType => val v = value.asInstanceOf[Float] if (v.isNaN || v.isInfinite) { - super[CodegenFallback].doGenCode(ctx, ev) + ctx.addReferenceMinorObj(v) } else { - ev.isNull = "false" - ev.value = s"${value}f" - ev.copy("") + s"${value}f" } case DoubleType => val v = value.asInstanceOf[Double] if (v.isNaN || v.isInfinite) { - super[CodegenFallback].doGenCode(ctx, ev) + ctx.addReferenceMinorObj(v) } else { - ev.isNull = "false" - ev.value = s"${value}D" - ev.copy("") + s"${value}D" } - case ByteType | ShortType => - ev.isNull = "false" - ev.value = s"(${ctx.javaType(dataType)})$value" - ev.copy("") - case IntegerType | DateType => - ev.isNull = "false" - ev.value = value.toString - ev.copy("") - case TimestampType | LongType => - ev.isNull = "false" - ev.value = s"${value}L" - ev.copy("") - // eval() version may be faster for non-primitive types - case other => - super[CodegenFallback].doGenCode(ctx, ev) + case ByteType | ShortType => s"(${ctx.javaType(dataType)})$value" + case TimestampType | LongType => s"${value}L" + case other => ctx.addReferenceMinorObj(value, ctx.javaType(dataType)) } + ev.copy("") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index a874a1cf37..bb9368cf6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -78,7 +78,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the value is null or false. - val errMsgField = ctx.addReferenceObj(errMsg) + val errMsgField = ctx.addReferenceMinorObj(errMsg) ExprCode(code = s"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index fc323693a2..36bf3017d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -961,7 +961,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the value is null. - val errMsgField = ctx.addReferenceObj(errMsg) + val errMsgField = ctx.addReferenceMinorObj(errMsg) val code = s""" ${childGen.code} @@ -998,7 +998,7 @@ case class GetExternalRowField( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the field is null. - val errMsgField = ctx.addReferenceObj(errMsg) + val errMsgField = ctx.addReferenceMinorObj(errMsg) val row = child.genCode(ctx) val code = s""" ${row.code} @@ -1038,7 +1038,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the type doesn't match. - val errMsgField = ctx.addReferenceObj(errMsg) + val errMsgField = ctx.addReferenceMinorObj(errMsg) val input = child.genCode(ctx) val obj = input.value -- GitLab