Skip to content
Snippets Groups Projects
Commit 65854797 authored by Takuya UESHIN's avatar Takuya UESHIN Committed by Wenchen Fan
Browse files

[SPARK-18467][SQL] Extracts method for preparing arguments from StaticInvoke,...

[SPARK-18467][SQL] Extracts method for preparing arguments from StaticInvoke, Invoke and NewInstance and modify to short circuit if arguments have null when `needNullCheck == true`.

## What changes were proposed in this pull request?

This pr extracts method for preparing arguments from `StaticInvoke`, `Invoke` and `NewInstance` and modify to short circuit if arguments have `null` when `propageteNull == true`.

The steps are as follows:

1. Introduce `InvokeLike` to extract common logic from `StaticInvoke`, `Invoke` and `NewInstance` to prepare arguments.
`StaticInvoke` and `Invoke` had a risk to exceed 64kb JVM limit to prepare arguments but after this patch they can handle them because they share the preparing code of NewInstance, which handles the limit well.

2. Remove unneeded null checking and fix nullability of `NewInstance`.
Avoid some of nullabilty checking which are not needed because the expression is not nullable.

3. Modify to short circuit if arguments have `null` when `needNullCheck == true`.
If `needNullCheck == true`, preparing arguments can be skipped if we found one of them is `null`, so modified to short circuit in the case.

## How was this patch tested?

Existing tests.

Author: Takuya UESHIN <ueshin@happy-camper.st>

