diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index afc190e6978d4c6f57bf1ffa7b92d7363c6d19d6..bacedec1ae2032871383f736a72d9d500cfad2d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -64,19 +64,75 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi val trueEval = trueValue.genCode(ctx) val falseEval = falseValue.genCode(ctx) - ev.copy(code = s""" - ${condEval.code} - boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${condEval.isNull} && ${condEval.value}) { - ${trueEval.code} - ${ev.isNull} = ${trueEval.isNull}; - ${ev.value} = ${trueEval.value}; - } else { - ${falseEval.code} - ${ev.isNull} = ${falseEval.isNull}; - ${ev.value} = ${falseEval.value}; - }""") + // place generated code of condition, true value and false value in separate methods if + // their code combined is large + val combinedLength = condEval.code.length + trueEval.code.length + falseEval.code.length + val generatedCode = if (combinedLength > 1024 && + // Split these expressions only if they are created from a row object + (ctx.INPUT_ROW != null && ctx.currentVars == null)) { + + val (condFuncName, condGlobalIsNull, condGlobalValue) = + createAndAddFunction(ctx, condEval, predicate.dataType, "evalIfCondExpr") + val (trueFuncName, trueGlobalIsNull, trueGlobalValue) = + createAndAddFunction(ctx, trueEval, trueValue.dataType, "evalIfTrueExpr") + val (falseFuncName, falseGlobalIsNull, falseGlobalValue) = + createAndAddFunction(ctx, falseEval, falseValue.dataType, "evalIfFalseExpr") + s""" + $condFuncName(${ctx.INPUT_ROW}); + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!$condGlobalIsNull && $condGlobalValue) { + $trueFuncName(${ctx.INPUT_ROW}); + ${ev.isNull} = $trueGlobalIsNull; + ${ev.value} = $trueGlobalValue; + } else { + $falseFuncName(${ctx.INPUT_ROW}); + ${ev.isNull} = $falseGlobalIsNull; + ${ev.value} = $falseGlobalValue; + } + """ + } + else { + s""" + ${condEval.code} + boolean ${ev.isNull} = false; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${condEval.isNull} && ${condEval.value}) { + ${trueEval.code} + ${ev.isNull} = ${trueEval.isNull}; + ${ev.value} = ${trueEval.value}; + } else { + ${falseEval.code} + ${ev.isNull} = ${falseEval.isNull}; + ${ev.value} = ${falseEval.value}; + } + """ + } + + ev.copy(code = generatedCode) + } + + private def createAndAddFunction( + ctx: CodegenContext, + ev: ExprCode, + dataType: DataType, + baseFuncName: String): (String, String, String) = { + val globalIsNull = ctx.freshName("isNull") + ctx.addMutableState("boolean", globalIsNull, s"$globalIsNull = false;") + val globalValue = ctx.freshName("value") + ctx.addMutableState(ctx.javaType(dataType), globalValue, + s"$globalValue = ${ctx.defaultValue(dataType)};") + val funcName = ctx.freshName(baseFuncName) + val funcBody = + s""" + |private void $funcName(InternalRow ${ctx.INPUT_ROW}) { + | ${ev.code.trim} + | $globalIsNull = ${ev.isNull}; + | $globalValue = ${ev.value}; + |} + """.stripMargin + ctx.addNewFunction(funcName, funcBody) + (funcName, globalIsNull, globalValue) } override def toString: String = s"if ($predicate) $trueValue else $falseValue" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 0cb201e4dae3e649d55692832f8a8ef990ba8e68..0f4b4b5bc8dd6aa14258c080cdf61dccacdc38e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -97,6 +97,27 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(actual(0) == cases) } + test("SPARK-18091: split large if expressions into blocks due to JVM code size limit") { + val inStr = "StringForTesting" + val row = create_row(inStr) + val inputStrAttr = 'a.string.at(0) + + var strExpr: Expression = inputStrAttr + for (_ <- 1 to 13) { + strExpr = If(EqualTo(Decode(Encode(strExpr, "utf-8"), "utf-8"), inputStrAttr), + strExpr, strExpr) + } + + val expressions = Seq(strExpr) + val plan = GenerateUnsafeProjection.generate(expressions, true) + val actual = plan(row).toSeq(expressions.map(_.dataType)) + val expected = Seq(UTF8String.fromString(inStr)) + + if (!checkResult(actual, expected)) { + fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected") + } + } + test("SPARK-14793: split wide array creation into blocks due to JVM code size limit") { val length = 5000 val expressions = Seq(CreateArray(List.fill(length)(EqualTo(Literal(1), Literal(1)))))