From c43899a04e4de18e238a1761bf4fe9f54e182320 Mon Sep 17 00:00:00 2001
From: Liang-Chi Hsieh <viirya@gmail.com>
Date: Tue, 1 Mar 2016 08:43:02 -0800
Subject: [PATCH] [SPARK-13511] [SQL] Add wholestage codegen for limit

JIRA: https://issues.apache.org/jira/browse/SPARK-13511

## What changes were proposed in this pull request?

Current limit operator doesn't support wholestage codegen. This is open to add support for it.

In the `doConsume` of `GlobalLimit` and `LocalLimit`, we use a count term to count the processed rows. Once the row numbers catches the limit number, we set the variable `stopEarly` of `BufferedRowIterator` newly added in this pr to `true` that indicates we want to stop processing remaining rows. Then when the wholestage codegen framework checks `shouldStop()`, it will stop the processing of the row iterator.

Before this, the executed plan for a query `sqlContext.range(N).limit(100).groupBy().sum()` is:

    TungstenAggregate(key=[], functions=[(sum(id#5L),mode=Final,isDistinct=false)], output=[sum(id)#6L])
    +- TungstenAggregate(key=[], functions=[(sum(id#5L),mode=Partial,isDistinct=false)], output=[sum#9L])
       +- GlobalLimit 100
          +- Exchange SinglePartition, None
             +- LocalLimit 100
                +- Range 0, 1, 1, 524288000, [id#5L]

After add wholestage codegen support:

    WholeStageCodegen
    :  +- TungstenAggregate(key=[], functions=[(sum(id#40L),mode=Final,isDistinct=false)], output=[sum(id)#41L])
    :     +- TungstenAggregate(key=[], functions=[(sum(id#40L),mode=Partial,isDistinct=false)], output=[sum#44L])
    :        +- GlobalLimit 100
    :           +- INPUT
    +- Exchange SinglePartition, None
       +- WholeStageCodegen
          :  +- LocalLimit 100
          :     +- Range 0, 1, 1, 524288000, [id#40L]

## How was this patch tested?

A test is added into BenchmarkWholeStageCodegen.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #11391 from viirya/wholestage-limit.
---
 .../apache/spark/sql/execution/limit.scala    | 35 +++++++++++++++++--
 .../BenchmarkWholeStageCodegen.scala          | 14 ++++++++
 2 files changed, 47 insertions(+), 2 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index cd543d4195..45175d36d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering}
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.exchange.ShuffleExchange
+import org.apache.spark.sql.execution.metric.SQLMetrics
 
 
 /**
@@ -48,7 +49,7 @@ case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode {
 /**
  * Helper trait which defines methods that are shared by both [[LocalLimit]] and [[GlobalLimit]].
  */
-trait BaseLimit extends UnaryNode {
+trait BaseLimit extends UnaryNode with CodegenSupport {
   val limit: Int
   override def output: Seq[Attribute] = child.output
   override def outputOrdering: Seq[SortOrder] = child.outputOrdering
@@ -56,6 +57,36 @@ trait BaseLimit extends UnaryNode {
   protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
     iter.take(limit)
   }
+
+  override def upstreams(): Seq[RDD[InternalRow]] = {
+    child.asInstanceOf[CodegenSupport].upstreams()
+  }
+
+  protected override def doProduce(ctx: CodegenContext): String = {
+    child.asInstanceOf[CodegenSupport].produce(ctx, this)
+  }
+
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+    val stopEarly = ctx.freshName("stopEarly")
+    ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;")
+
+    ctx.addNewFunction("shouldStop", s"""
+      @Override
+      protected boolean shouldStop() {
+        return !currentRows.isEmpty() || $stopEarly;
+      }
+    """)
+    val countTerm = ctx.freshName("count")
+    ctx.addMutableState("int", countTerm, s"$countTerm = 0;")
+    s"""
+       | if ($countTerm < $limit) {
+       |   $countTerm += 1;
+       |   ${consume(ctx, input)}
+       | } else {
+       |   $stopEarly = true;
+       | }
+     """.stripMargin
+  }
 }
 
 /**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 6d6cc0186a..2d3e34d0e1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -70,6 +70,20 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
     */
   }
 
+  ignore("range/limit/sum") {
+    val N = 500 << 20
+    runBenchmark("range/limit/sum", N) {
+      sqlContext.range(N).limit(1000000).groupBy().sum().collect()
+    }
+    /*
+    Westmere E56xx/L56xx/X56xx (Nehalem-C)
+    range/limit/sum:                    Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    range/limit/sum codegen=false             609 /  672        861.6           1.2       1.0X
+    range/limit/sum codegen=true              561 /  621        935.3           1.1       1.1X
+    */
+  }
+
   ignore("stat functions") {
     val N = 100 << 20
 
-- 
GitLab