Closes #15901 from ueshin/issues/SPARK-18467.
parent b625a36e
No related branches found
No related tags found
No related merge requests found
......@@ -32,6 +32,78 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types._
/**
* Common base class for [[StaticInvoke]], [[Invoke]], and [[NewInstance]].
*/
trait InvokeLike extends Expression with NonSQLExpression {
def arguments: Seq[Expression]
def propagateNull: Boolean
protected lazy val needNullCheck: Boolean = propagateNull && arguments.exists(_.nullable)
/**
* Prepares codes for arguments.
*
* - generate codes for argument.
* - use ctx.splitExpressions() to not exceed 64kb JVM limit while preparing arguments.
* - avoid some of nullabilty checking which are not needed because the expression is not
* nullable.
* - when needNullCheck == true, short circuit if we found one of arguments is null because
* preparing rest of arguments can be skipped in the case.
*
* @param ctx a [[CodegenContext]]
* @return (code to prepare arguments, argument string, result of argument null check)
*/
def prepareArguments(ctx: CodegenContext): (String, String, String) = {
val resultIsNull = if (needNullCheck) {
val resultIsNull = ctx.freshName("resultIsNull")
ctx.addMutableState("boolean", resultIsNull, "")
resultIsNull
} else {
"false"
}
val argValues = arguments.map { e =>
val argValue = ctx.freshName("argValue")
ctx.addMutableState(ctx.javaType(e.dataType), argValue, "")
argValue
}
val argCodes = if (needNullCheck) {
val reset = s"$resultIsNull = false;"
val argCodes = arguments.zipWithIndex.map { case (e, i) =>
val expr = e.genCode(ctx)
val updateResultIsNull = if (e.nullable) {
s"$resultIsNull = ${expr.isNull};"
} else {
""
}
s"""
if (!$resultIsNull) {
${expr.code}
$updateResultIsNull
${argValues(i)} = ${expr.value};
}
"""
}
reset +: argCodes
} else {
arguments.zipWithIndex.map { case (e, i) =>
val expr = e.genCode(ctx)
s"""
${expr.code}
${argValues(i)} = ${expr.value};
"""
}
}
val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes)
(argCode, argValues.mkString(", "), resultIsNull)
}
}
/**
* Invokes a static function, returning the result. By default, any of the arguments being null
* will result in returning null instead of calling the function.
......@@ -50,7 +122,7 @@ case class StaticInvoke(
dataType: DataType,
functionName: String,
arguments: Seq[Expression] = Nil,
propagateNull: Boolean = true) extends Expression with NonSQLExpression {
propagateNull: Boolean = true) extends InvokeLike {
val objectName = staticObject.getName.stripSuffix("$")
......@@ -62,16 +134,10 @@ case class StaticInvoke(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
val argGen = arguments.map(_.genCode(ctx))
val argString = argGen.map(_.value).mkString(", ")
val callFunc = s"$objectName.$functionName($argString)"
val (argCode, argString, resultIsNull) = prepareArguments(ctx)
val setIsNull = if (propagateNull && arguments.nonEmpty) {
s"boolean ${ev.isNull} = ${argGen.map(_.isNull).mkString(" || ")};"
} else {
s"boolean ${ev.isNull} = false;"
}
val callFunc = s"$objectName.$functionName($argString)"
// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
......@@ -82,9 +148,9 @@ case class StaticInvoke(
}
val code = s"""
${argGen.map(_.code).mkString("\n")}
$setIsNull
final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;
$argCode
boolean ${ev.isNull} = $resultIsNull;
final $javaType ${ev.value} = $resultIsNull ? ${ctx.defaultValue(dataType)} : $callFunc;
$postNullCheck
"""
ev.copy(code = code)
......@@ -103,13 +169,15 @@ case class StaticInvoke(
* @param functionName The name of the method to call.
* @param dataType The expected return type of the function.
* @param arguments An optional list of expressions, whos evaluation will be passed to the function.
* @param propagateNull When true, and any of the arguments is null, null will be returned instead
* of calling the function.
*/
case class Invoke(
targetObject: Expression,
functionName: String,
dataType: DataType,
arguments: Seq[Expression] = Nil,
propagateNull: Boolean = true) extends Expression with NonSQLExpression {
propagateNull: Boolean = true) extends InvokeLike {
override def nullable: Boolean = true
override def children: Seq[Expression] = targetObject +: arguments
......@@ -131,8 +199,8 @@ case class Invoke(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
val obj = targetObject.genCode(ctx)
val argGen = arguments.map(_.genCode(ctx))
val argString = argGen.map(_.value).mkString(", ")
val (argCode, argString, resultIsNull) = prepareArguments(ctx)
val returnPrimitive = method.isDefined && method.get.getReturnType.isPrimitive
val needTryCatch = method.isDefined && method.get.getExceptionTypes.nonEmpty
......@@ -164,12 +232,6 @@ case class Invoke(
"""
}
val setIsNull = if (propagateNull && arguments.nonEmpty) {
s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};"
} else {
s"boolean ${ev.isNull} = ${obj.isNull};"
}
// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
......@@ -177,15 +239,19 @@ case class Invoke(
} else {
""
}
val code = s"""
${obj.code}
${argGen.map(_.code).mkString("\n")}
$setIsNull
boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
$evaluate
if (!${obj.isNull}) {
$argCode
${ev.isNull} = $resultIsNull;
if (!${ev.isNull}) {
$evaluate
}
$postNullCheck
}
$postNullCheck
"""
ev.copy(code = code)
}
......@@ -223,10 +289,10 @@ case class NewInstance(
arguments: Seq[Expression],
propagateNull: Boolean,
dataType: DataType,
outerPointer: Option[() => AnyRef]) extends Expression with NonSQLExpression {
outerPointer: Option[() => AnyRef]) extends InvokeLike {
private val className = cls.getName
override def nullable: Boolean = propagateNull
override def nullable: Boolean = needNullCheck
override def children: Seq[Expression] = arguments
......@@ -245,52 +311,25 @@ case class NewInstance(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
val argIsNulls = ctx.freshName("argIsNulls")
ctx.addMutableState("boolean[]", argIsNulls,
s"$argIsNulls = new boolean[${arguments.size}];")
val argValues = arguments.zipWithIndex.map { case (e, i) =>
val argValue = ctx.freshName("argValue")
ctx.addMutableState(ctx.javaType(e.dataType), argValue, "")
argValue
}
val argCodes = arguments.zipWithIndex.map { case (e, i) =>
val expr = e.genCode(ctx)
expr.code + s"""
$argIsNulls[$i] = ${expr.isNull};
${argValues(i)} = ${expr.value};
"""
}
val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes)
val (argCode, argString, resultIsNull) = prepareArguments(ctx)
val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx))
var isNull = ev.isNull
val setIsNull = if (propagateNull && arguments.nonEmpty) {
s"""
boolean $isNull = false;
for (int idx = 0; idx < ${arguments.length}; idx++) {
if ($argIsNulls[idx]) { $isNull = true; break; }
}
"""
} else {
isNull = "false"
""
}
ev.isNull = resultIsNull
val constructorCall = outer.map { gen =>
s"""${gen.value}.new ${cls.getSimpleName}(${argValues.mkString(", ")})"""
s"${gen.value}.new ${cls.getSimpleName}($argString)"
}.getOrElse {
s"new $className(${argValues.mkString(", ")})"
s"new $className($argString)"
}
val code = s"""
$argCode
${outer.map(_.code).getOrElse("")}
$setIsNull
final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall;
"""
ev.copy(code = code, isNull = isNull)
final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall;
"""
ev.copy(code = code)
}
override def toString: String = s"newInstance($cls)"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment