From 15ff85b3163acbe8052d4489a00bcf1d2332fcf0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan <cloud0fan@163.com> Date: Tue, 13 Oct 2015 17:59:32 -0700 Subject: [PATCH] [SPARK-11068] [SQL] add callback to query execution With this feature, we can track the query plan, time cost, exception during query execution for spark users. Author: Wenchen Fan <cloud0fan@163.com> Closes #9078 from cloud-fan/callback. --- .../org/apache/spark/sql/DataFrame.scala | 46 +++++- .../spark/sql/QueryExecutionListener.scala | 136 ++++++++++++++++++ .../org/apache/spark/sql/SQLContext.scala | 3 + .../spark/sql/DataFrameCallbackSuite.scala | 82 +++++++++++ 4 files changed, 261 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/QueryExecutionListener.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameCallbackSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 01f60aba87..bfe8d3c8ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1344,7 +1344,9 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def head(n: Int): Array[Row] = limit(n).collect() + def head(n: Int): Array[Row] = withCallback("head", limit(n)) { df => + df.collect(needCallback = false) + } /** * Returns the first row. @@ -1414,8 +1416,18 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def collect(): Array[Row] = withNewExecutionId { - queryExecution.executedPlan.executeCollectPublic() + def collect(): Array[Row] = collect(needCallback = true) + + private def collect(needCallback: Boolean): Array[Row] = { + def execute(): Array[Row] = withNewExecutionId { + queryExecution.executedPlan.executeCollectPublic() + } + + if (needCallback) { + withCallback("collect", this)(_ => execute()) + } else { + execute() + } } /** @@ -1423,8 +1435,10 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def collectAsList(): java.util.List[Row] = withNewExecutionId { - java.util.Arrays.asList(rdd.collect() : _*) + def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ => + withNewExecutionId { + java.util.Arrays.asList(rdd.collect() : _*) + } } /** @@ -1432,7 +1446,9 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def count(): Long = groupBy().count().collect().head.getLong(0) + def count(): Long = withCallback("count", groupBy().count()) { df => + df.collect(needCallback = false).head.getLong(0) + } /** * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. @@ -1936,6 +1952,24 @@ class DataFrame private[sql]( SQLExecution.withNewExecutionId(sqlContext, queryExecution)(body) } + /** + * Wrap a DataFrame action to track the QueryExecution and time cost, then report to the + * user-registered callback functions. + */ + private def withCallback[T](name: String, df: DataFrame)(action: DataFrame => T) = { + try { + val start = System.nanoTime() + val result = action(df) + val end = System.nanoTime() + sqlContext.listenerManager.onSuccess(name, df.queryExecution, end - start) + result + } catch { + case e: Exception => + sqlContext.listenerManager.onFailure(name, df.queryExecution, e) + throw e + } + } + //////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////// // End of deprecated methods diff --git a/sql/core/src/main/scala/org/apache/spark/sql/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/QueryExecutionListener.scala new file mode 100644 index 0000000000..14fbebb45f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/QueryExecutionListener.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.locks.ReentrantReadWriteLock +import scala.collection.mutable.ListBuffer + +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.Logging +import org.apache.spark.sql.execution.QueryExecution + + +/** + * The interface of query execution listener that can be used to analyze execution metrics. + * + * Note that implementations should guarantee thread-safety as they will be used in a non + * thread-safe way. + */ +@Experimental +trait QueryExecutionListener { + + /** + * A callback function that will be called when a query executed successfully. + * Implementations should guarantee thread-safe. + * + * @param funcName the name of the action that triggered this query. + * @param qe the QueryExecution object that carries detail information like logical plan, + * physical plan, etc. + * @param duration the execution time for this query in nanoseconds. + */ + @DeveloperApi + def onSuccess(funcName: String, qe: QueryExecution, duration: Long) + + /** + * A callback function that will be called when a query execution failed. + * Implementations should guarantee thread-safe. + * + * @param funcName the name of the action that triggered this query. + * @param qe the QueryExecution object that carries detail information like logical plan, + * physical plan, etc. + * @param exception the exception that failed this query. + */ + @DeveloperApi + def onFailure(funcName: String, qe: QueryExecution, exception: Exception) +} + +@Experimental +class ExecutionListenerManager extends Logging { + private[this] val listeners = ListBuffer.empty[QueryExecutionListener] + private[this] val lock = new ReentrantReadWriteLock() + + /** Acquires a read lock on the cache for the duration of `f`. */ + private def readLock[A](f: => A): A = { + val rl = lock.readLock() + rl.lock() + try f finally { + rl.unlock() + } + } + + /** Acquires a write lock on the cache for the duration of `f`. */ + private def writeLock[A](f: => A): A = { + val wl = lock.writeLock() + wl.lock() + try f finally { + wl.unlock() + } + } + + /** + * Registers the specified QueryExecutionListener. + */ + @DeveloperApi + def register(listener: QueryExecutionListener): Unit = writeLock { + listeners += listener + } + + /** + * Unregisters the specified QueryExecutionListener. + */ + @DeveloperApi + def unregister(listener: QueryExecutionListener): Unit = writeLock { + listeners -= listener + } + + /** + * clears out all registered QueryExecutionListeners. + */ + @DeveloperApi + def clear(): Unit = writeLock { + listeners.clear() + } + + private[sql] def onSuccess( + funcName: String, + qe: QueryExecution, + duration: Long): Unit = readLock { + withErrorHandling { listener => + listener.onSuccess(funcName, qe, duration) + } + } + + private[sql] def onFailure( + funcName: String, + qe: QueryExecution, + exception: Exception): Unit = readLock { + withErrorHandling { listener => + listener.onFailure(funcName, qe, exception) + } + } + + private def withErrorHandling(f: QueryExecutionListener => Unit): Unit = { + for (listener <- listeners) { + try { + f(listener) + } catch { + case e: Exception => logWarning("error executing query execution listener", e) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index cd937257d3..a835408f8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -177,6 +177,9 @@ class SQLContext private[sql]( */ def getAllConfs: immutable.Map[String, String] = conf.getAllConfs + @transient + lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager + @transient protected[sql] lazy val catalog: Catalog = new SimpleCatalog(conf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameCallbackSuite.scala new file mode 100644 index 0000000000..4e286a0076 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameCallbackSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.test.SharedSQLContext + +import scala.collection.mutable.ArrayBuffer + +class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + import functions._ + + test("execute callback functions when a DataFrame action finished successfully") { + val metrics = ArrayBuffer.empty[(String, QueryExecution, Long)] + val listener = new QueryExecutionListener { + // Only test successful case here, so no need to implement `onFailure` + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + metrics += ((funcName, qe, duration)) + } + } + sqlContext.listenerManager.register(listener) + + val df = Seq(1 -> "a").toDF("i", "j") + df.select("i").collect() + df.filter($"i" > 0).count() + + assert(metrics.length == 2) + + assert(metrics(0)._1 == "collect") + assert(metrics(0)._2.analyzed.isInstanceOf[Project]) + assert(metrics(0)._3 > 0) + + assert(metrics(1)._1 == "count") + assert(metrics(1)._2.analyzed.isInstanceOf[Aggregate]) + assert(metrics(1)._3 > 0) + } + + test("execute callback functions when a DataFrame action failed") { + val metrics = ArrayBuffer.empty[(String, QueryExecution, Exception)] + val listener = new QueryExecutionListener { + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + metrics += ((funcName, qe, exception)) + } + + // Only test failed case here, so no need to implement `onSuccess` + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {} + } + sqlContext.listenerManager.register(listener) + + val errorUdf = udf[Int, Int] { _ => throw new RuntimeException("udf error") } + val df = sparkContext.makeRDD(Seq(1 -> "a")).toDF("i", "j") + + // Ignore the log when we are expecting an exception. + sparkContext.setLogLevel("FATAL") + val e = intercept[SparkException](df.select(errorUdf($"i")).collect()) + + assert(metrics.length == 1) + assert(metrics(0)._1 == "collect") + assert(metrics(0)._2.analyzed.isInstanceOf[Project]) + assert(metrics(0)._3.getMessage == e.getMessage) + } +} -- GitLab