Skip to content
Snippets Groups Projects
Commit c43899a0 authored by Liang-Chi Hsieh's avatar Liang-Chi Hsieh Committed by Davies Liu
Browse files

[SPARK-13511] [SQL] Add wholestage codegen for limit

JIRA: https://issues.apache.org/jira/browse/SPARK-13511

## What changes were proposed in this pull request?

Current limit operator doesn't support wholestage codegen. This is open to add support for it.

In the `doConsume` of `GlobalLimit` and `LocalLimit`, we use a count term to count the processed rows. Once the row numbers catches the limit number, we set the variable `stopEarly` of `BufferedRowIterator` newly added in this pr to `true` that indicates we want to stop processing remaining rows. Then when the wholestage codegen framework checks `shouldStop()`, it will stop the processing of the row iterator.

Before this, the executed plan for a query `sqlContext.range(N).limit(100).groupBy().sum()` is:

    TungstenAggregate(key=[], functions=[(sum(id#5L),mode=Final,isDistinct=false)], output=[sum(id)#6L])
    +- TungstenAggregate(key=[], functions=[(sum(id#5L),mode=Partial,isDistinct=false)], output=[sum#9L])
       +- GlobalLimit 100
          +- Exchange SinglePartition, None
             +- LocalLimit 100
                +- Range 0, 1, 1, 524288000, [id#5L]

After add wholestage codegen support:

    WholeStageCodegen
    :  +- TungstenAggregate(key=[], functions=[(sum(id#40L),mode=Final,isDistinct=false)], output=[sum(id)#41L])
    :     +- TungstenAggregate(key=[], functions=[(sum(id#40L),mode=Partial,isDistinct=false)], output=[sum#44L])
    :        +- GlobalLimit 100
    :           +- INPUT
    +- Exchange SinglePartition, None
       +- WholeStageCodegen
          :  +- LocalLimit 100
          :     +- Range 0, 1, 1, 524288000, [id#40L]

## How was this patch tested?

A test is added into BenchmarkWholeStageCodegen.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #11391 from viirya/wholestage-limit.
parent 12a2a57e
No related branches found
No related tags found
No related merge requests found
...@@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD ...@@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering}
import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.execution.metric.SQLMetrics
/** /**
...@@ -48,7 +49,7 @@ case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode { ...@@ -48,7 +49,7 @@ case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode {
/** /**
* Helper trait which defines methods that are shared by both [[LocalLimit]] and [[GlobalLimit]]. * Helper trait which defines methods that are shared by both [[LocalLimit]] and [[GlobalLimit]].
*/ */
trait BaseLimit extends UnaryNode { trait BaseLimit extends UnaryNode with CodegenSupport {
val limit: Int val limit: Int
override def output: Seq[Attribute] = child.output override def output: Seq[Attribute] = child.output
override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def outputOrdering: Seq[SortOrder] = child.outputOrdering
...@@ -56,6 +57,36 @@ trait BaseLimit extends UnaryNode { ...@@ -56,6 +57,36 @@ trait BaseLimit extends UnaryNode {
protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
iter.take(limit) iter.take(limit)
} }
override def upstreams(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].upstreams()
}
protected override def doProduce(ctx: CodegenContext): String = {
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val stopEarly = ctx.freshName("stopEarly")
ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;")
ctx.addNewFunction("shouldStop", s"""
@Override
protected boolean shouldStop() {
return !currentRows.isEmpty() || $stopEarly;
}
""")
val countTerm = ctx.freshName("count")
ctx.addMutableState("int", countTerm, s"$countTerm = 0;")
s"""
| if ($countTerm < $limit) {
| $countTerm += 1;
| ${consume(ctx, input)}
| } else {
| $stopEarly = true;
| }
""".stripMargin
}
} }
/** /**
......
...@@ -70,6 +70,20 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { ...@@ -70,6 +70,20 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
*/ */
} }
ignore("range/limit/sum") {
val N = 500 << 20
runBenchmark("range/limit/sum", N) {
sqlContext.range(N).limit(1000000).groupBy().sum().collect()
}
/*
Westmere E56xx/L56xx/X56xx (Nehalem-C)
range/limit/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
range/limit/sum codegen=false 609 / 672 861.6 1.2 1.0X
range/limit/sum codegen=true 561 / 621 935.3 1.1 1.1X
*/
}
ignore("stat functions") { ignore("stat functions") {
val N = 100 << 20 val N = 100 << 20
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment