diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index f7162e420d19a727a67edc5ddfc21ca319486112..affd1bdb327c30576e0cd5691e20022d8f233a88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback + /** * This class is used to compute equality of (sub)expression trees. Expressions can be added * to this class and they subsequently query for expression equality. Expression trees are @@ -67,7 +69,8 @@ class EquivalentExpressions { */ def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = { val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf - if (!skip && !addExpr(root)) { + // the children of CodegenFallback will not be used to generate code (call eval() instead) + if (!skip && !addExpr(root) && !root.isInstanceOf[CodegenFallback]) { root.children.foreach(addExprTree(_, ignoreLeaf)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 25cf210c4b527642761242e9803e9bd1b2d1f827..db17ba7c84ffc8facf5bcfa5019f3601dcc4de36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -100,8 +100,8 @@ abstract class Expression extends TreeNode[Expression] { ExprCode(code, subExprState.isNull, subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") - val primitive = ctx.freshName("primitive") - val ve = ExprCode("", isNull, primitive) + val value = ctx.freshName("value") + val ve = ExprCode("", isNull, value) ve.code = genCode(ctx, ve) // Add `this` in the comment. ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 683029ff144d886fb43976e6d8c50bbc77b0aa0f..2747c315ad374e1c2e393b000262e1b8c06b1b2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -125,7 +125,7 @@ class CodegenContext { val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] // The collection of sub-exression result resetting methods that need to be called on each row. - val subExprResetVariables = mutable.ArrayBuffer.empty[String] + val subexprFunctions = mutable.ArrayBuffer.empty[String] def declareAddedFunctions(): String = { addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n") @@ -424,9 +424,9 @@ class CodegenContext { val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) commonExprs.foreach(e => { val expr = e.head - val isNull = freshName("isNull") - val value = freshName("value") val fnName = freshName("evalExpr") + val isNull = s"${fnName}IsNull" + val value = s"${fnName}Value" // Generate the code for this expression tree and wrap it in a function. val code = expr.gen(this) @@ -461,7 +461,7 @@ class CodegenContext { addMutableState(javaType(expr.dataType), value, s"$value = ${defaultValue(expr.dataType)};") - subExprResetVariables += s"$fnName($INPUT_ROW);" + subexprFunctions += s"$fnName($INPUT_ROW);" val state = SubExprEliminationState(isNull, value) e.foreach(subExprEliminationExprs.put(_, state)) }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 59ef0f5836a3cb3ea5a854d44bb76477765282c7..d9fe76133c6effa95b8dde3524560484af439c09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -38,12 +38,29 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) + def generate( + expressions: Seq[Expression], + inputSchema: Seq[Attribute], + useSubexprElimination: Boolean): (() => MutableProjection) = { + create(canonicalize(bind(expressions, inputSchema)), useSubexprElimination) + } + protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { + create(expressions, false) + } + + private def create( + expressions: Seq[Expression], + useSubexprElimination: Boolean): (() => MutableProjection) = { val ctx = newCodeGenContext() - val projectionCodes = expressions.zipWithIndex.map { - case (NoOp, _) => "" - case (e, i) => - val evaluationCode = e.gen(ctx) + val (validExpr, index) = expressions.zipWithIndex.filter { + case (NoOp, _) => false + case _ => true + }.unzip + val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination) + val projectionCodes = exprVals.zip(index).map { + case (ev, i) => + val e = expressions(i) if (e.nullable) { val isNull = s"isNull_$i" val value = s"value_$i" @@ -51,22 +68,25 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu ctx.addMutableState(ctx.javaType(e.dataType), value, s"this.$value = ${ctx.defaultValue(e.dataType)};") s""" - ${evaluationCode.code} - this.$isNull = ${evaluationCode.isNull}; - this.$value = ${evaluationCode.value}; + ${ev.code} + this.$isNull = ${ev.isNull}; + this.$value = ${ev.value}; """ } else { val value = s"value_$i" ctx.addMutableState(ctx.javaType(e.dataType), value, s"this.$value = ${ctx.defaultValue(e.dataType)};") s""" - ${evaluationCode.code} - this.$value = ${evaluationCode.value}; + ${ev.code} + this.$value = ${ev.value}; """ } } - val updates = expressions.zipWithIndex.map { - case (NoOp, _) => "" + + // Evaluate all the the subexpressions. + val evalSubexpr = ctx.subexprFunctions.mkString("\n") + + val updates = validExpr.zip(index).map { case (e, i) => if (e.nullable) { if (e.dataType.isInstanceOf[DecimalType]) { @@ -128,6 +148,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; + $evalSubexpr $allProjections // copy all the results into MutableRow $allUpdates diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 61e7469ee4be276a81fa1aaa307c15bb442805c8..72bf39a0398b17792a28b9712a4a03b5f31e0a20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -294,13 +294,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val holderClass = classOf[BufferHolder].getName ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") - // Reset the subexpression values for each row. - val subexprReset = ctx.subExprResetVariables.mkString("\n") + // Evaluate all the subexpression. + val evalSubexpr = ctx.subexprFunctions.mkString("\n") val code = s""" $bufferHolder.reset(); - $subexprReset + $evalSubexpr ${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)} $result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize()); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index a61297b2c0395e39df76f5ab707b12b784bb8716..43a3eb9dec97c102c5376bea6a70644ce6eb280a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -154,4 +154,17 @@ class SubexpressionEliminationSuite extends SparkFunSuite { equivalence.addExpr(sum) assert(equivalence.getAllEquivalentExprs.isEmpty) } + + test("Children of CodegenFallback") { + val one = Literal(1) + val two = Add(one, one) + val explode = Explode(two) + val add = Add(two, explode) + + var equivalence = new EquivalentExpressions + equivalence.addExprTree(add, true) + // the `two` inside `explode` should not be added + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 0) + assert(equivalence.getAllEquivalentExprs.filter(_.size == 1).size == 3) // add, two, explode + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 75101ea0fc6d20d9a9865d60f240e2ade16033cd..b19b772409d836100cb83dfffeada6881589eff5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -196,10 +196,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ private[this] def isTesting: Boolean = sys.props.contains("spark.testing") protected def newMutableProjection( - expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = { + expressions: Seq[Expression], + inputSchema: Seq[Attribute], + useSubexprElimination: Boolean = false): () => MutableProjection = { log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema") try { - GenerateMutableProjection.generate(expressions, inputSchema) + GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination) } catch { case e: Exception => if (isTesting) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 168b5ab0316d1f80e563d7f6396db2d6bc87d75d..26a7340f1ae10c9fb43c83828d503fcdf444be0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -194,7 +194,11 @@ case class Window( val functions = functionSeq.toArray // Construct an aggregate processor if we need one. - def processor = AggregateProcessor(functions, ordinal, child.output, newMutableProjection) + def processor = AggregateProcessor( + functions, + ordinal, + child.output, + (expressions, schema) => newMutableProjection(expressions, schema)) // Create the factory val factory = key match { @@ -206,7 +210,7 @@ case class Window( ordinal, functions, child.output, - newMutableProjection, + (expressions, schema) => newMutableProjection(expressions, schema), offset) // Growing Frame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index 1d56592c40b960f882ad3bd93000602b46ec30b9..06a3991459f08b292e38fd6c4bbca08322aa0015 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -87,7 +87,8 @@ case class SortBasedAggregate( aggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection, + (expressions, inputSchema) => + newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), numInputRows, numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index a9cf04388d2e8c612a5fe1f62501290cdb12187b..8dcbab4c8cfbc6ca5c9204eb64ca2a246951d2d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -94,7 +94,8 @@ case class TungstenAggregate( aggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection, + (expressions, inputSchema) => + newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), child.output, iter, testFallbackStartsAt, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d7f182352b4c9c1d72a9567e67d14386799b6f31..b159346bed9f72e9cd9cd3edeaf28fc4d57e3c03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.parser.ParserConf import org.apache.spark.sql.execution.{aggregate, SparkQl} import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} import org.apache.spark.sql.test.SQLTestData._ @@ -1968,6 +1969,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { verifyCallCount( df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + val testUdf = functions.udf((x: Int) => { + countAcc.++=(1) + x + }) + verifyCallCount( + df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) + // Would be nice if semantic equals for `+` understood commutative verifyCallCount( df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2)