diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index cdae8ea458949185dfd3815d550bb8656eac98bd..393925161fc7b7ad52feba2499bf1042cff898d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -25,9 +25,9 @@ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogRelation, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation, SaveIntoDataSourceCommand} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.StructType @@ -211,13 +211,15 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } assertNotBucketed("save") - val dataSource = DataSource( - df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap) - dataSource.write(mode, df) + runCommand(df.sparkSession, "save") { + SaveIntoDataSourceCommand( + query = df.logicalPlan, + provider = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = extraOptions.toMap, + mode = mode) + } } /** @@ -260,13 +262,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { ) } - df.sparkSession.sessionState.executePlan( + runCommand(df.sparkSession, "insertInto") { InsertIntoTable( table = UnresolvedRelation(tableIdent), partition = Map.empty[String, Option[String]], query = df.logicalPlan, overwrite = mode == SaveMode.Overwrite, - ifNotExists = false)).toRdd + ifNotExists = false) + } } private def getBucketSpec: Option[BucketSpec] = { @@ -389,10 +392,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { schema = new StructType, provider = Some(source), partitionColumnNames = partitioningColumns.getOrElse(Nil), - bucketSpec = getBucketSpec - ) - df.sparkSession.sessionState.executePlan( - CreateTable(tableDesc, mode, Some(df.logicalPlan))).toRdd + bucketSpec = getBucketSpec) + + runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan))) } /** @@ -573,6 +575,25 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { format("csv").save(path) } + /** + * Wrap a DataFrameWriter action to track the QueryExecution and time cost, then report to the + * user-registered callback functions. + */ + private def runCommand(session: SparkSession, name: String)(command: LogicalPlan): Unit = { + val qe = session.sessionState.executePlan(command) + try { + val start = System.nanoTime() + // call `QueryExecution.toRDD` to trigger the execution of commands. + qe.toRdd + val end = System.nanoTime() + session.listenerManager.onSuccess(name, qe, end - start) + } catch { + case e: Exception => + session.listenerManager.onFailure(name, qe, e) + throw e + } + } + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala new file mode 100644 index 0000000000000000000000000000000000000000..6f19ea195c0cd736b81817e25020c1968046e47b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.command.RunnableCommand + +/** + * Saves the results of `query` in to a data source. + * + * Note that this command is different from [[InsertIntoDataSourceCommand]]. This command will call + * `CreatableRelationProvider.createRelation` to write out the data, while + * [[InsertIntoDataSourceCommand]] calls `InsertableRelation.insert`. Ideally these 2 data source + * interfaces should do the same thing, but as we've already published these 2 interfaces and the + * implementations may have different logic, we have to keep these 2 different commands. + */ +case class SaveIntoDataSourceCommand( + query: LogicalPlan, + provider: String, + partitionColumns: Seq[String], + options: Map[String, String], + mode: SaveMode) extends RunnableCommand { + + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) + + override def run(sparkSession: SparkSession): Seq[Row] = { + DataSource( + sparkSession, + className = provider, + partitionColumns = partitionColumns, + options = options).write(mode, Dataset.ofRows(sparkSession, query)) + + Seq.empty[Row] + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 3ae5ce610d2a6ac0cc137361739afddf6bb80b28..9f27d06dcb366aea5f53f71a29127e53959b66b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql.util import scala.collection.mutable.ArrayBuffer import org.apache.spark._ -import org.apache.spark.sql.{functions, QueryTest} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} +import org.apache.spark.sql.{functions, AnalysisException, QueryTest} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, InsertIntoTable, LogicalPlan, Project} import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec} +import org.apache.spark.sql.execution.datasources.{CreateTable, SaveIntoDataSourceCommand} import org.apache.spark.sql.test.SharedSQLContext class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { @@ -159,4 +161,55 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { spark.listenerManager.unregister(listener) } + + test("execute callback functions for DataFrameWriter") { + val commands = ArrayBuffer.empty[(String, LogicalPlan)] + val exceptions = ArrayBuffer.empty[(String, Exception)] + val listener = new QueryExecutionListener { + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + exceptions += funcName -> exception + } + + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + commands += funcName -> qe.logical + } + } + spark.listenerManager.register(listener) + + withTempPath { path => + spark.range(10).write.format("json").save(path.getCanonicalPath) + assert(commands.length == 1) + assert(commands.head._1 == "save") + assert(commands.head._2.isInstanceOf[SaveIntoDataSourceCommand]) + assert(commands.head._2.asInstanceOf[SaveIntoDataSourceCommand].provider == "json") + } + + withTable("tab") { + sql("CREATE TABLE tab(i long) using parquet") + spark.range(10).write.insertInto("tab") + assert(commands.length == 2) + assert(commands(1)._1 == "insertInto") + assert(commands(1)._2.isInstanceOf[InsertIntoTable]) + assert(commands(1)._2.asInstanceOf[InsertIntoTable].table + .asInstanceOf[UnresolvedRelation].tableIdentifier.table == "tab") + } + + withTable("tab") { + spark.range(10).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("tab") + assert(commands.length == 3) + assert(commands(2)._1 == "saveAsTable") + assert(commands(2)._2.isInstanceOf[CreateTable]) + assert(commands(2)._2.asInstanceOf[CreateTable].tableDesc.partitionColumnNames == Seq("p")) + } + + withTable("tab") { + sql("CREATE TABLE tab(i long) using parquet") + val e = intercept[AnalysisException] { + spark.range(10).select($"id", $"id").write.insertInto("tab") + } + assert(exceptions.length == 1) + assert(exceptions.head._1 == "insertInto") + assert(exceptions.head._2 == e) + } + } }