Skip to content
Snippets Groups Projects
Commit 07ced434 authored by Wenchen Fan's avatar Wenchen Fan Committed by Yin Huai
Browse files

[SPARK-11253] [SQL] reset all accumulators in physical operators before execute an action

With this change, our query execution listener can get the metrics correctly.

The UI still looks good after this change.
<img width="257" alt="screen shot 2015-10-23 at 11 25 14 am" src="https://cloud.githubusercontent.com/assets/3182036/10683834/d516f37e-7978-11e5-8118-343ed40eb824.png">
<img width="494" alt="screen shot 2015-10-23 at 11 25 01 am" src="https://cloud.githubusercontent.com/assets/3182036/10683837/e1fa60da-7978-11e5-8ec8-178b88f27764.png">

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9215 from cloud-fan/metric.
parent 87f82a5f
No related branches found
No related tags found
No related merge requests found
......@@ -1974,6 +1974,9 @@ class DataFrame private[sql](
*/
private def withCallback[T](name: String, df: DataFrame)(action: DataFrame => T) = {
try {
df.queryExecution.executedPlan.foreach { plan =>
plan.metrics.valuesIterator.foreach(_.reset())
}
val start = System.nanoTime()
val result = action(df)
val end = System.nanoTime()
......
......@@ -28,7 +28,12 @@ import org.apache.spark.{Accumulable, AccumulableParam, SparkContext}
*/
private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T](
name: String, val param: SQLMetricParam[R, T])
extends Accumulable[R, T](param.zero, param, Some(name), true)
extends Accumulable[R, T](param.zero, param, Some(name), true) {
def reset(): Unit = {
this.value = param.zero
}
}
/**
* Create a layer for specialized metric. We cannot add `@specialized` to
......
......@@ -17,14 +17,14 @@
package org.apache.spark.sql.util
import org.apache.spark.SparkException
import scala.collection.mutable.ArrayBuffer
import org.apache.spark._
import org.apache.spark.sql.{functions, QueryTest}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.test.SharedSQLContext
import scala.collection.mutable.ArrayBuffer
class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
import testImplicits._
import functions._
......@@ -54,6 +54,8 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
assert(metrics(1)._1 == "count")
assert(metrics(1)._2.analyzed.isInstanceOf[Aggregate])
assert(metrics(1)._3 > 0)
sqlContext.listenerManager.unregister(listener)
}
test("execute callback functions when a DataFrame action failed") {
......@@ -79,5 +81,78 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
assert(metrics(0)._1 == "collect")
assert(metrics(0)._2.analyzed.isInstanceOf[Project])
assert(metrics(0)._3.getMessage == e.getMessage)
sqlContext.listenerManager.unregister(listener)
}
test("get numRows metrics by callback") {
val metrics = ArrayBuffer.empty[Long]
val listener = new QueryExecutionListener {
// Only test successful case here, so no need to implement `onFailure`
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
metrics += qe.executedPlan.longMetric("numInputRows").value.value
}
}
sqlContext.listenerManager.register(listener)
val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count()
df.collect()
df.collect()
Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect()
assert(metrics.length == 3)
assert(metrics(0) == 1)
assert(metrics(1) == 1)
assert(metrics(2) == 2)
sqlContext.listenerManager.unregister(listener)
}
// TODO: Currently some LongSQLMetric use -1 as initial value, so if the accumulator is never
// updated, we can filter it out later. However, when we aggregate(sum) accumulator values at
// driver side for SQL physical operators, these -1 values will make our result smaller.
// A easy fix is to create a new SQLMetric(including new MetricValue, MetricParam, etc.), but we
// can do it later because the impact is just too small (1048576 tasks for 1 MB).
ignore("get size metrics by callback") {
val metrics = ArrayBuffer.empty[Long]
val listener = new QueryExecutionListener {
// Only test successful case here, so no need to implement `onFailure`
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
metrics += qe.executedPlan.longMetric("dataSize").value.value
val bottomAgg = qe.executedPlan.children(0).children(0)
metrics += bottomAgg.longMetric("dataSize").value.value
}
}
sqlContext.listenerManager.register(listener)
val sparkListener = new SaveInfoListener
sqlContext.sparkContext.addSparkListener(sparkListener)
val df = (1 to 100).map(i => i -> i.toString).toDF("i", "j")
df.groupBy("i").count().collect()
def getPeakExecutionMemory(stageId: Int): Long = {
val peakMemoryAccumulator = sparkListener.getCompletedStageInfos(stageId).accumulables
.filter(_._2.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)
assert(peakMemoryAccumulator.size == 1)
peakMemoryAccumulator.head._2.value.toLong
}
assert(sparkListener.getCompletedStageInfos.length == 2)
val bottomAggDataSize = getPeakExecutionMemory(0)
val topAggDataSize = getPeakExecutionMemory(1)
// For this simple case, the peakExecutionMemory of a stage should be the data size of the
// aggregate operator, as we only have one memory consuming operator per stage.
assert(metrics.length == 2)
assert(metrics(0) == topAggDataSize)
assert(metrics(1) == bottomAggDataSize)
sqlContext.listenerManager.unregister(listener)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment