diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 09007b7c89fe3210d987b9e3c27b20288e80e03c..d7746ca7a052e2f3ed72664549b2db07de610727 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -92,10 +92,10 @@ class CodegenContext { * This is for minor objects not to store the object into field but refer it from the references * field at the time of use because number of fields in class is limited so we should reduce it. */ - def addReferenceObj(obj: Any): String = { + def addReferenceMinorObj(obj: Any, className: String = null): String = { val idx = references.length references += obj - val clsName = obj.getClass.getName + val clsName = Option(className).getOrElse(obj.getClass.getName) s"(($clsName) references[$idx])" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 1985e68c94e2da4dccf46a0cb771be5be743adeb..ab45c41bc0bf94137bea9de73a0bd6c7a8cec7d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -220,7 +220,7 @@ object DecimalLiteral { /** * In order to do type checking, use Literal.create() instead of constructor */ -case class Literal (value: Any, dataType: DataType) extends LeafExpression with CodegenFallback { +case class Literal (value: Any, dataType: DataType) extends LeafExpression { override def foldable: Boolean = true override def nullable: Boolean = value == null @@ -271,45 +271,28 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression with ev.isNull = "true" ev.copy(s"final ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};") } else { - dataType match { - case BooleanType => - ev.isNull = "false" - ev.value = value.toString - ev.copy("") + ev.isNull = "false" + ev.value = dataType match { + case BooleanType | IntegerType | DateType => value.toString case FloatType => val v = value.asInstanceOf[Float] if (v.isNaN || v.isInfinite) { - super[CodegenFallback].doGenCode(ctx, ev) + ctx.addReferenceMinorObj(v) } else { - ev.isNull = "false" - ev.value = s"${value}f" - ev.copy("") + s"${value}f" } case DoubleType => val v = value.asInstanceOf[Double] if (v.isNaN || v.isInfinite) { - super[CodegenFallback].doGenCode(ctx, ev) + ctx.addReferenceMinorObj(v) } else { - ev.isNull = "false" - ev.value = s"${value}D" - ev.copy("") + s"${value}D" } - case ByteType | ShortType => - ev.isNull = "false" - ev.value = s"(${ctx.javaType(dataType)})$value" - ev.copy("") - case IntegerType | DateType => - ev.isNull = "false" - ev.value = value.toString - ev.copy("") - case TimestampType | LongType => - ev.isNull = "false" - ev.value = s"${value}L" - ev.copy("") - // eval() version may be faster for non-primitive types - case other => - super[CodegenFallback].doGenCode(ctx, ev) + case ByteType | ShortType => s"(${ctx.javaType(dataType)})$value" + case TimestampType | LongType => s"${value}L" + case other => ctx.addReferenceMinorObj(value, ctx.javaType(dataType)) } + ev.copy("") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index a874a1cf37086e99566135fb9c3f79e503188cb5..bb9368cf6d77417386ed8ae2b6e01df0a1d13f5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -78,7 +78,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the value is null or false. - val errMsgField = ctx.addReferenceObj(errMsg) + val errMsgField = ctx.addReferenceMinorObj(errMsg) ExprCode(code = s"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index fc323693a24ad894741dfd0855af8d7230604bab..36bf3017d4cdba458fc4db7bc05980b747db5272 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -961,7 +961,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the value is null. - val errMsgField = ctx.addReferenceObj(errMsg) + val errMsgField = ctx.addReferenceMinorObj(errMsg) val code = s""" ${childGen.code} @@ -998,7 +998,7 @@ case class GetExternalRowField( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the field is null. - val errMsgField = ctx.addReferenceObj(errMsg) + val errMsgField = ctx.addReferenceMinorObj(errMsg) val row = child.genCode(ctx) val code = s""" ${row.code} @@ -1038,7 +1038,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the type doesn't match. - val errMsgField = ctx.addReferenceObj(errMsg) + val errMsgField = ctx.addReferenceMinorObj(errMsg) val input = child.genCode(ctx) val obj = input.value