From 10e526e7e63bbf19e464bde2f6c4e581cf6c7c45 Mon Sep 17 00:00:00 2001 From: Wenchen Fan <wenchen@databricks.com> Date: Tue, 30 May 2017 20:12:32 -0700 Subject: [PATCH] [SPARK-20213][SQL] Fix DataFrameWriter operations in SQL UI tab ## What changes were proposed in this pull request? Currently the `DataFrameWriter` operations have several problems: 1. non-file-format data source writing action doesn't show up in the SQL tab in Spark UI 2. file-format data source writing action shows a scan node in the SQL tab, without saying anything about writing. (streaming also have this issue, but not fixed in this PR) 3. Spark SQL CLI actions don't show up in the SQL tab. This PR fixes all of them, by refactoring the `ExecuteCommandExec` to make it have children. close https://github.com/apache/spark/pull/17540 ## How was this patch tested? existing tests. Also test the UI manually. For a simple command: `Seq(1 -> "a").toDF("i", "j").write.parquet("/tmp/qwe")` before this PR: <img width="266" alt="qq20170523-035840 2x" src="https://cloud.githubusercontent.com/assets/3182036/26326050/24e18ba2-3f6c-11e7-8817-6dd275bf6ac5.png"> after this PR: <img width="287" alt="qq20170523-035708 2x" src="https://cloud.githubusercontent.com/assets/3182036/26326054/2ad7f460-3f6c-11e7-8053-d68325beb28f.png"> Author: Wenchen Fan <wenchen@databricks.com> Closes #18064 from cloud-fan/execution. --- .../spark/sql/kafka010/KafkaWriter.scala | 10 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../sql/catalyst/plans/logical/Command.scala | 3 +- .../catalyst/plans/logical/Statistics.scala | 15 ++- .../apache/spark/sql/DataFrameWriter.scala | 12 +-- .../scala/org/apache/spark/sql/Dataset.scala | 48 +++++++-- .../spark/sql/execution/QueryExecution.scala | 7 +- .../spark/sql/execution/SQLExecution.scala | 13 +++ .../spark/sql/execution/SparkStrategies.scala | 2 +- .../execution/columnar/InMemoryRelation.scala | 2 +- .../columnar/InMemoryTableScanExec.scala | 2 +- .../command/AnalyzeColumnCommand.scala | 7 +- .../command/AnalyzeTableCommand.scala | 2 +- .../spark/sql/execution/command/cache.scala | 10 +- .../sql/execution/command/commands.scala | 24 +++-- .../command/createDataSourceTables.scala | 4 +- .../spark/sql/execution/command/views.scala | 4 +- .../execution/datasources/DataSource.scala | 61 +++++------- .../datasources/FileFormatWriter.scala | 98 +++++++++---------- .../InsertIntoDataSourceCommand.scala | 2 +- .../InsertIntoHadoopFsRelationCommand.scala | 10 +- .../SaveIntoDataSourceCommand.scala | 13 +-- .../datasources/csv/CSVDataSource.scala | 3 +- .../spark/sql/execution/datasources/ddl.scala | 2 +- .../datasources/jdbc/JDBCRelation.scala | 14 +-- .../datasources/jdbc/JdbcUtils.scala | 2 +- .../execution/streaming/FileStreamSink.scala | 2 +- .../execution/streaming/StreamExecution.scala | 6 +- .../sql/execution/streaming/console.scala | 4 +- .../sql/execution/streaming/memory.scala | 4 +- .../execution/metric/SQLMetricsSuite.scala | 5 +- .../sql/util/DataFrameCallbackSuite.scala | 27 ++--- .../hive/thriftserver/SparkSQLDriver.scala | 6 +- .../hive/execution/InsertIntoHiveTable.scala | 11 ++- .../apache/spark/sql/hive/test/TestHive.scala | 54 +++++----- .../hive/execution/HiveComparisonTest.scala | 6 +- .../sql/hive/execution/SQLQuerySuite.scala | 20 ++-- 37 files changed, 299 insertions(+), 218 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 0ed9d4e84d..5e9ae35b3f 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -85,12 +85,10 @@ private[kafka010] object KafkaWriter extends Logging { topic: Option[String] = None): Unit = { val schema = queryExecution.analyzed.output validateQuery(queryExecution, kafkaParameters, topic) - SQLExecution.withNewExecutionId(sparkSession, queryExecution) { - queryExecution.toRdd.foreachPartition { iter => - val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) - Utils.tryWithSafeFinally(block = writeTask.execute(iter))( - finallyBlock = writeTask.close()) - } + queryExecution.toRdd.foreachPartition { iter => + val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) + Utils.tryWithSafeFinally(block = writeTask.execute(iter))( + finallyBlock = writeTask.close()) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index d3f822bf7e..5ba043e17a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -357,7 +357,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT }) } - override protected def innerChildren: Seq[QueryPlan[_]] = subqueries + override def innerChildren: Seq[QueryPlan[_]] = subqueries /** * Returns a plan where a best effort attempt has been made to transform `this` in a way diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala index 38f47081b6..ec5766e1f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute * commands can be used by parsers to represent DDL operations. Commands, unlike queries, are * eagerly executed. */ -trait Command extends LeafNode { +trait Command extends LogicalPlan { override def output: Seq[Attribute] = Seq.empty + override def children: Seq[LogicalPlan] = Seq.empty } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index a64562b5db..ae5f1d1fc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -22,7 +22,8 @@ import java.math.{MathContext, RoundingMode} import scala.util.control.NonFatal import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -243,9 +244,9 @@ object ColumnStat extends Logging { } col.dataType match { - case _: IntegralType => fixedLenTypeStruct(LongType) + case dt: IntegralType => fixedLenTypeStruct(dt) case _: DecimalType => fixedLenTypeStruct(col.dataType) - case DoubleType | FloatType => fixedLenTypeStruct(DoubleType) + case dt @ (DoubleType | FloatType) => fixedLenTypeStruct(dt) case BooleanType => fixedLenTypeStruct(col.dataType) case DateType => fixedLenTypeStruct(col.dataType) case TimestampType => fixedLenTypeStruct(col.dataType) @@ -264,14 +265,12 @@ object ColumnStat extends Logging { } /** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */ - def rowToColumnStat(row: Row, attr: Attribute): ColumnStat = { + def rowToColumnStat(row: InternalRow, attr: Attribute): ColumnStat = { ColumnStat( distinctCount = BigInt(row.getLong(0)), // for string/binary min/max, get should return null - min = Option(row.get(1)) - .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply), - max = Option(row.get(2)) - .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply), + min = Option(row.get(1, attr.dataType)), + max = Option(row.get(2, attr.dataType)), nullCount = BigInt(row.getLong(3)), avgLen = row.getLong(4), maxLen = row.getLong(5) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b71c5eb843..255c4064eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogRelation, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation, SaveIntoDataSourceCommand} import org.apache.spark.sql.sources.BaseRelation @@ -231,12 +232,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { assertNotBucketed("save") runCommand(df.sparkSession, "save") { - SaveIntoDataSourceCommand( - query = df.logicalPlan, - provider = source, + DataSource( + sparkSession = df.sparkSession, + className = source, partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap, - mode = mode) + options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) } } @@ -607,7 +607,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { try { val start = System.nanoTime() // call `QueryExecution.toRDD` to trigger the execution of commands. - qe.toRdd + SQLExecution.withNewExecutionId(session, qe)(qe.toRdd) val end = System.nanoTime() session.listenerManager.onSuccess(name, qe, end - start) } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1cd6fda5ed..d5b4c82c35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -179,9 +179,9 @@ class Dataset[T] private[sql]( // to happen right away to let these side effects take place eagerly. queryExecution.analyzed match { case c: Command => - LocalRelation(c.output, queryExecution.executedPlan.executeCollect()) + LocalRelation(c.output, withAction("command", queryExecution)(_.executeCollect())) case u @ Union(children) if children.forall(_.isInstanceOf[Command]) => - LocalRelation(u.output, queryExecution.executedPlan.executeCollect()) + LocalRelation(u.output, withAction("command", queryExecution)(_.executeCollect())) case _ => queryExecution.analyzed } @@ -248,8 +248,13 @@ class Dataset[T] private[sql]( _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { val numRows = _numRows.max(0) val takeResult = toDF().take(numRows + 1) - val hasMoreData = takeResult.length > numRows - val data = takeResult.take(numRows) + showString(takeResult, numRows, truncate, vertical) + } + + private def showString( + dataWithOneMoreRow: Array[Row], numRows: Int, truncate: Int, vertical: Boolean): String = { + val hasMoreData = dataWithOneMoreRow.length > numRows + val data = dataWithOneMoreRow.take(numRows) lazy val timeZone = DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone) @@ -684,6 +689,18 @@ class Dataset[T] private[sql]( } else { println(showString(numRows, truncate = 0)) } + + // An internal version of `show`, which won't set execution id and trigger listeners. + private[sql] def showInternal(_numRows: Int, truncate: Boolean): Unit = { + val numRows = _numRows.max(0) + val takeResult = toDF().takeInternal(numRows + 1) + + if (truncate) { + println(showString(takeResult, numRows, truncate = 20, vertical = false)) + } else { + println(showString(takeResult, numRows, truncate = 0, vertical = false)) + } + } // scalastyle:on println /** @@ -2453,6 +2470,11 @@ class Dataset[T] private[sql]( */ def take(n: Int): Array[T] = head(n) + // An internal version of `take`, which won't set execution id and trigger listeners. + private[sql] def takeInternal(n: Int): Array[T] = { + collectFromPlan(limit(n).queryExecution.executedPlan) + } + /** * Returns the first `n` rows in the Dataset as a list. * @@ -2477,6 +2499,11 @@ class Dataset[T] private[sql]( */ def collect(): Array[T] = withAction("collect", queryExecution)(collectFromPlan) + // An internal version of `collect`, which won't set execution id and trigger listeners. + private[sql] def collectInternal(): Array[T] = { + collectFromPlan(queryExecution.executedPlan) + } + /** * Returns a Java list that contains all rows in this Dataset. * @@ -2518,6 +2545,11 @@ class Dataset[T] private[sql]( plan.executeCollect().head.getLong(0) } + // An internal version of `count`, which won't set execution id and trigger listeners. + private[sql] def countInternal(): Long = { + groupBy().count().queryExecution.executedPlan.executeCollect().head.getLong(0) + } + /** * Returns a new Dataset that has exactly `numPartitions` partitions. * @@ -2763,7 +2795,7 @@ class Dataset[T] private[sql]( createTempViewCommand(viewName, replace = true, global = true) } - private def createTempViewCommand( + private[spark] def createTempViewCommand( viewName: String, replace: Boolean, global: Boolean): CreateViewCommand = { @@ -2954,17 +2986,17 @@ class Dataset[T] private[sql]( } /** A convenient function to wrap a logical plan and produce a DataFrame. */ - @inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = { + @inline private def withPlan(logicalPlan: LogicalPlan): DataFrame = { Dataset.ofRows(sparkSession, logicalPlan) } /** A convenient function to wrap a logical plan and produce a Dataset. */ - @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { + @inline private def withTypedPlan[U : Encoder](logicalPlan: LogicalPlan): Dataset[U] = { Dataset(sparkSession, logicalPlan) } /** A convenient function to wrap a set based logical plan and produce a Dataset. */ - @inline private def withSetOperator[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { + @inline private def withSetOperator[U : Encoder](logicalPlan: LogicalPlan): Dataset[U] = { if (classTag.runtimeClass.isAssignableFrom(classOf[Row])) { // Set operators widen types (change the schema), so we cannot reuse the row encoder. Dataset.ofRows(sparkSession, logicalPlan).asInstanceOf[Dataset[U]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 2e05e5d659..1ba9a79446 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -113,10 +113,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { /** - * Returns the result as a hive compatible sequence of strings. This is for testing only. + * Returns the result as a hive compatible sequence of strings. This is used in tests and + * `SparkSQLDriver` for CLI applications. */ def hiveResultString(): Seq[String] = executedPlan match { - case ExecutedCommandExec(desc: DescribeTableCommand) => + case ExecutedCommandExec(desc: DescribeTableCommand, _) => // If it is a describe command for a Hive table, we want to have the output format // be similar with Hive. desc.run(sparkSession).map { @@ -127,7 +128,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { .mkString("\t") } // SHOW TABLES in Hive only output table names, while ours output database, table name, isTemp. - case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended => + case command @ ExecutedCommandExec(s: ShowTablesCommand, _) if !s.isExtended => command.executeCollect().map(_.getString(1)) case other => val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index be35916e34..bb206e8432 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -39,6 +39,19 @@ object SQLExecution { executionIdToQueryExecution.get(executionId) } + private val testing = sys.props.contains("spark.testing") + + private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = { + // only throw an exception during tests. a missing execution ID should not fail a job. + if (testing && sparkSession.sparkContext.getLocalProperty(EXECUTION_ID_KEY) == null) { + // Attention testers: when a test fails with this exception, it means that the action that + // started execution of a query didn't call withNewExecutionId. The execution ID should be + // set by calling withNewExecutionId in the action that begins execution, like + // Dataset.collect or DataFrameWriter.insertInto. + throw new IllegalStateException("Execution ID should be set") + } + } + /** * Wrap an action that will execute "queryExecution" to track all Spark jobs in the body so that * we can connect them with an execution. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 843ce63161..f13294c925 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -346,7 +346,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case r: RunnableCommand => ExecutedCommandExec(r) :: Nil + case r: RunnableCommand => ExecutedCommandExec(r, r.children.map(planLater)) :: Nil case MemoryPlan(sink, output) => val encoder = RowEncoder(sink.schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 3486a6bce8..456a8f3b20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -64,7 +64,7 @@ case class InMemoryRelation( val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator) extends logical.LeafNode with MultiInstanceRelation { - override protected def innerChildren: Seq[SparkPlan] = Seq(child) + override def innerChildren: Seq[SparkPlan] = Seq(child) override def producedAttributes: AttributeSet = outputSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 7063b08f7c..1d601374de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -34,7 +34,7 @@ case class InMemoryTableScanExec( @transient relation: InMemoryRelation) extends LeafExecNode { - override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren + override def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 0d8db2ff5d..2de14c90ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableTyp import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.QueryExecution /** @@ -96,11 +97,13 @@ case class AnalyzeColumnCommand( attributesToAnalyze.map(ColumnStat.statExprs(_, ndvMaxErr)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) - val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head() + val statsRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExpressions, relation)) + .executedPlan.executeTake(1).head val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) => - (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1), attr)) + // according to `ColumnStat.statExprs`, the stats struct always have 6 fields. + (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1, 6), attr)) }.toMap (rowCount, columnStats) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index d2ea0cdf61..3183c7911b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -56,7 +56,7 @@ case class AnalyzeTableCommand( // 2. when total size is changed, `oldRowCount` becomes invalid. // This is to make sure that we only record the right statistics. if (!noscan) { - val newRowCount = sparkSession.table(tableIdentWithDB).count() + val newRowCount = sparkSession.table(tableIdentWithDB).countInternal() if (newRowCount >= 0 && newRowCount != oldRowCount) { newStats = if (newStats.isDefined) { newStats.map(_.copy(rowCount = Some(BigInt(newRowCount)))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 336f14dd97..184d0387eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -30,19 +30,19 @@ case class CacheTableCommand( require(plan.isEmpty || tableIdent.database.isEmpty, "Database name is not allowed in CACHE TABLE AS SELECT") - override protected def innerChildren: Seq[QueryPlan[_]] = { - plan.toSeq - } + override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq override def run(sparkSession: SparkSession): Seq[Row] = { plan.foreach { logicalPlan => - Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString) + Dataset.ofRows(sparkSession, logicalPlan) + .createTempViewCommand(tableIdent.quotedString, replace = false, global = false) + .run(sparkSession) } sparkSession.catalog.cacheTable(tableIdent.quotedString) if (!isLazy) { // Performs eager caching - sparkSession.table(tableIdent).count() + sparkSession.table(tableIdent).countInternal() } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 41d91d877d..99d81c49f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -22,8 +22,7 @@ import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.{logical, QueryPlan} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.debug._ @@ -36,14 +35,20 @@ import org.apache.spark.sql.types._ * wrapped in `ExecutedCommand` during execution. */ trait RunnableCommand extends logical.Command { - def run(sparkSession: SparkSession): Seq[Row] + def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { + throw new NotImplementedError + } + + def run(sparkSession: SparkSession): Seq[Row] = { + throw new NotImplementedError + } } /** * A physical operator that executes the run method of a `RunnableCommand` and * saves the result to prevent multiple executions. */ -case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan { +case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) extends SparkPlan { /** * A concrete command should override this lazy field to wrap up any side effects caused by the * command or any other computation that should be evaluated exactly once. The value of this field @@ -55,14 +60,19 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan { */ protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { val converter = CatalystTypeConverters.createToCatalystConverter(schema) - cmd.run(sqlContext.sparkSession).map(converter(_).asInstanceOf[InternalRow]) + val rows = if (children.isEmpty) { + cmd.run(sqlContext.sparkSession) + } else { + cmd.run(sqlContext.sparkSession, children) + } + rows.map(converter(_).asInstanceOf[InternalRow]) } - override protected def innerChildren: Seq[QueryPlan[_]] = cmd :: Nil + override def innerChildren: Seq[QueryPlan[_]] = cmd.innerChildren override def output: Seq[Attribute] = cmd.output - override def children: Seq[SparkPlan] = Nil + override def nodeName: String = cmd.nodeName override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 2d890118ae..729bd39d82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -122,7 +122,7 @@ case class CreateDataSourceTableAsSelectCommand( query: LogicalPlan) extends RunnableCommand { - override protected def innerChildren: Seq[LogicalPlan] = Seq(query) + override def innerChildren: Seq[LogicalPlan] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { assert(table.tableType != CatalogTableType.VIEW) @@ -195,7 +195,7 @@ case class CreateDataSourceTableAsSelectCommand( catalogTable = if (tableExists) Some(table) else None) try { - dataSource.writeAndRead(mode, Dataset.ofRows(session, query)) + dataSource.writeAndRead(mode, query) } catch { case ex: AnalysisException => logError(s"Failed to write to table ${table.identifier.unquotedString}", ex) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 00f0acab21..1945d68241 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -97,7 +97,7 @@ case class CreateViewCommand( import ViewHelper._ - override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child) + override def innerChildren: Seq[QueryPlan[_]] = Seq(child) if (viewType == PersistedView) { require(originalText.isDefined, "'originalText' must be provided to create permanent view") @@ -264,7 +264,7 @@ case class AlterViewAsCommand( import ViewHelper._ - override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(session: SparkSession): Seq[Row] = { // If the plan cannot be analyzed, throw an exception and don't proceed. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 9fce29b06b..958715eefa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -28,8 +28,9 @@ import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider @@ -388,9 +389,10 @@ case class DataSource( } /** - * Writes the given [[DataFrame]] out in this [[FileFormat]]. + * Writes the given [[LogicalPlan]] out in this [[FileFormat]]. */ - private def writeInFileFormat(format: FileFormat, mode: SaveMode, data: DataFrame): Unit = { + private def planForWritingFileFormat( + format: FileFormat, mode: SaveMode, data: LogicalPlan): LogicalPlan = { // Don't glob path for the write path. The contracts here are: // 1. Only one output path can be specified on the write path; // 2. Output path must be a legal HDFS style file system path; @@ -408,16 +410,6 @@ case class DataSource( val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive) - // SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does - // not need to have the query as child, to avoid to analyze an optimized query, - // because InsertIntoHadoopFsRelationCommand will be optimized first. - val partitionAttributes = partitionColumns.map { name => - val plan = data.logicalPlan - plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse { - throw new AnalysisException( - s"Unable to resolve $name given [${plan.output.map(_.name).mkString(", ")}]") - }.asInstanceOf[Attribute] - } val fileIndex = catalogTable.map(_.identifier).map { tableIdent => sparkSession.table(tableIdent).queryExecution.analyzed.collect { case LogicalRelation(t: HadoopFsRelation, _, _) => t.location @@ -426,36 +418,35 @@ case class DataSource( // For partitioned relation r, r.schema's column ordering can be different from the column // ordering of data.logicalPlan (partition columns are all moved after data column). This // will be adjusted within InsertIntoHadoopFsRelation. - val plan = - InsertIntoHadoopFsRelationCommand( - outputPath = outputPath, - staticPartitions = Map.empty, - ifPartitionNotExists = false, - partitionColumns = partitionAttributes, - bucketSpec = bucketSpec, - fileFormat = format, - options = options, - query = data.logicalPlan, - mode = mode, - catalogTable = catalogTable, - fileIndex = fileIndex) - sparkSession.sessionState.executePlan(plan).toRdd + InsertIntoHadoopFsRelationCommand( + outputPath = outputPath, + staticPartitions = Map.empty, + ifPartitionNotExists = false, + partitionColumns = partitionColumns.map(UnresolvedAttribute.quoted), + bucketSpec = bucketSpec, + fileFormat = format, + options = options, + query = data, + mode = mode, + catalogTable = catalogTable, + fileIndex = fileIndex) } /** - * Writes the given [[DataFrame]] out to this [[DataSource]] and returns a [[BaseRelation]] for + * Writes the given [[LogicalPlan]] out to this [[DataSource]] and returns a [[BaseRelation]] for * the following reading. */ - def writeAndRead(mode: SaveMode, data: DataFrame): BaseRelation = { + def writeAndRead(mode: SaveMode, data: LogicalPlan): BaseRelation = { if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } providingClass.newInstance() match { case dataSource: CreatableRelationProvider => - dataSource.createRelation(sparkSession.sqlContext, mode, caseInsensitiveOptions, data) + dataSource.createRelation( + sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data)) case format: FileFormat => - writeInFileFormat(format, mode, data) + sparkSession.sessionState.executePlan(planForWritingFileFormat(format, mode, data)).toRdd // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() case _ => @@ -464,18 +455,18 @@ case class DataSource( } /** - * Writes the given [[DataFrame]] out to this [[DataSource]]. + * Returns a logical plan to write the given [[LogicalPlan]] out to this [[DataSource]]. */ - def write(mode: SaveMode, data: DataFrame): Unit = { + def planForWriting(mode: SaveMode, data: LogicalPlan): LogicalPlan = { if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } providingClass.newInstance() match { case dataSource: CreatableRelationProvider => - dataSource.createRelation(sparkSession.sqlContext, mode, caseInsensitiveOptions, data) + SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode) case format: FileFormat => - writeInFileFormat(format, mode, data) + planForWritingFileFormat(format, mode, data) case _ => sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 4ec09bff42..afe454f714 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -38,8 +38,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution} -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.types.StringType import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -96,7 +96,7 @@ object FileFormatWriter extends Logging { */ def write( sparkSession: SparkSession, - queryExecution: QueryExecution, + plan: SparkPlan, fileFormat: FileFormat, committer: FileCommitProtocol, outputSpec: OutputSpec, @@ -111,9 +111,9 @@ object FileFormatWriter extends Logging { job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) - val allColumns = queryExecution.logical.output + val allColumns = plan.output val partitionSet = AttributeSet(partitionColumns) - val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains) + val dataColumns = allColumns.filterNot(partitionSet.contains) val bucketIdExpression = bucketSpec.map { spec => val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) @@ -151,7 +151,7 @@ object FileFormatWriter extends Logging { // We should first sort by partition columns, then bucket id, and finally sorting columns. val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns // the sort order doesn't matter - val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child) + val actualOrdering = plan.outputOrdering.map(_.child) val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { false } else { @@ -161,50 +161,50 @@ object FileFormatWriter extends Logging { } } - SQLExecution.withNewExecutionId(sparkSession, queryExecution) { - // This call shouldn't be put into the `try` block below because it only initializes and - // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - committer.setupJob(job) - - try { - val rdd = if (orderingMatched) { - queryExecution.toRdd - } else { - SortExec( - requiredOrdering.map(SortOrder(_, Ascending)), - global = false, - child = queryExecution.executedPlan).execute() - } - val ret = new Array[WriteTaskResult](rdd.partitions.length) - sparkSession.sparkContext.runJob( - rdd, - (taskContext: TaskContext, iter: Iterator[InternalRow]) => { - executeTask( - description = description, - sparkStageId = taskContext.stageId(), - sparkPartitionId = taskContext.partitionId(), - sparkAttemptNumber = taskContext.attemptNumber(), - committer, - iterator = iter) - }, - 0 until rdd.partitions.length, - (index, res: WriteTaskResult) => { - committer.onTaskCommit(res.commitMsg) - ret(index) = res - }) - - val commitMsgs = ret.map(_.commitMsg) - val updatedPartitions = ret.flatMap(_.updatedPartitions) - .distinct.map(PartitioningUtils.parsePathFragment) - - committer.commitJob(job, commitMsgs) - logInfo(s"Job ${job.getJobID} committed.") - refreshFunction(updatedPartitions) - } catch { case cause: Throwable => - logError(s"Aborting job ${job.getJobID}.", cause) - committer.abortJob(job) - throw new SparkException("Job aborted.", cause) + SQLExecution.checkSQLExecutionId(sparkSession) + + // This call shouldn't be put into the `try` block below because it only initializes and + // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. + committer.setupJob(job) + + try { + val rdd = if (orderingMatched) { + plan.execute() + } else { + SortExec( + requiredOrdering.map(SortOrder(_, Ascending)), + global = false, + child = plan).execute() } + val ret = new Array[WriteTaskResult](rdd.partitions.length) + sparkSession.sparkContext.runJob( + rdd, + (taskContext: TaskContext, iter: Iterator[InternalRow]) => { + executeTask( + description = description, + sparkStageId = taskContext.stageId(), + sparkPartitionId = taskContext.partitionId(), + sparkAttemptNumber = taskContext.attemptNumber(), + committer, + iterator = iter) + }, + 0 until rdd.partitions.length, + (index, res: WriteTaskResult) => { + committer.onTaskCommit(res.commitMsg) + ret(index) = res + }) + + val commitMsgs = ret.map(_.commitMsg) + val updatedPartitions = ret.flatMap(_.updatedPartitions) + .distinct.map(PartitioningUtils.parsePathFragment) + + committer.commitJob(job, commitMsgs) + logInfo(s"Job ${job.getJobID} committed.") + refreshFunction(updatedPartitions) + } catch { case cause: Throwable => + logError(s"Aborting job ${job.getJobID}.", cause) + committer.abortJob(job) + throw new SparkException("Job aborted.", cause) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index a813829d50..08b2f4f311 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -33,7 +33,7 @@ case class InsertIntoDataSourceCommand( overwrite: Boolean) extends RunnableCommand { - override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index c9d31449d3..00aa124088 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogT import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ /** @@ -53,12 +54,13 @@ case class InsertIntoHadoopFsRelationCommand( catalogTable: Option[CatalogTable], fileIndex: Option[FileIndex]) extends RunnableCommand { - import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName - override protected def innerChildren: Seq[LogicalPlan] = query :: Nil + override def children: Seq[LogicalPlan] = query :: Nil + + override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { + assert(children.length == 1) - override def run(sparkSession: SparkSession): Seq[Row] = { // Most formats don't do well with duplicate columns, so lets not allow that if (query.schema.fieldNames.length != query.schema.fieldNames.distinct.length) { val duplicateColumns = query.schema.fieldNames.groupBy(identity).collect { @@ -144,7 +146,7 @@ case class InsertIntoHadoopFsRelationCommand( FileFormatWriter.write( sparkSession = sparkSession, - queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, + plan = children.head, fileFormat = fileFormat, committer = committer, outputSpec = FileFormatWriter.OutputSpec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 6f19ea195c..5eb6a8471b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.sources.CreatableRelationProvider /** * Saves the results of `query` in to a data source. @@ -33,19 +34,15 @@ import org.apache.spark.sql.execution.command.RunnableCommand */ case class SaveIntoDataSourceCommand( query: LogicalPlan, - provider: String, - partitionColumns: Seq[String], + dataSource: CreatableRelationProvider, options: Map[String, String], mode: SaveMode) extends RunnableCommand { - override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { - DataSource( - sparkSession, - className = provider, - partitionColumns = partitionColumns, - options = options).write(mode, Dataset.ofRows(sparkSession, query)) + dataSource.createRelation( + sparkSession.sqlContext, mode, options, Dataset.ofRows(sparkSession, query)) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 83bdf6fe22..76f121c0c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -144,7 +144,8 @@ object TextInputCSVDataSource extends CSVDataSource { inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): StructType = { val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) - val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption + val maybeFirstLine = + CSVUtils.filterCommentAndEmpty(csv, parsedOptions).takeInternal(1).headOption inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index f8d4a9bb5b..fdc5e85f3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -76,7 +76,7 @@ case class CreateTempViewUsing( CatalogUtils.maskCredentials(options) } - def run(sparkSession: SparkSession): Seq[Row] = { + override def run(sparkSession: SparkSession): Seq[Row] = { if (provider.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, " + "you can't use it with CREATE TEMP VIEW USING") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 8b45dba04d..a06f1ce328 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -129,12 +129,14 @@ private[sql] case class JDBCRelation( } override def insert(data: DataFrame, overwrite: Boolean): Unit = { - val url = jdbcOptions.url - val table = jdbcOptions.table - val properties = jdbcOptions.asProperties - data.write - .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) - .jdbc(url, table, properties) + import scala.collection.JavaConverters._ + + val options = jdbcOptions.asProperties.asScala + + ("url" -> jdbcOptions.url, "dbtable" -> jdbcOptions.table) + val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append + + new JdbcRelationProvider().createRelation( + data.sparkSession.sqlContext, mode, options.toMap, data) } override def toString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 71eaab119d..ca61c2efe2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -788,7 +788,7 @@ object JdbcUtils extends Logging { case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n) case _ => df } - repartitionedDF.foreachPartition(iterator => savePartition( + repartitionedDF.rdd.foreachPartition(iterator => savePartition( getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 6885d0bf67..96225ecffa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -122,7 +122,7 @@ class FileStreamSink( FileFormatWriter.write( sparkSession = sparkSession, - queryExecution = data.queryExecution, + plan = data.queryExecution.executedPlan, fileFormat = fileFormat, committer = committer, outputSpec = FileFormatWriter.OutputSpec(path, Map.empty), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index b6ddf7437e..ab8608563c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ @@ -655,7 +655,9 @@ class StreamExecution( new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema)) reportTimeTaken("addBatch") { - sink.addBatch(currentBatchId, nextBatch) + SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { + sink.addBatch(currentBatchId, nextBatch) + } } awaitBatchLock.lock() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index e8b9712d19..38c6319110 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -46,8 +46,8 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging { println("-------------------------------------------") // scalastyle:off println data.sparkSession.createDataFrame( - data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) - .show(numRowsToShow, isTruncated) + data.sparkSession.sparkContext.parallelize(data.collectInternal()), data.schema) + .showInternal(numRowsToShow, isTruncated) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 971ce5afb1..7eaa803a9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -196,11 +196,11 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - val rows = AddedData(batchId, data.collect()) + val rows = AddedData(batchId, data.collectInternal()) synchronized { batches += rows } case Complete => - val rows = AddedData(batchId, data.collect()) + val rows = AddedData(batchId, data.collectInternal()) synchronized { batches.clear() batches += rows diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index e544245588..a4e62f1d16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -290,10 +290,13 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => + // person creates a temporary view. get the DF before listing previous execution IDs + val data = person.select('name) + sparkContext.listenerBus.waitUntilEmpty(10000) val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet // Assume the execution plan is // PhysicalRDD(nodeId = 0) - person.select('name).write.format("json").save(file.getAbsolutePath) + data.write.format("json").save(file.getAbsolutePath) sparkContext.listenerBus.waitUntilEmpty(10000) val executionIds = spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 7c9ea7d393..a239e39d9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.{functions, AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, InsertIntoTable, LogicalPlan, Project} import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec} -import org.apache.spark.sql.execution.datasources.{CreateTable, SaveIntoDataSourceCommand} +import org.apache.spark.sql.execution.datasources.{CreateTable, InsertIntoHadoopFsRelationCommand} +import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.test.SharedSQLContext class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { @@ -178,26 +179,28 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { spark.range(10).write.format("json").save(path.getCanonicalPath) assert(commands.length == 1) assert(commands.head._1 == "save") - assert(commands.head._2.isInstanceOf[SaveIntoDataSourceCommand]) - assert(commands.head._2.asInstanceOf[SaveIntoDataSourceCommand].provider == "json") + assert(commands.head._2.isInstanceOf[InsertIntoHadoopFsRelationCommand]) + assert(commands.head._2.asInstanceOf[InsertIntoHadoopFsRelationCommand] + .fileFormat.isInstanceOf[JsonFileFormat]) } withTable("tab") { - sql("CREATE TABLE tab(i long) using parquet") + sql("CREATE TABLE tab(i long) using parquet") // adds commands(1) via onSuccess spark.range(10).write.insertInto("tab") - assert(commands.length == 2) - assert(commands(1)._1 == "insertInto") - assert(commands(1)._2.isInstanceOf[InsertIntoTable]) - assert(commands(1)._2.asInstanceOf[InsertIntoTable].table + assert(commands.length == 3) + assert(commands(2)._1 == "insertInto") + assert(commands(2)._2.isInstanceOf[InsertIntoTable]) + assert(commands(2)._2.asInstanceOf[InsertIntoTable].table .asInstanceOf[UnresolvedRelation].tableIdentifier.table == "tab") } + // exiting withTable adds commands(3) via onSuccess (drops tab) withTable("tab") { spark.range(10).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("tab") - assert(commands.length == 3) - assert(commands(2)._1 == "saveAsTable") - assert(commands(2)._2.isInstanceOf[CreateTable]) - assert(commands(2)._2.asInstanceOf[CreateTable].tableDesc.partitionColumnNames == Seq("p")) + assert(commands.length == 5) + assert(commands(4)._1 == "saveAsTable") + assert(commands(4)._2.isInstanceOf[CreateTable]) + assert(commands(4)._2.asInstanceOf[CreateTable].tableDesc.partitionColumnNames == Seq("p")) } withTable("tab") { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 0d5dc7af5f..6775902173 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SQLContext} -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlContext) @@ -60,7 +60,9 @@ private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlCont try { context.sparkContext.setJobDescription(command) val execution = context.sessionState.executePlan(context.sql(command).logicalPlan) - hiveResponse = execution.hiveResultString() + hiveResponse = SQLExecution.withNewExecutionId(context.sparkSession, execution) { + execution.hiveResultString() + } tableSchema = getResultSetSchema(execution) new CommandProcessorResponse(0) } catch { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 10ce8e3730..392b7cfaa8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -32,10 +32,11 @@ import org.apache.hadoop.hive.ql.ErrorMsg import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.internal.io.FileCommitProtocol -import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormatWriter import org.apache.spark.sql.hive._ @@ -81,7 +82,7 @@ case class InsertIntoHiveTable( overwrite: Boolean, ifPartitionNotExists: Boolean) extends RunnableCommand { - override protected def innerChildren: Seq[LogicalPlan] = query :: Nil + override def children: Seq[LogicalPlan] = query :: Nil var createdTempDir: Option[Path] = None @@ -230,7 +231,9 @@ case class InsertIntoHiveTable( * `org.apache.hadoop.hive.serde2.SerDe` and the * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition. */ - override def run(sparkSession: SparkSession): Seq[Row] = { + override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { + assert(children.length == 1) + val sessionState = sparkSession.sessionState val externalCatalog = sparkSession.sharedState.externalCatalog val hiveVersion = externalCatalog.asInstanceOf[HiveExternalCatalog].client.version @@ -344,7 +347,7 @@ case class InsertIntoHiveTable( FileFormatWriter.write( sparkSession = sparkSession, - queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, + plan = children.head, fileFormat = new HiveFileFormat(fileSinkConf), committer = committer, outputSpec = FileFormatWriter.OutputSpec(tmpLocation.toString, Map.empty), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index e1534c797d..4e1792321c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -34,8 +34,8 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClient @@ -294,23 +294,23 @@ private[hive] class TestHiveSparkSession( "CREATE TABLE src1 (key INT, value STRING)".cmd, s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), TestTable("srcpart", () => { - sql( - "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") + "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)" + .cmd.apply() for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { - sql( - s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' - |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') - """.stripMargin) + s""" + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' + |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') + """.stripMargin.cmd.apply() } }), TestTable("srcpart1", () => { - sql( - "CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") + "CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)" + .cmd.apply() for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) { - sql( - s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' - |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') - """.stripMargin) + s""" + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' + |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') + """.stripMargin.cmd.apply() } }), TestTable("src_thrift", () => { @@ -318,8 +318,7 @@ private[hive] class TestHiveSparkSession( import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat} import org.apache.thrift.protocol.TBinaryProtocol - sql( - s""" + s""" |CREATE TABLE src_thrift(fake INT) |ROW FORMAT SERDE '${classOf[ThriftDeserializer].getName}' |WITH SERDEPROPERTIES( @@ -329,13 +328,12 @@ private[hive] class TestHiveSparkSession( |STORED AS |INPUTFORMAT '${classOf[SequenceFileInputFormat[_, _]].getName}' |OUTPUTFORMAT '${classOf[SequenceFileOutputFormat[_, _]].getName}' - """.stripMargin) + """.stripMargin.cmd.apply() - sql( - s""" - |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/complex.seq")}' - |INTO TABLE src_thrift - """.stripMargin) + s""" + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/complex.seq")}' + |INTO TABLE src_thrift + """.stripMargin.cmd.apply() }), TestTable("serdeins", s"""CREATE TABLE serdeins (key INT, value STRING) @@ -458,7 +456,17 @@ private[hive] class TestHiveSparkSession( logDebug(s"Loading test table $name") val createCmds = testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) - createCmds.foreach(_()) + + // test tables are loaded lazily, so they may be loaded in the middle a query execution which + // has already set the execution id. + if (sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) == null) { + // We don't actually have a `QueryExecution` here, use a fake one instead. + SQLExecution.withNewExecutionId(this, new QueryExecution(this, OneRowRelation)) { + createCmds.foreach(_()) + } + } else { + createCmds.foreach(_()) + } if (cacheTables) { new SQLContext(self).cacheTable(name) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 98aa92a9bb..cee82cda46 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} @@ -341,7 +342,10 @@ abstract class HiveComparisonTest // Run w/ catalyst val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => val query = new TestHiveQueryExecution(queryString.replace("../../data", testDataPath)) - try { (query, prepareAnswer(query, query.hiveResultString())) } catch { + def getResult(): Seq[String] = { + SQLExecution.withNewExecutionId(query.sparkSession, query)(query.hiveResultString()) + } + try { (query, prepareAnswer(query, getResult())) } catch { case e: Throwable => val errorMessage = s""" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c944f28d10..da7a0645db 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -965,14 +965,20 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("sanity test for SPARK-6618") { - (1 to 100).par.map { i => - val tableName = s"SPARK_6618_table_$i" - sql(s"CREATE TABLE $tableName (col1 string)") - sessionState.catalog.lookupRelation(TableIdentifier(tableName)) - table(tableName) - tables() - sql(s"DROP TABLE $tableName") + val threads: Seq[Thread] = (1 to 10).map { i => + new Thread("test-thread-" + i) { + override def run(): Unit = { + val tableName = s"SPARK_6618_table_$i" + sql(s"CREATE TABLE $tableName (col1 string)") + sessionState.catalog.lookupRelation(TableIdentifier(tableName)) + table(tableName) + tables() + sql(s"DROP TABLE $tableName") + } + } } + threads.foreach(_.start()) + threads.foreach(_.join(10000)) } test("SPARK-5203 union with different decimal precision") { -- GitLab