From b362239df566bc949283f2ac195ee89af105605a Mon Sep 17 00:00:00 2001
From: Davies Liu <davies@databricks.com>
Date: Wed, 20 Jan 2016 15:24:01 -0800
Subject: [PATCH] [SPARK-12797] [SQL] Generated TungstenAggregate (without
 grouping keys)

As discussed in #10786, the generated TungstenAggregate does not support imperative functions.

For a query
```
sqlContext.range(10).filter("id > 1").groupBy().count()
```

The generated code will looks like:
```
/* 032 */     if (!initAgg0) {
/* 033 */       initAgg0 = true;
/* 034 */
/* 035 */       // initialize aggregation buffer
/* 037 */       long bufValue2 = 0L;
/* 038 */
/* 039 */
/* 040 */       // initialize Range
/* 041 */       if (!range_initRange5) {
/* 042 */         range_initRange5 = true;
       ...
/* 071 */       }
/* 072 */
/* 073 */       while (!range_overflow8 && range_number7 < range_partitionEnd6) {
/* 074 */         long range_value9 = range_number7;
/* 075 */         range_number7 += 1L;
/* 076 */         if (range_number7 < range_value9 ^ 1L < 0) {
/* 077 */           range_overflow8 = true;
/* 078 */         }
/* 079 */
/* 085 */         boolean primitive11 = false;
/* 086 */         primitive11 = range_value9 > 1L;
/* 087 */         if (!false && primitive11) {
/* 092 */           // do aggregate and update aggregation buffer
/* 099 */           long primitive17 = -1L;
/* 100 */           primitive17 = bufValue2 + 1L;
/* 101 */           bufValue2 = primitive17;
/* 105 */         }
/* 107 */       }
/* 109 */
/* 110 */       // output the result
/* 112 */       bufferHolder25.reset();
/* 114 */       rowWriter26.initialize(bufferHolder25, 1);
/* 118 */       rowWriter26.write(0, bufValue2);
/* 120 */       result24.pointTo(bufferHolder25.buffer, bufferHolder25.totalSize());
/* 121 */       currentRow = result24;
/* 122 */       return;
/* 124 */     }
/* 125 */
```

cc nongli

Author: Davies Liu <davies@databricks.com>

