Skip to content
Snippets Groups Projects
Commit 18066f2e authored by Liang-Chi Hsieh's avatar Liang-Chi Hsieh Committed by Wenchen Fan
Browse files

[SPARK-21052][SQL] Add hash map metrics to join

## What changes were proposed in this pull request?

This adds the average hash map probe metrics to join operator such as `BroadcastHashJoin` and `ShuffledHashJoin`.

This PR adds the API to `HashedRelation` to get average hash map probe.

## How was this patch tested?

Related test cases are added.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #18301 from viirya/SPARK-21052.
parent 29bd251d
No related branches found
No related tags found
Loading
Showing with 296 additions and 60 deletions
......@@ -60,7 +60,7 @@ case class HashAggregateExec(
"peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"),
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
"aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"),
"avgHashmapProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hashmap probe"))
"avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe"))
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
......@@ -94,7 +94,7 @@ case class HashAggregateExec(
val numOutputRows = longMetric("numOutputRows")
val peakMemory = longMetric("peakMemory")
val spillSize = longMetric("spillSize")
val avgHashmapProbe = longMetric("avgHashmapProbe")
val avgHashProbe = longMetric("avgHashProbe")
child.execute().mapPartitions { iter =>
......@@ -119,7 +119,7 @@ case class HashAggregateExec(
numOutputRows,
peakMemory,
spillSize,
avgHashmapProbe)
avgHashProbe)
if (!hasInput && groupingExpressions.isEmpty) {
numOutputRows += 1
Iterator.single[UnsafeRow](aggregationIterator.outputForEmptyGroupingKeyWithoutInput())
......@@ -344,7 +344,7 @@ case class HashAggregateExec(
sorter: UnsafeKVExternalSorter,
peakMemory: SQLMetric,
spillSize: SQLMetric,
avgHashmapProbe: SQLMetric): KVIterator[UnsafeRow, UnsafeRow] = {
avgHashProbe: SQLMetric): KVIterator[UnsafeRow, UnsafeRow] = {
// update peak execution memory
val mapMemory = hashMap.getPeakMemoryUsedBytes
......@@ -355,8 +355,7 @@ case class HashAggregateExec(
metrics.incPeakExecutionMemory(maxMemory)
// Update average hashmap probe
val avgProbes = hashMap.getAverageProbesPerLookup()
avgHashmapProbe.add(avgProbes.ceil.toLong)
avgHashProbe.set(hashMap.getAverageProbesPerLookup())
if (sorter == null) {
// not spilled
......@@ -584,7 +583,7 @@ case class HashAggregateExec(
val doAgg = ctx.freshName("doAggregateWithKeys")
val peakMemory = metricTerm(ctx, "peakMemory")
val spillSize = metricTerm(ctx, "spillSize")
val avgHashmapProbe = metricTerm(ctx, "avgHashmapProbe")
val avgHashProbe = metricTerm(ctx, "avgHashProbe")
def generateGenerateCode(): String = {
if (isFastHashMapEnabled) {
......@@ -611,7 +610,7 @@ case class HashAggregateExec(
s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();"} else ""}
$iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm, $peakMemory, $spillSize,
$avgHashmapProbe);
$avgHashProbe);
}
""")
......
......@@ -89,7 +89,7 @@ class TungstenAggregationIterator(
numOutputRows: SQLMetric,
peakMemory: SQLMetric,
spillSize: SQLMetric,
avgHashmapProbe: SQLMetric)
avgHashProbe: SQLMetric)
extends AggregationIterator(
groupingExpressions,
originalInputAttributes,
......@@ -367,6 +367,22 @@ class TungstenAggregationIterator(
}
}
TaskContext.get().addTaskCompletionListener(_ => {
// At the end of the task, update the task's peak memory usage. Since we destroy
// the map to create the sorter, their memory usages should not overlap, so it is safe
// to just use the max of the two.
val mapMemory = hashMap.getPeakMemoryUsedBytes
val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L)
val maxMemory = Math.max(mapMemory, sorterMemory)
val metrics = TaskContext.get().taskMetrics()
peakMemory.set(maxMemory)
spillSize.set(metrics.memoryBytesSpilled - spillSizeBefore)
metrics.incPeakExecutionMemory(maxMemory)
// Updating average hashmap probe
avgHashProbe.set(hashMap.getAverageProbesPerLookup())
})
///////////////////////////////////////////////////////////////////////////
// Part 7: Iterator's public methods.
///////////////////////////////////////////////////////////////////////////
......@@ -409,22 +425,6 @@ class TungstenAggregationIterator(
}
}
// If this is the last record, update the task's peak memory usage. Since we destroy
// the map to create the sorter, their memory usages should not overlap, so it is safe
// to just use the max of the two.
if (!hasNext) {
val mapMemory = hashMap.getPeakMemoryUsedBytes
val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L)
val maxMemory = Math.max(mapMemory, sorterMemory)
val metrics = TaskContext.get().taskMetrics()
peakMemory += maxMemory
spillSize += metrics.memoryBytesSpilled - spillSizeBefore
metrics.incPeakExecutionMemory(maxMemory)
// Update average hashmap probe if this is the last record.
val averageProbes = hashMap.getAverageProbesPerLookup()
avgHashmapProbe.add(averageProbes.ceil.toLong)
}
numOutputRows += 1
res
} else {
......
......@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Dist
import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.LongType
import org.apache.spark.util.TaskCompletionListener
/**
* Performs an inner hash join of two child relations. When the output RDD of this operator is
......@@ -46,7 +47,8 @@ case class BroadcastHashJoinExec(
extends BinaryExecNode with HashJoin with CodegenSupport {
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe"))
override def requiredChildDistribution: Seq[Distribution] = {
val mode = HashedRelationBroadcastMode(buildKeys)
......@@ -60,12 +62,13 @@ case class BroadcastHashJoinExec(
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
val avgHashProbe = longMetric("avgHashProbe")
val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
streamedPlan.execute().mapPartitions { streamedIter =>
val hashed = broadcastRelation.value.asReadOnlyCopy()
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize)
join(streamedIter, hashed, numOutputRows)
join(streamedIter, hashed, numOutputRows, avgHashProbe)
}
}
......@@ -90,6 +93,23 @@ case class BroadcastHashJoinExec(
}
}
/**
* Returns the codes used to add a task completion listener to update avg hash probe
* at the end of the task.
*/
private def genTaskListener(avgHashProbe: String, relationTerm: String): String = {
val listenerClass = classOf[TaskCompletionListener].getName
val taskContextClass = classOf[TaskContext].getName
s"""
| $taskContextClass$$.MODULE$$.get().addTaskCompletionListener(new $listenerClass() {
| @Override
| public void onTaskCompletion($taskContextClass context) {
| $avgHashProbe.set($relationTerm.getAverageProbesPerLookup());
| }
| });
""".stripMargin
}
/**
* Returns a tuple of Broadcast of HashedRelation and the variable name for it.
*/
......@@ -99,10 +119,16 @@ case class BroadcastHashJoinExec(
val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
val relationTerm = ctx.freshName("relation")
val clsName = broadcastRelation.value.getClass.getName
// At the end of the task, we update the avg hash probe.
val avgHashProbe = metricTerm(ctx, "avgHashProbe")
val addTaskListener = genTaskListener(avgHashProbe, relationTerm)
ctx.addMutableState(clsName, relationTerm,
s"""
| $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy();
| incPeakExecutionMemory($relationTerm.estimatedSize());
| $addTaskListener
""".stripMargin)
(broadcastRelation, relationTerm)
}
......
......@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.joins
import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
......@@ -193,7 +194,8 @@ trait HashJoin {
protected def join(
streamedIter: Iterator[InternalRow],
hashed: HashedRelation,
numOutputRows: SQLMetric): Iterator[InternalRow] = {
numOutputRows: SQLMetric,
avgHashProbe: SQLMetric): Iterator[InternalRow] = {
val joinedIter = joinType match {
case _: InnerLike =>
......@@ -211,6 +213,10 @@ trait HashJoin {
s"BroadcastHashJoin should not take $x as the JoinType")
}
// At the end of the task, we update the avg hash probe.
TaskContext.get().addTaskCompletionListener(_ =>
avgHashProbe.set(hashed.getAverageProbesPerLookup()))
val resultProj = createResultProjection
joinedIter.map { r =>
numOutputRows += 1
......
......@@ -79,6 +79,11 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation {
* Release any used resources.
*/
def close(): Unit
/**
* Returns the average number of probes per key lookup.
*/
def getAverageProbesPerLookup(): Double
}
private[execution] object HashedRelation {
......@@ -242,7 +247,8 @@ private[joins] class UnsafeHashedRelation(
binaryMap = new BytesToBytesMap(
taskMemoryManager,
(nKeys * 1.5 + 1).toInt, // reduce hash collision
pageSizeBytes)
pageSizeBytes,
true)
var i = 0
var keyBuffer = new Array[Byte](1024)
......@@ -273,6 +279,8 @@ private[joins] class UnsafeHashedRelation(
override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException {
read(in.readInt, in.readLong, in.readBytes)
}
override def getAverageProbesPerLookup(): Double = binaryMap.getAverageProbesPerLookup()
}
private[joins] object UnsafeHashedRelation {
......@@ -290,7 +298,8 @@ private[joins] object UnsafeHashedRelation {
taskMemoryManager,
// Only 70% of the slots can be used before growing, more capacity help to reduce collision
(sizeEstimate * 1.5 + 1).toInt,
pageSizeBytes)
pageSizeBytes,
true)
// Create a mapping of buildKeys -> rows
val keyGenerator = UnsafeProjection.create(key)
......@@ -344,7 +353,7 @@ private[joins] object UnsafeHashedRelation {
* determined by `key1 - minKey`.
*
* The map is created as sparse mode, then key-value could be appended into it. Once finish
* appending, caller could all optimize() to try to turn the map into dense mode, which is faster
* appending, caller could call optimize() to try to turn the map into dense mode, which is faster
* to probe.
*
* see http://java-performance.info/implementing-world-fastest-java-int-to-int-hash-map/
......@@ -385,6 +394,10 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
// The number of unique keys.
private var numKeys = 0L
// Tracking average number of probes per key lookup.
private var numKeyLookups = 0L
private var numProbes = 0L
// needed by serializer
def this() = {
this(
......@@ -469,6 +482,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
*/
def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = {
if (isDense) {
numKeyLookups += 1
numProbes += 1
if (key >= minKey && key <= maxKey) {
val value = array((key - minKey).toInt)
if (value > 0) {
......@@ -477,11 +492,14 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
}
} else {
var pos = firstSlot(key)
numKeyLookups += 1
numProbes += 1
while (array(pos + 1) != 0) {
if (array(pos) == key) {
return getRow(array(pos + 1), resultRow)
}
pos = nextSlot(pos)
numProbes += 1
}
}
null
......@@ -509,6 +527,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
*/
def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
if (isDense) {
numKeyLookups += 1
numProbes += 1
if (key >= minKey && key <= maxKey) {
val value = array((key - minKey).toInt)
if (value > 0) {
......@@ -517,11 +537,14 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
}
} else {
var pos = firstSlot(key)
numKeyLookups += 1
numProbes += 1
while (array(pos + 1) != 0) {
if (array(pos) == key) {
return valueIter(array(pos + 1), resultRow)
}
pos = nextSlot(pos)
numProbes += 1
}
}
null
......@@ -573,8 +596,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
private def updateIndex(key: Long, address: Long): Unit = {
var pos = firstSlot(key)
assert(numKeys < array.length / 2)
numKeyLookups += 1
numProbes += 1
while (array(pos) != key && array(pos + 1) != 0) {
pos = nextSlot(pos)
numProbes += 1
}
if (array(pos + 1) == 0) {
// this is the first value for this key, put the address in array.
......@@ -686,6 +712,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
writeLong(maxKey)
writeLong(numKeys)
writeLong(numValues)
writeLong(numKeyLookups)
writeLong(numProbes)
writeLong(array.length)
writeLongArray(writeBuffer, array, array.length)
......@@ -727,6 +755,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
maxKey = readLong()
numKeys = readLong()
numValues = readLong()
numKeyLookups = readLong()
numProbes = readLong()
val length = readLong().toInt
mask = length - 2
......@@ -742,6 +772,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
override def read(kryo: Kryo, in: Input): Unit = {
read(in.readBoolean, in.readLong, in.readBytes)
}
/**
* Returns the average number of probes per key lookup.
*/
def getAverageProbesPerLookup(): Double = numProbes.toDouble / numKeyLookups
}
private[joins] class LongHashedRelation(
......@@ -793,6 +828,8 @@ private[joins] class LongHashedRelation(
resultRow = new UnsafeRow(nFields)
map = in.readObject().asInstanceOf[LongToUnsafeRowMap]
}
override def getAverageProbesPerLookup(): Double = map.getAverageProbesPerLookup()
}
/**
......
......@@ -42,7 +42,8 @@ case class ShuffledHashJoinExec(
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"),
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"))
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"),
"avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe"))
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
......@@ -62,9 +63,10 @@ case class ShuffledHashJoinExec(
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
val avgHashProbe = longMetric("avgHashProbe")
streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
val hashed = buildHashedRelation(buildIter)
join(streamIter, hashed, numOutputRows)
join(streamIter, hashed, numOutputRows, avgHashProbe)
}
}
}
......@@ -57,6 +57,12 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato
override def add(v: Long): Unit = _value += v
// We can set a double value to `SQLMetric` which stores only long value, if it is
// average metrics.
def set(v: Double): Unit = SQLMetrics.setDoubleForAverageMetrics(this, v)
def set(v: Long): Unit = _value = v
def +=(v: Long): Unit = _value += v
override def value: Long = _value
......@@ -74,6 +80,19 @@ object SQLMetrics {
private val TIMING_METRIC = "timing"
private val AVERAGE_METRIC = "average"
private val baseForAvgMetric: Int = 10
/**
* Converts a double value to long value by multiplying a base integer, so we can store it in
* `SQLMetrics`. It only works for average metrics. When showing the metrics on UI, we restore
* it back to a double value up to the decimal places bound by the base integer.
*/
private[sql] def setDoubleForAverageMetrics(metric: SQLMetric, v: Double): Unit = {
assert(metric.metricType == AVERAGE_METRIC,
s"Can't set a double to a metric of metrics type: ${metric.metricType}")
metric.set((v * baseForAvgMetric).toLong)
}
def createMetric(sc: SparkContext, name: String): SQLMetric = {
val acc = new SQLMetric(SUM_METRIC)
acc.register(sc, name = Some(name), countFailedValues = false)
......@@ -104,15 +123,14 @@ object SQLMetrics {
/**
* Create a metric to report the average information (including min, med, max) like
* avg hashmap probe. Because `SQLMetric` stores long values, we take the ceil of the average
* values before storing them. This metric is used to record an average value computed in the
* end of a task. It should be set once. The initial values (zeros) of this metrics will be
* excluded after.
* avg hash probe. As average metrics are double values, this kind of metrics should be
* only set with `SQLMetric.set` method instead of other methods like `SQLMetric.add`.
* The initial values (zeros) of this metrics will be excluded after.
*/
def createAverageMetric(sc: SparkContext, name: String): SQLMetric = {
// The final result of this metric in physical operator UI may looks like:
// probe avg (min, med, max):
// (1, 2, 6)
// (1.2, 2.2, 6.3)
val acc = new SQLMetric(AVERAGE_METRIC)
acc.register(sc, name = Some(s"$name (min, med, max)"), countFailedValues = false)
acc
......@@ -127,7 +145,7 @@ object SQLMetrics {
val numberFormat = NumberFormat.getIntegerInstance(Locale.US)
numberFormat.format(values.sum)
} else if (metricsType == AVERAGE_METRIC) {
val numberFormat = NumberFormat.getIntegerInstance(Locale.US)
val numberFormat = NumberFormat.getNumberInstance(Locale.US)
val validValues = values.filter(_ > 0)
val Seq(min, med, max) = {
......@@ -137,7 +155,7 @@ object SQLMetrics {
val sorted = validValues.sorted
Seq(sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
}
metric.map(numberFormat.format)
metric.map(v => numberFormat.format(v.toDouble / baseForAvgMetric))
}
s"\n($min, $med, $max)"
} else {
......
......@@ -47,9 +47,10 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
private def getSparkPlanMetrics(
df: DataFrame,
expectedNumOfJobs: Int,
expectedNodeIds: Set[Long]): Option[Map[Long, (String, Map[String, Any])]] = {
expectedNodeIds: Set[Long],
enableWholeStage: Boolean = false): Option[Map[Long, (String, Map[String, Any])]] = {
val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet
withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
withSQLConf("spark.sql.codegen.wholeStage" -> enableWholeStage.toString) {
df.collect()
}
sparkContext.listenerBus.waitUntilEmpty(10000)
......@@ -110,6 +111,20 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
}
}
/**
* Generates a `DataFrame` by filling randomly generated bytes for hash collision.
*/
private def generateRandomBytesDF(numRows: Int = 65535): DataFrame = {
val random = new Random()
val manyBytes = (0 until numRows).map { _ =>
val byteArrSize = random.nextInt(100)
val bytes = new Array[Byte](byteArrSize)
random.nextBytes(bytes)
(bytes, random.nextInt(100))
}
manyBytes.toSeq.toDF("a", "b")
}
test("LocalTableScanExec computes metrics in collect and take") {
val df1 = spark.createDataset(Seq(1, 2, 3))
val logical = df1.queryExecution.logical
......@@ -151,9 +166,9 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
val df = testData2.groupBy().count() // 2 partitions
val expected1 = Seq(
Map("number of output rows" -> 2L,
"avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)"),
"avg hash probe (min, med, max)" -> "\n(1, 1, 1)"),
Map("number of output rows" -> 1L,
"avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)"))
"avg hash probe (min, med, max)" -> "\n(1, 1, 1)"))
testSparkPlanMetrics(df, 1, Map(
2L -> ("HashAggregate", expected1(0)),
0L -> ("HashAggregate", expected1(1)))
......@@ -163,9 +178,9 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
val df2 = testData2.groupBy('a).count()
val expected2 = Seq(
Map("number of output rows" -> 4L,
"avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)"),
"avg hash probe (min, med, max)" -> "\n(1, 1, 1)"),
Map("number of output rows" -> 3L,
"avg hashmap probe (min, med, max)" -> "\n(1, 1, 1)"))
"avg hash probe (min, med, max)" -> "\n(1, 1, 1)"))
testSparkPlanMetrics(df2, 1, Map(
2L -> ("HashAggregate", expected2(0)),
0L -> ("HashAggregate", expected2(1)))
......@@ -173,19 +188,42 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
}
test("Aggregate metrics: track avg probe") {
val random = new Random()
val manyBytes = (0 until 65535).map { _ =>
val byteArrSize = random.nextInt(100)
val bytes = new Array[Byte](byteArrSize)
random.nextBytes(bytes)
(bytes, random.nextInt(100))
}
val df = manyBytes.toSeq.toDF("a", "b").repartition(1).groupBy('a).count()
val metrics = getSparkPlanMetrics(df, 1, Set(2L, 0L)).get
Seq(metrics(2L)._2("avg hashmap probe (min, med, max)"),
metrics(0L)._2("avg hashmap probe (min, med, max)")).foreach { probes =>
probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe =>
assert(probe.toInt > 1)
// The executed plan looks like:
// HashAggregate(keys=[a#61], functions=[count(1)], output=[a#61, count#71L])
// +- Exchange hashpartitioning(a#61, 5)
// +- HashAggregate(keys=[a#61], functions=[partial_count(1)], output=[a#61, count#76L])
// +- Exchange RoundRobinPartitioning(1)
// +- LocalTableScan [a#61]
//
// Assume the execution plan with node id is:
// Wholestage disabled:
// HashAggregate(nodeId = 0)
// Exchange(nodeId = 1)
// HashAggregate(nodeId = 2)
// Exchange (nodeId = 3)
// LocalTableScan(nodeId = 4)
//
// Wholestage enabled:
// WholeStageCodegen(nodeId = 0)
// HashAggregate(nodeId = 1)
// Exchange(nodeId = 2)
// WholeStageCodegen(nodeId = 3)
// HashAggregate(nodeId = 4)
// Exchange(nodeId = 5)
// LocalTableScan(nodeId = 6)
Seq(true, false).foreach { enableWholeStage =>
val df = generateRandomBytesDF().repartition(1).groupBy('a).count()
val nodeIds = if (enableWholeStage) {
Set(4L, 1L)
} else {
Set(2L, 0L)
}
val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get
nodeIds.foreach { nodeId =>
val probes = metrics(nodeId)._2("avg hash probe (min, med, max)")
probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe =>
assert(probe.toDouble > 1.0)
}
}
}
}
......@@ -267,10 +305,120 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
val df = df1.join(broadcast(df2), "key")
testSparkPlanMetrics(df, 2, Map(
1L -> ("BroadcastHashJoin", Map(
"number of output rows" -> 2L)))
"number of output rows" -> 2L,
"avg hash probe (min, med, max)" -> "\n(1, 1, 1)")))
)
}
test("BroadcastHashJoin metrics: track avg probe") {
// The executed plan looks like:
// Project [a#210, b#211, b#221]
// +- BroadcastHashJoin [a#210], [a#220], Inner, BuildRight
// :- Project [_1#207 AS a#210, _2#208 AS b#211]
// : +- Filter isnotnull(_1#207)
// : +- LocalTableScan [_1#207, _2#208]
// +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, binary, true]))
// +- Project [_1#217 AS a#220, _2#218 AS b#221]
// +- Filter isnotnull(_1#217)
// +- LocalTableScan [_1#217, _2#218]
//
// Assume the execution plan with node id is
// WholeStageCodegen disabled:
// Project(nodeId = 0)
// BroadcastHashJoin(nodeId = 1)
// ...(ignored)
//
// WholeStageCodegen enabled:
// WholeStageCodegen(nodeId = 0)
// Project(nodeId = 1)
// BroadcastHashJoin(nodeId = 2)
// Project(nodeId = 3)
// Filter(nodeId = 4)
// ...(ignored)
Seq(true, false).foreach { enableWholeStage =>
val df1 = generateRandomBytesDF()
val df2 = generateRandomBytesDF()
val df = df1.join(broadcast(df2), "a")
val nodeIds = if (enableWholeStage) {
Set(2L)
} else {
Set(1L)
}
val metrics = getSparkPlanMetrics(df, 2, nodeIds, enableWholeStage).get
nodeIds.foreach { nodeId =>
val probes = metrics(nodeId)._2("avg hash probe (min, med, max)")
probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe =>
assert(probe.toDouble > 1.0)
}
}
}
}
test("ShuffledHashJoin metrics") {
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "40",
"spark.sql.shuffle.partitions" -> "2",
"spark.sql.join.preferSortMergeJoin" -> "false") {
val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value")
val df2 = (1 to 10).map(i => (i, i.toString)).toSeq.toDF("key", "value")
// Assume the execution plan is
// ... -> ShuffledHashJoin(nodeId = 1) -> Project(nodeId = 0)
val df = df1.join(df2, "key")
val metrics = getSparkPlanMetrics(df, 1, Set(1L))
testSparkPlanMetrics(df, 1, Map(
1L -> ("ShuffledHashJoin", Map(
"number of output rows" -> 2L,
"avg hash probe (min, med, max)" -> "\n(1, 1, 1)")))
)
}
}
test("ShuffledHashJoin metrics: track avg probe") {
// The executed plan looks like:
// Project [a#308, b#309, b#319]
// +- ShuffledHashJoin [a#308], [a#318], Inner, BuildRight
// :- Exchange hashpartitioning(a#308, 2)
// : +- Project [_1#305 AS a#308, _2#306 AS b#309]
// : +- Filter isnotnull(_1#305)
// : +- LocalTableScan [_1#305, _2#306]
// +- Exchange hashpartitioning(a#318, 2)
// +- Project [_1#315 AS a#318, _2#316 AS b#319]
// +- Filter isnotnull(_1#315)
// +- LocalTableScan [_1#315, _2#316]
//
// Assume the execution plan with node id is
// WholeStageCodegen disabled:
// Project(nodeId = 0)
// ShuffledHashJoin(nodeId = 1)
// ...(ignored)
//
// WholeStageCodegen enabled:
// WholeStageCodegen(nodeId = 0)
// Project(nodeId = 1)
// ShuffledHashJoin(nodeId = 2)
// ...(ignored)
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "5000000",
"spark.sql.shuffle.partitions" -> "2",
"spark.sql.join.preferSortMergeJoin" -> "false") {
Seq(true, false).foreach { enableWholeStage =>
val df1 = generateRandomBytesDF(65535 * 5)
val df2 = generateRandomBytesDF(65535)
val df = df1.join(df2, "a")
val nodeIds = if (enableWholeStage) {
Set(2L)
} else {
Set(1L)
}
val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get
nodeIds.foreach { nodeId =>
val probes = metrics(nodeId)._2("avg hash probe (min, med, max)")
probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe =>
assert(probe.toDouble > 1.0)
}
}
}
}
}
test("BroadcastHashJoin(outer) metrics") {
val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value")
val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment