diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 7aa08fb63053b1da1684b9fab9d06f9f21b8dc07..c5b2b7d11893c35d75e1f3059c57f6afaf85aa38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1775,7 +1775,7 @@ class DataFrame private[sql]( private def withCallback[T](name: String, df: DataFrame)(action: DataFrame => T) = { try { df.queryExecution.executedPlan.foreach { plan => - plan.metrics.valuesIterator.foreach(_.reset()) + plan.resetMetrics() } val start = System.nanoTime() val result = action(df) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 3cc99d3c7b1b2b55a9f595d669048599d283d6b4..c72b8dc70708f96ca5a653340e870c85f93b081e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -77,6 +77,13 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ */ private[sql] def metrics: Map[String, SQLMetric[_, _]] = Map.empty + /** + * Reset all the metrics. + */ + private[sql] def resetMetrics(): Unit = { + metrics.valuesIterator.foreach(_.reset()) + } + /** * Return a LongSQLMetric according to the name. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 30f74fc14f6c61e480408c8b1df319744d39320d..f35efb5b24b1f2d8815a70032a50f64c2a962eb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight} -import org.apache.spark.util.Utils +import org.apache.spark.sql.execution.metric.{LongSQLMetric, LongSQLMetricValue, SQLMetric} /** * An interface for those physical operators that support codegen. @@ -42,6 +42,19 @@ trait CodegenSupport extends SparkPlan { case _ => nodeName.toLowerCase } + /** + * Creates a metric using the specified name. + * + * @return name of the variable representing the metric + */ + def metricTerm(ctx: CodegenContext, name: String): String = { + val metric = ctx.addReferenceObj(name, longMetric(name)) + val value = ctx.freshName("metricValue") + val cls = classOf[LongSQLMetricValue].getName + ctx.addMutableState(cls, value, s"$value = ($cls) $metric.localValue();") + value + } + /** * Whether this SparkPlan support whole stage codegen or not. */ @@ -316,6 +329,10 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } } + private[sql] override def resetMetrics(): Unit = { + plan.foreach(_.resetMetrics()) + } + override def generateTreeString( depth: Int, lastChildren: Seq[Boolean], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index a6950f805a113b27c8ff6807d2023965ed19c973..852203f3743dc67991f0712231ab385922ec3a19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -202,6 +202,7 @@ case class TungstenAggregate( | } """.stripMargin) + val numOutput = metricTerm(ctx, "numOutputRows") s""" | if (!$initAgg) { | $initAgg = true; @@ -210,6 +211,7 @@ case class TungstenAggregate( | // output the result | ${genResult.trim} | + | $numOutput.add(1); | ${consume(ctx, resultVars).trim} | } """.stripMargin @@ -297,6 +299,7 @@ case class TungstenAggregate( val peakMemory = Math.max(mapMemory, sorterMemory) val metrics = TaskContext.get().taskMetrics() metrics.incPeakExecutionMemory(peakMemory) + // TODO: update data size and spill size if (sorter == null) { // not spilled @@ -456,6 +459,7 @@ case class TungstenAggregate( val keyTerm = ctx.freshName("aggKey") val bufferTerm = ctx.freshName("aggBuffer") val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan) + val numOutput = metricTerm(ctx, "numOutputRows") s""" if (!$initAgg) { @@ -465,6 +469,7 @@ case class TungstenAggregate( // output the result while ($iterTerm.next()) { + $numOutput.add(1); UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); $outputCode diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 949acb9aca7624d8d2143b6da9b10e01c39a8bc9..4b82d5563460b7262d8300679e0f39b1795fd330 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics} import org.apache.spark.sql.types.LongType import org.apache.spark.util.random.PoissonSampler @@ -78,6 +78,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val numOutput = metricTerm(ctx, "numOutputRows") val expr = ExpressionCanonicalizer.execute( BindReferences.bindReference(condition, child.output)) ctx.currentVars = input @@ -90,6 +91,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit s""" | ${eval.code} | if ($nullCheck ${eval.value}) { + | $numOutput.add(1); | ${consume(ctx, ctx.currentVars)} | } """.stripMargin @@ -159,6 +161,8 @@ case class Range( } protected override def doProduce(ctx: CodegenContext): String = { + val numOutput = metricTerm(ctx, "numOutputRows") + val initTerm = ctx.freshName("initRange") ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") val partitionEnd = ctx.freshName("partitionEnd") @@ -204,6 +208,8 @@ case class Range( | } else { | $partitionEnd = end.longValue(); | } + | + | $numOutput.add(($partitionEnd - $number) / ${step}L); | } """.stripMargin) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 35c7963b48c4aed373de1d452402b8867888b1f4..985e74011daa7e0e9f2350e910548c68a7f200b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -163,6 +163,7 @@ case class BroadcastHashJoin( case BuildRight => input ++ buildColumns } + val numOutput = metricTerm(ctx, "numOutputRows") val outputCode = if (condition.isDefined) { // filter the output via condition ctx.currentVars = resultVars @@ -170,11 +171,15 @@ case class BroadcastHashJoin( s""" | ${ev.code} | if (!${ev.isNull} && ${ev.value}) { + | $numOutput.add(1); | ${consume(ctx, resultVars)} | } """.stripMargin } else { - consume(ctx, resultVars) + s""" + |$numOutput.add(1); + |${consume(ctx, resultVars)} + """.stripMargin } if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index dc6c647a4a95f41fb125f207a16233892f7a71f3..1c7e69f30fb48e889ab55bff93db7434fa717070 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -63,7 +63,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { rang/filter/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- rang/filter/sum codegen=false 14332 / 16646 36.0 27.8 1.0X - rang/filter/sum codegen=true 845 / 940 620.0 1.6 17.0X + rang/filter/sum codegen=true 897 / 1022 584.6 1.7 16.4X */ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index d24625a5351ee37310f414f2e446b218d5dff7d4..f4bc9e501c21c77e471f049c86de2b30acdd828b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -298,24 +298,22 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet - // Assume the execution plan is - // PhysicalRDD(nodeId = 0) - person.select('name).write.format("json").save(file.getAbsolutePath) - sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) - assert(executionIds.size === 1) - val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs - // Use "<=" because there is a race condition that we may miss some jobs - // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. - assert(jobs.size <= 1) - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) - // Because "save" will create a new DataFrame internally, we cannot get the real metric id. - // However, we still can check the value. - assert(metricValues.values.toSeq === Seq("2")) - } + val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + // Assume the execution plan is + // PhysicalRDD(nodeId = 0) + person.select('name).write.format("json").save(file.getAbsolutePath) + sparkContext.listenerBus.waitUntilEmpty(10000) + val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + assert(executionIds.size === 1) + val executionId = executionIds.head + val jobs = sqlContext.listener.getExecution(executionId).get.jobs + // Use "<=" because there is a race condition that we may miss some jobs + // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. + assert(jobs.size <= 1) + val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + // Because "save" will create a new DataFrame internally, we cannot get the real metric id. + // However, we still can check the value. + assert(metricValues.values.toSeq === Seq("2")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index d3191d3aead95be3f35a077bcd7ea4d78eadba02..15a95623d1e5cc442d4adfeaa3c6f8c189fce112 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark._ import org.apache.spark.sql.{functions, QueryTest} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegen} import org.apache.spark.sql.test.SharedSQLContext class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { @@ -92,17 +92,19 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { - metrics += qe.executedPlan.longMetric("numOutputRows").value.value + val metric = qe.executedPlan match { + case w: WholeStageCodegen => w.plan.longMetric("numOutputRows") + case other => other.longMetric("numOutputRows") + } + metrics += metric.value.value } } sqlContext.listenerManager.register(listener) - withSQLConf("spark.sql.codegen.wholeStage" -> "false") { - val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() - df.collect() - df.collect() - Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() - } + val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() + df.collect() + df.collect() + Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect() assert(metrics.length == 3) assert(metrics(0) === 1)