Closes #10840 from davies/gen_agg.
---
 .../sql/execution/WholeStageCodegen.scala     | 12 ++-
 .../aggregate/TungstenAggregate.scala         | 87 ++++++++++++++++++-
 .../BenchmarkWholeStageCodegen.scala          |  8 +-
 .../execution/WholeStageCodegenSuite.scala    | 12 +++
 .../execution/metric/SQLMetricsSuite.scala    |  4 +-
 5 files changed, 111 insertions(+), 12 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index c15fabab80..57f4945de9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -264,12 +264,16 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
   */
 private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] {
 
+  private def supportCodegen(e: Expression): Boolean = e match {
+    case e: LeafExpression => true
+    // CodegenFallback requires the input to be an InternalRow
+    case e: CodegenFallback => false
+    case _ => true
+  }
+
   private def supportCodegen(plan: SparkPlan): Boolean = plan match {
     case plan: CodegenSupport if plan.supportCodegen =>
-      // Non-leaf with CodegenFallback does not work with whole stage codegen
-      val willFallback = plan.expressions.exists(
-        _.find(e => e.isInstanceOf[CodegenFallback] && !e.isInstanceOf[LeafExpression]).isDefined
-      )
+      val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined)
       // the generated code will be huge if there are too many columns
       val haveManyColumns = plan.output.length > 200
       !willFallback && !haveManyColumns
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 8dcbab4c8c..23e54f344d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
+import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.types.StructType
 
@@ -35,7 +36,7 @@ case class TungstenAggregate(
     initialInputBufferOffset: Int,
     resultExpressions: Seq[NamedExpression],
     child: SparkPlan)
-  extends UnaryNode {
+  extends UnaryNode with CodegenSupport {
 
   private[this] val aggregateBufferAttributes = {
     aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
@@ -113,6 +114,86 @@ case class TungstenAggregate(
     }
   }
 
+  override def supportCodegen: Boolean = {
+    groupingExpressions.isEmpty &&
+      // ImperativeAggregate is not supported right now
+      !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) &&
+      // final aggregation only have one row, do not need to codegen
+      !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete)
+  }
+
+  // The variables used as aggregation buffer
+  private var bufVars: Seq[ExprCode] = _
+
+  private val modes = aggregateExpressions.map(_.mode).distinct
+
+  protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+    val initAgg = ctx.freshName("initAgg")
+    ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
+
+    // generate variables for aggregation buffer
+    val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
+    val initExpr = functions.flatMap(f => f.initialValues)
+    bufVars = initExpr.map { e =>
+      val isNull = ctx.freshName("bufIsNull")
+      val value = ctx.freshName("bufValue")
+      // The initial expression should not access any column
+      val ev = e.gen(ctx)
+      val initVars = s"""
+         | boolean $isNull = ${ev.isNull};
+         | ${ctx.javaType(e.dataType)} $value = ${ev.value};
+       """.stripMargin
+      ExprCode(ev.code + initVars, isNull, value)
+    }
+
+    val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this)
+    val source =
+      s"""
+         | if (!$initAgg) {
+         |   $initAgg = true;
+         |
+         |   // initialize aggregation buffer
+         |   ${bufVars.map(_.code).mkString("\n")}
+         |
+         |   $childSource
+         |
+         |   // output the result
+         |   ${consume(ctx, bufVars)}
+         | }
+       """.stripMargin
+
+    (rdd, source)
+  }
+
+  override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
+    // only have DeclarativeAggregate
+    val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
+    // the mode could be only Partial or PartialMerge
+    val updateExpr = if (modes.contains(Partial)) {
+      functions.flatMap(_.updateExpressions)
+    } else {
+      functions.flatMap(_.mergeExpressions)
+    }
+
+    val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output
+    val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr))
+    ctx.currentVars = bufVars ++ input
+    // TODO: support subexpression elimination
+    val codes = boundExpr.zipWithIndex.map { case (e, i) =>
+      val ev = e.gen(ctx)
+      s"""
+         | ${ev.code}
+         | ${bufVars(i).isNull} = ${ev.isNull};
+         | ${bufVars(i).value} = ${ev.value};
+       """.stripMargin
+    }
+
+    s"""
+       | // do aggregate and update aggregation buffer
+       | ${codes.mkString("")}
+     """.stripMargin
+  }
+
   override def simpleString: String = {
     val allAggregateExpressions = aggregateExpressions
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 788b04fcf8..c4aad398bf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -46,10 +46,10 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
 
     /*
       Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
-      Single Int Column Scan:      Avg Time(ms)    Avg Rate(M/s)  Relative Rate
-      -------------------------------------------------------------------------
-      Without whole stage codegen       6725.52            31.18         1.00 X
-      With whole stage codegen          2233.05            93.91         3.01 X
+      Single Int Column Scan:            Avg Time(ms)    Avg Rate(M/s)  Relative Rate
+      -------------------------------------------------------------------------------
+      Without whole stage codegen             7775.53            26.97         1.00 X
+      With whole stage codegen                 342.15           612.94        22.73 X
     */
     benchmark.run()
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index c54fc6ba2d..300788c88a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -17,7 +17,10 @@
 
 package org.apache.spark.sql.execution
 
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.execution.aggregate.TungstenAggregate
+import org.apache.spark.sql.functions.{avg, col, max}
 import org.apache.spark.sql.test.SharedSQLContext
 
 class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
@@ -35,4 +38,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
       sortAnswers = false
     )
   }
+
+  test("Aggregate should be included in WholeStageCodegen") {
+    val df = sqlContext.range(10).groupBy().agg(max(col("id")), avg(col("id")))
+    val plan = df.queryExecution.executedPlan
+    assert(plan.find(p =>
+      p.isInstanceOf[WholeStageCodegen] &&
+        p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
+    assert(df.collect() === Array(Row(9, 4.5)))
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 4339f7260d..51285431a4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -71,7 +71,9 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
       expectedNumOfJobs: Int,
       expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = {
     val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
-    df.collect()
+    withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
+      df.collect()
+    }
     sparkContext.listenerBus.waitUntilEmpty(10000)
     val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
     assert(executionIds.size === 1)
-- 
GitLab