diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index a8018b9213f2bad1c6aff14492ce42c519b6aeb2..861a2c21ad9a09f646c2411e17b2fbfff3cf08cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -99,7 +99,10 @@ case class GeneratedAggregate( // but really, common sub expression elimination would be better.... val zero = Cast(Literal(0), calcType) val updateFunction = Coalesce( - Add(Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType)) :: currentSum :: Nil) + Add( + Coalesce(currentSum :: zero :: Nil), + Cast(expr, calcType) + ) :: currentSum :: zero :: Nil) val result = expr.dataType match { case DecimalType.Fixed(_, _) => @@ -109,6 +112,45 @@ case class GeneratedAggregate( AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + case cs @ CombineSum(expr) => + val calcType = expr.dataType + expr.dataType match { + case DecimalType.Fixed(_, _) => + DecimalType.Unlimited + case _ => + expr.dataType + } + + val currentSum = AttributeReference("currentSum", calcType, nullable = true)() + val initialValue = Literal.create(null, calcType) + + // Coalasce avoids double calculation... + // but really, common sub expression elimination would be better.... + val zero = Cast(Literal(0), calcType) + // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its + // UnscaledValue will be null if and only if x is null; helps with Average on decimals + val actualExpr = expr match { + case UnscaledValue(e) => e + case _ => expr + } + // partial sum result can be null only when no input rows present + val updateFunction = If( + IsNotNull(actualExpr), + Coalesce( + Add( + Coalesce(currentSum :: zero :: Nil), + Cast(expr, calcType)) :: currentSum :: zero :: Nil), + currentSum) + + val result = + expr.dataType match { + case DecimalType.Fixed(_, _) => + Cast(currentSum, cs.dataType) + case _ => currentSum + } + + AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) + case a @ Average(expr) => val calcType = expr.dataType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f754fa770d1b57ee202d8579c5229b401d98a75d..23f7e5609414b9edfd2eaeb075999574b606e445 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -155,7 +155,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists { - case _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false + case _: CombineSum | _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && Seq(IntegerType, LongType).contains(exprs.head.dataType) => false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 87e7cf8c8af9fbd7fdcf893145792d0b404adc6e..1ad92a394118738f6bbcdd1cf859a2ef408ac290 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.test.TestSQLContext import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -102,14 +103,99 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT ABS(2.5)"), Row(2.5)) } - + test("aggregation with codegen") { val originalValue = conf.codegenEnabled setConf(SQLConf.CODEGEN_ENABLED, "true") - sql("SELECT key FROM testData GROUP BY key").collect() + // Prepare a table that we can group some rows. + table("testData") + .unionAll(table("testData")) + .unionAll(table("testData")) + .registerTempTable("testData3x") + + def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { + val df = sql(sqlText) + // First, check if we have GeneratedAggregate. + var hasGeneratedAgg = false + df.queryExecution.executedPlan.foreach { + case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true + case _ => + } + if (!hasGeneratedAgg) { + fail( + s""" + |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan. + |${df.queryExecution.simpleString} + """.stripMargin) + } + // Then, check results. + checkAnswer(df, expectedResults) + } + + // Just to group rows. + testCodeGen( + "SELECT key FROM testData3x GROUP BY key", + (1 to 100).map(Row(_))) + // COUNT + testCodeGen( + "SELECT key, count(value) FROM testData3x GROUP BY key", + (1 to 100).map(i => Row(i, 3))) + testCodeGen( + "SELECT count(key) FROM testData3x", + Row(300) :: Nil) + // COUNT DISTINCT ON int + testCodeGen( + "SELECT value, count(distinct key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 1))) + testCodeGen( + "SELECT count(distinct key) FROM testData3x", + Row(100) :: Nil) + // SUM + testCodeGen( + "SELECT value, sum(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, 3 * i))) + testCodeGen( + "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x", + Row(5050 * 3, 5050 * 3.0) :: Nil) + // AVERAGE + testCodeGen( + "SELECT value, avg(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT avg(key) FROM testData3x", + Row(50.5) :: Nil) + // MAX + testCodeGen( + "SELECT value, max(key) FROM testData3x GROUP BY value", + (1 to 100).map(i => Row(i.toString, i))) + testCodeGen( + "SELECT max(key) FROM testData3x", + Row(100) :: Nil) + // Some combinations. + testCodeGen( + """ + |SELECT + | value, + | sum(key), + | max(key), + | avg(key), + | count(key), + | count(distinct key) + |FROM testData3x + |GROUP BY value + """.stripMargin, + (1 to 100).map(i => Row(i.toString, i*3, i, i, 3, 1))) + testCodeGen( + "SELECT max(key), avg(key), count(key), count(distinct key) FROM testData3x", + Row(100, 50.5, 300, 100) :: Nil) + // Aggregate with Code generation handling all null values + testCodeGen( + "SELECT sum('a'), avg('a'), count(null) FROM testData", + Row(0, null, 0) :: Nil) + + dropTempTable("testData3x") setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } - test("Add Parser of SQL COALESCE()") { checkAnswer( sql("""SELECT COALESCE(1, 2)"""),