Skip to content
Snippets Groups Projects
Commit e6a02c66 authored by Davies Liu's avatar Davies Liu Committed by Davies Liu
Browse files

[SPARK-12914] [SQL] generate aggregation with grouping keys

This PR add support for grouping keys for generated TungstenAggregate.

Spilling and performance improvements for BytesToBytesMap will be done by followup PR.

Author: Davies Liu <davies@databricks.com>

Closes #10855 from davies/gen_keys.
parent 12252d1d
No related branches found
No related tags found
No related merge requests found
...@@ -55,6 +55,20 @@ class CodegenContext { ...@@ -55,6 +55,20 @@ class CodegenContext {
*/ */
val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]() val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]()
/**
* Add an object to `references`, create a class member to access it.
*
* Returns the name of class member.
*/
def addReferenceObj(name: String, obj: Any, className: String = null): String = {
val term = freshName(name)
val idx = references.length
references += obj
val clsName = Option(className).getOrElse(obj.getClass.getName)
addMutableState(clsName, term, s"this.$term = ($clsName) references[$idx];")
term
}
/** /**
* Holding a list of generated columns as input of current operator, will be used by * Holding a list of generated columns as input of current operator, will be used by
* BoundReference to generate code. * BoundReference to generate code.
...@@ -198,6 +212,39 @@ class CodegenContext { ...@@ -198,6 +212,39 @@ class CodegenContext {
} }
} }
/**
* Update a column in MutableRow from ExprCode.
*/
def updateColumn(
row: String,
dataType: DataType,
ordinal: Int,
ev: ExprCode,
nullable: Boolean): String = {
if (nullable) {
// Can't call setNullAt on DecimalType, because we need to keep the offset
if (dataType.isInstanceOf[DecimalType]) {
s"""
if (!${ev.isNull}) {
${setColumn(row, dataType, ordinal, ev.value)};
} else {
${setColumn(row, dataType, ordinal, "null")};
}
"""
} else {
s"""
if (!${ev.isNull}) {
${setColumn(row, dataType, ordinal, ev.value)};
} else {
$row.setNullAt($ordinal);
}
"""
}
} else {
s"""${setColumn(row, dataType, ordinal, ev.value)};"""
}
}
/** /**
* Returns the name used in accessor and setter for a Java primitive type. * Returns the name used in accessor and setter for a Java primitive type.
*/ */
......
...@@ -88,31 +88,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu ...@@ -88,31 +88,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
val updates = validExpr.zip(index).map { val updates = validExpr.zip(index).map {
case (e, i) => case (e, i) =>
if (e.nullable) { val ev = ExprCode("", s"this.isNull_$i", s"this.value_$i")
if (e.dataType.isInstanceOf[DecimalType]) { ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
// Can't call setNullAt on DecimalType, because we need to keep the offset
s"""
if (this.isNull_$i) {
${ctx.setColumn("mutableRow", e.dataType, i, "null")};
} else {
${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
}
"""
} else {
s"""
if (this.isNull_$i) {
mutableRow.setNullAt($i);
} else {
${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
}
"""
}
} else {
s"""
${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
"""
}
} }
val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes)
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution; package org.apache.spark.sql.execution;
import java.io.IOException;
import scala.collection.Iterator; import scala.collection.Iterator;
import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.InternalRow;
...@@ -34,7 +36,7 @@ public class BufferedRowIterator { ...@@ -34,7 +36,7 @@ public class BufferedRowIterator {
// used when there is no column in output // used when there is no column in output
protected UnsafeRow unsafeRow = new UnsafeRow(0); protected UnsafeRow unsafeRow = new UnsafeRow(0);
public boolean hasNext() { public boolean hasNext() throws IOException {
if (currentRow == null) { if (currentRow == null) {
processNext(); processNext();
} }
...@@ -56,7 +58,7 @@ public class BufferedRowIterator { ...@@ -56,7 +58,7 @@ public class BufferedRowIterator {
* *
* After it's called, if currentRow is still null, it means no more rows left. * After it's called, if currentRow is still null, it means no more rows left.
*/ */
protected void processNext() { protected void processNext() throws IOException {
if (input.hasNext()) { if (input.hasNext()) {
currentRow = input.next(); currentRow = input.next();
} }
......
...@@ -17,16 +17,18 @@ ...@@ -17,16 +17,18 @@
package org.apache.spark.sql.execution.aggregate package org.apache.spark.sql.execution.aggregate
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.{DecimalType, StructType}
import org.apache.spark.unsafe.KVIterator
case class TungstenAggregate( case class TungstenAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]], requiredChildDistributionExpressions: Option[Seq[Expression]],
...@@ -114,22 +116,38 @@ case class TungstenAggregate( ...@@ -114,22 +116,38 @@ case class TungstenAggregate(
} }
} }
// all the mode of aggregate expressions
private val modes = aggregateExpressions.map(_.mode).distinct
override def supportCodegen: Boolean = { override def supportCodegen: Boolean = {
groupingExpressions.isEmpty && // ImperativeAggregate is not supported right now
// ImperativeAggregate is not supported right now !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
} }
// The variables used as aggregation buffer
private var bufVars: Seq[ExprCode] = _
private val modes = aggregateExpressions.map(_.mode).distinct
override def upstream(): RDD[InternalRow] = { override def upstream(): RDD[InternalRow] = {
child.asInstanceOf[CodegenSupport].upstream() child.asInstanceOf[CodegenSupport].upstream()
} }
protected override def doProduce(ctx: CodegenContext): String = { protected override def doProduce(ctx: CodegenContext): String = {
if (groupingExpressions.isEmpty) {
doProduceWithoutKeys(ctx)
} else {
doProduceWithKeys(ctx)
}
}
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
if (groupingExpressions.isEmpty) {
doConsumeWithoutKeys(ctx, input)
} else {
doConsumeWithKeys(ctx, input)
}
}
// The variables used as aggregation buffer
private var bufVars: Seq[ExprCode] = _
private def doProduceWithoutKeys(ctx: CodegenContext): String = {
val initAgg = ctx.freshName("initAgg") val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
...@@ -176,10 +194,10 @@ case class TungstenAggregate( ...@@ -176,10 +194,10 @@ case class TungstenAggregate(
(resultVars, resultVars.map(_.code).mkString("\n")) (resultVars, resultVars.map(_.code).mkString("\n"))
} }
val doAgg = ctx.freshName("doAgg") val doAgg = ctx.freshName("doAggregateWithoutKey")
ctx.addNewFunction(doAgg, ctx.addNewFunction(doAgg,
s""" s"""
| private void $doAgg() { | private void $doAgg() throws java.io.IOException {
| // initialize aggregation buffer | // initialize aggregation buffer
| ${bufVars.map(_.code).mkString("\n")} | ${bufVars.map(_.code).mkString("\n")}
| |
...@@ -200,7 +218,7 @@ case class TungstenAggregate( ...@@ -200,7 +218,7 @@ case class TungstenAggregate(
""".stripMargin """.stripMargin
} }
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// only have DeclarativeAggregate // only have DeclarativeAggregate
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output
...@@ -212,7 +230,6 @@ case class TungstenAggregate( ...@@ -212,7 +230,6 @@ case class TungstenAggregate(
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
} }
} }
ctx.currentVars = bufVars ++ input ctx.currentVars = bufVars ++ input
// TODO: support subexpression elimination // TODO: support subexpression elimination
val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).gen(ctx)) val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).gen(ctx))
...@@ -232,6 +249,199 @@ case class TungstenAggregate( ...@@ -232,6 +249,199 @@ case class TungstenAggregate(
""".stripMargin """.stripMargin
} }
private val groupingAttributes = groupingExpressions.map(_.toAttribute)
private val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
private val declFunctions = aggregateExpressions.map(_.aggregateFunction)
.filter(_.isInstanceOf[DeclarativeAggregate])
.map(_.asInstanceOf[DeclarativeAggregate])
private val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes)
private val bufferSchema = StructType.fromAttributes(bufferAttributes)
// The name for HashMap
private var hashMapTerm: String = _
/**
* This is called by generated Java class, should be public.
*/
def createHashMap(): UnsafeFixedWidthAggregationMap = {
// create initialized aggregate buffer
val initExpr = declFunctions.flatMap(f => f.initialValues)
val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow)
// create hashMap
new UnsafeFixedWidthAggregationMap(
initialBuffer,
bufferSchema,
groupingKeySchema,
TaskContext.get().taskMemoryManager(),
1024 * 16, // initial capacity
TaskContext.get().taskMemoryManager().pageSizeBytes,
false // disable tracking of performance metrics
)
}
/**
* This is called by generated Java class, should be public.
*/
def createUnsafeJoiner(): UnsafeRowJoiner = {
GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
}
/**
* Update peak execution memory, called in generated Java class.
*/
def updatePeakMemory(hashMap: UnsafeFixedWidthAggregationMap): Unit = {
val mapMemory = hashMap.getPeakMemoryUsedBytes
val metrics = TaskContext.get().taskMetrics()
metrics.incPeakExecutionMemory(mapMemory)
}
private def doProduceWithKeys(ctx: CodegenContext): String = {
val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
// create hashMap
val thisPlan = ctx.addReferenceObj("plan", this)
hashMapTerm = ctx.freshName("hashMap")
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
// Create a name for iterator from HashMap
val iterTerm = ctx.freshName("mapIter")
ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")
// generate code for output
val keyTerm = ctx.freshName("aggKey")
val bufferTerm = ctx.freshName("aggBuffer")
val outputCode = if (modes.contains(Final) || modes.contains(Complete)) {
// generate output using resultExpressions
ctx.currentVars = null
ctx.INPUT_ROW = keyTerm
val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
BoundReference(i, e.dataType, e.nullable).gen(ctx)
}
ctx.INPUT_ROW = bufferTerm
val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) =>
BoundReference(i, e.dataType, e.nullable).gen(ctx)
}
// evaluate the aggregation result
ctx.currentVars = bufferVars
val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
BindReferences.bindReference(e, bufferAttributes).gen(ctx)
}
// generate the final result
ctx.currentVars = keyVars ++ aggResults
val inputAttrs = groupingAttributes ++ aggregateAttributes
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, inputAttrs).gen(ctx)
}
s"""
${keyVars.map(_.code).mkString("\n")}
${bufferVars.map(_.code).mkString("\n")}
${aggResults.map(_.code).mkString("\n")}
${resultVars.map(_.code).mkString("\n")}
${consume(ctx, resultVars)}
"""
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
// This should be the last operator in a stage, we should output UnsafeRow directly
val joinerTerm = ctx.freshName("unsafeRowJoiner")
ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm,
s"$joinerTerm = $thisPlan.createUnsafeJoiner();")
val resultRow = ctx.freshName("resultRow")
s"""
UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm);
${consume(ctx, null, resultRow)}
"""
} else {
// generate result based on grouping key
ctx.INPUT_ROW = keyTerm
ctx.currentVars = null
val eval = resultExpressions.map{ e =>
BindReferences.bindReference(e, groupingAttributes).gen(ctx)
}
s"""
${eval.map(_.code).mkString("\n")}
${consume(ctx, eval)}
"""
}
val doAgg = ctx.freshName("doAggregateWithKeys")
ctx.addNewFunction(doAgg,
s"""
private void $doAgg() throws java.io.IOException {
${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
$iterTerm = $hashMapTerm.iterator();
}
""")
s"""
if (!$initAgg) {
$initAgg = true;
$doAgg();
}
// output the result
while ($iterTerm.next()) {
UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
$outputCode
}
$thisPlan.updatePeakMemory($hashMapTerm);
$hashMapTerm.free();
"""
}
private def doConsumeWithKeys( ctx: CodegenContext, input: Seq[ExprCode]): String = {
// create grouping key
ctx.currentVars = input
val keyCode = GenerateUnsafeProjection.createCode(
ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
val key = keyCode.value
val buffer = ctx.freshName("aggBuffer")
// only have DeclarativeAggregate
val updateExpr = aggregateExpressions.flatMap { e =>
e.mode match {
case Partial | Complete =>
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
case PartialMerge | Final =>
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
}
}
val inputAttr = bufferAttributes ++ child.output
ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input
ctx.INPUT_ROW = buffer
// TODO: support subexpression elimination
val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx))
val updates = evals.zipWithIndex.map { case (ev, i) =>
val dt = updateExpr(i).dataType
ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable)
}
s"""
// generate grouping key
${keyCode.code}
UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
if ($buffer == null) {
// failed to allocate the first page
throw new OutOfMemoryError("No enough memory for aggregation");
}
// evaluate aggregate function
${evals.map(_.code).mkString("\n")}
// update aggregate buffer
${updates.mkString("\n")}
"""
}
override def simpleString: String = { override def simpleString: String = {
val allAggregateExpressions = aggregateExpressions val allAggregateExpressions = aggregateExpressions
......
...@@ -18,7 +18,12 @@ ...@@ -18,7 +18,12 @@
package org.apache.spark.sql.execution package org.apache.spark.sql.execution
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.SQLContext import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.hash.Murmur3_x86_32
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.util.Benchmark import org.apache.spark.util.Benchmark
/** /**
...@@ -27,34 +32,124 @@ import org.apache.spark.util.Benchmark ...@@ -27,34 +32,124 @@ import org.apache.spark.util.Benchmark
* build/sbt "sql/test-only *BenchmarkWholeStageCodegen" * build/sbt "sql/test-only *BenchmarkWholeStageCodegen"
*/ */
class BenchmarkWholeStageCodegen extends SparkFunSuite { class BenchmarkWholeStageCodegen extends SparkFunSuite {
def testWholeStage(values: Int): Unit = { lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") lazy val sc = SparkContext.getOrCreate(conf)
val sc = SparkContext.getOrCreate(conf) lazy val sqlContext = SQLContext.getOrCreate(sc)
val sqlContext = SQLContext.getOrCreate(sc)
val benchmark = new Benchmark("Single Int Column Scan", values) def testWholeStage(values: Int): Unit = {
val benchmark = new Benchmark("rang/filter/aggregate", values)
benchmark.addCase("Without whole stage codegen") { iter => benchmark.addCase("Without codegen") { iter =>
sqlContext.setConf("spark.sql.codegen.wholeStage", "false") sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
sqlContext.range(values).filter("(id & 1) = 1").count() sqlContext.range(values).filter("(id & 1) = 1").count()
} }
benchmark.addCase("With whole stage codegen") { iter => benchmark.addCase("With codegen") { iter =>
sqlContext.setConf("spark.sql.codegen.wholeStage", "true") sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
sqlContext.range(values).filter("(id & 1) = 1").count() sqlContext.range(values).filter("(id & 1) = 1").count()
} }
/* /*
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate rang/filter/aggregate: Avg Time(ms) Avg Rate(M/s) Relative Rate
------------------------------------------------------------------------------- -------------------------------------------------------------------------------
Without whole stage codegen 7775.53 26.97 1.00 X Without codegen 7775.53 26.97 1.00 X
With whole stage codegen 342.15 612.94 22.73 X With codegen 342.15 612.94 22.73 X
*/ */
benchmark.run() benchmark.run()
} }
ignore("benchmark") { def testAggregateWithKey(values: Int): Unit = {
testWholeStage(1024 * 1024 * 200) val benchmark = new Benchmark("Aggregate with keys", values)
benchmark.addCase("Aggregate w/o codegen") { iter =>
sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect()
}
benchmark.addCase(s"Aggregate w codegen") { iter =>
sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect()
}
/*
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate
-------------------------------------------------------------------------------
Aggregate w/o codegen 4254.38 4.93 1.00 X
Aggregate w codegen 2661.45 7.88 1.60 X
*/
benchmark.run()
}
def testBytesToBytesMap(values: Int): Unit = {
val benchmark = new Benchmark("BytesToBytesMap", values)
benchmark.addCase("hash") { iter =>
var i = 0
val keyBytes = new Array[Byte](16)
val valueBytes = new Array[Byte](16)
val key = new UnsafeRow(1)
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
val value = new UnsafeRow(2)
value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
var s = 0
while (i < values) {
key.setInt(0, i % 1000)
val h = Murmur3_x86_32.hashUnsafeWords(
key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 0)
s += h
i += 1
}
}
Seq("off", "on").foreach { heap =>
benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter =>
val taskMemoryManager = new TaskMemoryManager(
new StaticMemoryManager(
new SparkConf().set("spark.memory.offHeap.enabled", s"${heap == "off"}")
.set("spark.memory.offHeap.size", "102400000"),
Long.MaxValue,
Long.MaxValue,
1),
0)
val map = new BytesToBytesMap(taskMemoryManager, 1024, 64L<<20)
val keyBytes = new Array[Byte](16)
val valueBytes = new Array[Byte](16)
val key = new UnsafeRow(1)
key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
val value = new UnsafeRow(2)
value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
var i = 0
while (i < values) {
key.setInt(0, i % 65536)
val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
if (loc.isDefined) {
value.pointTo(loc.getValueAddress.getBaseObject, loc.getValueAddress.getBaseOffset,
loc.getValueLength)
value.setInt(0, value.getInt(0) + 1)
i += 1
} else {
loc.putNewKey(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
value.getBaseObject, value.getBaseOffset, value.getSizeInBytes)
}
}
}
}
/**
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate
-------------------------------------------------------------------------------
hash 662.06 79.19 1.00 X
BytesToBytesMap (off Heap) 2209.42 23.73 0.30 X
BytesToBytesMap (on Heap) 2957.68 17.73 0.22 X
*/
benchmark.run()
}
test("benchmark") {
// testWholeStage(1024 * 1024 * 200)
// testAggregateWithKey(20 << 20)
// testBytesToBytesMap(1024 * 1024 * 50)
} }
} }
...@@ -47,4 +47,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { ...@@ -47,4 +47,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
assert(df.collect() === Array(Row(9, 4.5))) assert(df.collect() === Array(Row(9, 4.5)))
} }
test("Aggregate with grouping keys should be included in WholeStageCodegen") {
val df = sqlContext.range(3).groupBy("id").count().orderBy("id")
val plan = df.queryExecution.executedPlan
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegen] &&
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment