diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index dfdaaaae872e4a0aa7bfee9dc2de0111b524b2b8..b00223a86d4d4112fc077b30e9bbd5fcf32f71e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -565,13 +565,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { val dataSize = rows.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum longMetric("dataSize") += dataSize - // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` - // directly without setting an execution id. We should be tolerant to it. - if (executionId != null) { - sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates( - executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq)) - } - + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) rows } }(SubqueryExec.executionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index ce5013daeb1f90dc39f759ef48868e9c2ab21669..7a2505ca86e3eb80128a4f0755155873fba6e39d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -97,13 +97,7 @@ case class BroadcastExchangeExec( val broadcasted = sparkContext.broadcast(relation) longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000 - // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` - // directly without setting an execution id. We should be tolerant to it. - if (executionId != null) { - sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates( - executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq)) - } - + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) broadcasted } catch { case oe: OutOfMemoryError => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index dbc27d8b237f3e72d561ad4211ae30e64dfa8470..ef982a4ebd10d14461b77bb0aac97e02d7c12c1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -22,9 +22,15 @@ import java.util.Locale import org.apache.spark.SparkContext import org.apache.spark.scheduler.AccumulableInfo +import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} +/** + * A metric used in a SQL query plan. This is implemented as an [[AccumulatorV2]]. Updates on + * the executor side are automatically propagated and shown in the SQL UI through metrics. Updates + * on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates()]]. + */ class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] { // This is a workaround for SPARK-11013. // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will @@ -126,4 +132,18 @@ object SQLMetrics { s"\n$sum ($min, $med, $max)" } } + + /** + * Updates metrics based on the driver side value. This is useful for certain metrics that + * are only updated on the driver, e.g. subquery execution time, or number of files. + */ + def postDriverMetricUpdates( + sc: SparkContext, executionId: String, metrics: Seq[SQLMetric]): Unit = { + // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` + // directly without setting an execution id. We should be tolerant to it. + if (executionId != null) { + sc.listenerBus.post( + SparkListenerDriverAccumUpdates(executionId.toLong, metrics.map(m => m.id -> m.value))) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 12d3bc9281f354a9846109ea9438f8918c48dea0..b4a91230a0012cd34657f30766971f0c27947b78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -47,6 +47,13 @@ case class SparkListenerSQLExecutionStart( case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long) extends SparkListenerEvent +/** + * A message used to update SQL metric value for driver-side updates (which doesn't get reflected + * automatically). + * + * @param executionId The execution id for a query, so we can find the query plan. + * @param accumUpdates Map from accumulator id to the metric value (metrics are always 64-bit ints). + */ @DeveloperApi case class SparkListenerDriverAccumUpdates( executionId: Long, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index cf867309f86662895b206f74da6c4b75d8ed71d9..508da1b01b3fcd6a5ace2271d77b50bece943d89 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -477,9 +477,11 @@ private case class MyPlan(sc: SparkContext, expectedValue: Long) extends LeafExe override def doExecute(): RDD[InternalRow] = { longMetric("dummy") += expectedValue - sc.listenerBus.post(SparkListenerDriverAccumUpdates( - sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY).toLong, - metrics.values.map(m => m.id -> m.value).toSeq)) + + SQLMetrics.postDriverMetricUpdates( + sc, + sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY), + metrics.values.toSeq) sc.emptyRDD } }