Skip to content
Snippets Groups Projects
Commit c100d31d authored by Davies Liu's avatar Davies Liu Committed by Davies Liu
Browse files

[SPARK-13873] [SQL] Avoid copy of UnsafeRow when there is no join in whole stage codegen

## What changes were proposed in this pull request?

We need to copy the UnsafeRow since a Join could produce multiple rows from single input rows. We could avoid that if there is no join (or the join will not produce multiple rows) inside WholeStageCodegen.

Updated the benchmark for `collect`, we could see 20-30% speedup.

## How was this patch tested?

existing unit tests.

Author: Davies Liu <davies@databricks.com>

Closes #11740 from davies/avoid_copy2.
parent 917f4000
No related branches found
No related tags found
No related merge requests found
Showing
with 35 additions and 8 deletions
......@@ -77,6 +77,16 @@ class CodegenContext {
*/
var currentVars: Seq[ExprCode] = null
/**
* Whether should we copy the result rows or not.
*
* If any operator inside WholeStageCodegen generate multiple rows from a single row (for
* example, Join), this should be true.
*
* If an operator starts a new pipeline, this should be reset to false before calling `consume()`.
*/
var copyResult: Boolean = false
/**
* Holding expressions' mutable states like `MonotonicallyIncreasingID.count` as a
* 3-tuple: java type, variable name, code to init it.
......
......@@ -187,6 +187,7 @@ case class Expand(
val i = ctx.freshName("i")
// these column have to declared before the loop.
val evaluate = evaluateVariables(outputColumns)
ctx.copyResult = true
s"""
|$evaluate
|for (int $i = 0; $i < ${projections.length}; $i ++) {
......
......@@ -115,7 +115,8 @@ class GroupedIterator private(
false
} else {
// Skip to next group.
while (input.hasNext && keyOrdering.compare(currentGroup, currentRow) == 0) {
// currentRow may be overwritten by `hasNext`, so we should compare them first.
while (keyOrdering.compare(currentGroup, currentRow) == 0 && input.hasNext) {
currentRow = input.next()
}
......
......@@ -111,7 +111,6 @@ case class Sort(
val needToSort = ctx.freshName("needToSort")
ctx.addMutableState("boolean", needToSort, s"$needToSort = true;")
// Initialize the class member variables. This includes the instance of the Sorter and
// the iterator to return sorted rows.
val thisPlan = ctx.addReferenceObj("plan", this)
......@@ -132,6 +131,10 @@ case class Sort(
| }
""".stripMargin.trim)
// The child could change `copyResult` to true, but we had already consumed all the rows,
// so `copyResult` should be reset to `false`.
ctx.copyResult = false
val outputRow = ctx.freshName("outputRow")
val dataSize = metricTerm(ctx, "dataSize")
val spillSize = metricTerm(ctx, "spillSize")
......
......@@ -379,10 +379,15 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
input: Seq[ExprCode],
row: String = null): String = {
val doCopy = if (ctx.copyResult) {
".copy()"
} else {
""
}
if (row != null) {
// There is an UnsafeRow already
s"""
|append($row.copy());
|append($row$doCopy);
""".stripMargin.trim
} else {
assert(input != null)
......@@ -397,7 +402,7 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
s"""
|$evaluateInputs
|${code.code.trim}
|append(${code.value}.copy());
|append(${code.value}$doCopy);
""".stripMargin.trim
} else {
// There is no columns
......
......@@ -465,6 +465,10 @@ case class TungstenAggregate(
val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan)
val numOutput = metricTerm(ctx, "numOutputRows")
// The child could change `copyResult` to true, but we had already consumed all the rows,
// so `copyResult` should be reset to `false`.
ctx.copyResult = false
s"""
if (!$initAgg) {
$initAgg = true;
......
......@@ -230,6 +230,7 @@ case class BroadcastHashJoin(
""".stripMargin
} else {
ctx.copyResult = true
val matches = ctx.freshName("matches")
val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
val i = ctx.freshName("i")
......@@ -303,6 +304,7 @@ case class BroadcastHashJoin(
""".stripMargin
} else {
ctx.copyResult = true
val matches = ctx.freshName("matches")
val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
val i = ctx.freshName("i")
......
......@@ -404,6 +404,7 @@ case class SortMergeJoin(
}
override def doProduce(ctx: CodegenContext): String = {
ctx.copyResult = true
val leftInput = ctx.freshName("leftInput")
ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];")
val rightInput = ctx.freshName("rightInput")
......
......@@ -457,12 +457,12 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
benchmark.run()
/**
* Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
collect 1 million 775 / 1170 1.4 738.9 1.0X
collect 2 millions 1153 / 1758 0.9 1099.3 0.7X
collect 4 millions 4451 / 5124 0.2 4244.9 0.2X
collect 1 million 439 / 654 2.4 418.7 1.0X
collect 2 millions 961 / 1907 1.1 916.4 0.5X
collect 4 millions 3193 / 3895 0.3 3044.7 0.1X
*/
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment