diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 0ed9d4e84d54da6a1d69dba6a2e074fb285cfbcf..5e9ae35b3f008d2fba08e2c5c6a615dd99514aa6 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -85,12 +85,10 @@ private[kafka010] object KafkaWriter extends Logging { topic: Option[String] = None): Unit = { val schema = queryExecution.analyzed.output validateQuery(queryExecution, kafkaParameters, topic) - SQLExecution.withNewExecutionId(sparkSession, queryExecution) { - queryExecution.toRdd.foreachPartition { iter => - val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) - Utils.tryWithSafeFinally(block = writeTask.execute(iter))( - finallyBlock = writeTask.close()) - } + queryExecution.toRdd.foreachPartition { iter => + val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) + Utils.tryWithSafeFinally(block = writeTask.execute(iter))( + finallyBlock = writeTask.close()) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index d3f822bf7eb0e99c29b73e7a261d2306f496d4a3..5ba043e17a12847c3e3f07c295c8e7753896570a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -357,7 +357,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT }) } - override protected def innerChildren: Seq[QueryPlan[_]] = subqueries + override def innerChildren: Seq[QueryPlan[_]] = subqueries /** * Returns a plan where a best effort attempt has been made to transform `this` in a way diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala index 38f47081b6f55c4cc4d01f4a1d1df1ec989febe7..ec5766e1f67f261ae0e673c71684ff51661d9c07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute * commands can be used by parsers to represent DDL operations. Commands, unlike queries, are * eagerly executed. */ -trait Command extends LeafNode { +trait Command extends LogicalPlan { override def output: Seq[Attribute] = Seq.empty + override def children: Seq[LogicalPlan] = Seq.empty } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index a64562b5dbd93ae6c779b91478e0e8b7751b8491..ae5f1d1fc4f8318079ced9466fa21932ad653b08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -22,7 +22,8 @@ import java.math.{MathContext, RoundingMode} import scala.util.control.NonFatal import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -243,9 +244,9 @@ object ColumnStat extends Logging { } col.dataType match { - case _: IntegralType => fixedLenTypeStruct(LongType) + case dt: IntegralType => fixedLenTypeStruct(dt) case _: DecimalType => fixedLenTypeStruct(col.dataType) - case DoubleType | FloatType => fixedLenTypeStruct(DoubleType) + case dt @ (DoubleType | FloatType) => fixedLenTypeStruct(dt) case BooleanType => fixedLenTypeStruct(col.dataType) case DateType => fixedLenTypeStruct(col.dataType) case TimestampType => fixedLenTypeStruct(col.dataType) @@ -264,14 +265,12 @@ object ColumnStat extends Logging { } /** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */ - def rowToColumnStat(row: Row, attr: Attribute): ColumnStat = { + def rowToColumnStat(row: InternalRow, attr: Attribute): ColumnStat = { ColumnStat( distinctCount = BigInt(row.getLong(0)), // for string/binary min/max, get should return null - min = Option(row.get(1)) - .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply), - max = Option(row.get(2)) - .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply), + min = Option(row.get(1, attr.dataType)), + max = Option(row.get(2, attr.dataType)), nullCount = BigInt(row.getLong(3)), avgLen = row.getLong(4), maxLen = row.getLong(5) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b71c5eb843eec6a61fcc3f82fe4921d97bce8f50..255c4064eb574e67672d0aa7dd3a8bab3fd2729e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogRelation, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation, SaveIntoDataSourceCommand} import org.apache.spark.sql.sources.BaseRelation @@ -231,12 +232,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { assertNotBucketed("save") runCommand(df.sparkSession, "save") { - SaveIntoDataSourceCommand( - query = df.logicalPlan, - provider = source, + DataSource( + sparkSession = df.sparkSession, + className = source, partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap, - mode = mode) + options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) } } @@ -607,7 +607,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { try { val start = System.nanoTime() // call `QueryExecution.toRDD` to trigger the execution of commands. - qe.toRdd + SQLExecution.withNewExecutionId(session, qe)(qe.toRdd) val end = System.nanoTime() session.listenerManager.onSuccess(name, qe, end - start) } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1cd6fda5edc8797374a79bfefbb93800fc82624d..d5b4c82c3558b11752cd7c1e25f320aeae94a842 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -179,9 +179,9 @@ class Dataset[T] private[sql]( // to happen right away to let these side effects take place eagerly. queryExecution.analyzed match { case c: Command => - LocalRelation(c.output, queryExecution.executedPlan.executeCollect()) + LocalRelation(c.output, withAction("command", queryExecution)(_.executeCollect())) case u @ Union(children) if children.forall(_.isInstanceOf[Command]) => - LocalRelation(u.output, queryExecution.executedPlan.executeCollect()) + LocalRelation(u.output, withAction("command", queryExecution)(_.executeCollect())) case _ => queryExecution.analyzed } @@ -248,8 +248,13 @@ class Dataset[T] private[sql]( _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { val numRows = _numRows.max(0) val takeResult = toDF().take(numRows + 1) - val hasMoreData = takeResult.length > numRows - val data = takeResult.take(numRows) + showString(takeResult, numRows, truncate, vertical) + } + + private def showString( + dataWithOneMoreRow: Array[Row], numRows: Int, truncate: Int, vertical: Boolean): String = { + val hasMoreData = dataWithOneMoreRow.length > numRows + val data = dataWithOneMoreRow.take(numRows) lazy val timeZone = DateTimeUtils.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone) @@ -684,6 +689,18 @@ class Dataset[T] private[sql]( } else { println(showString(numRows, truncate = 0)) } + + // An internal version of `show`, which won't set execution id and trigger listeners. + private[sql] def showInternal(_numRows: Int, truncate: Boolean): Unit = { + val numRows = _numRows.max(0) + val takeResult = toDF().takeInternal(numRows + 1) + + if (truncate) { + println(showString(takeResult, numRows, truncate = 20, vertical = false)) + } else { + println(showString(takeResult, numRows, truncate = 0, vertical = false)) + } + } // scalastyle:on println /** @@ -2453,6 +2470,11 @@ class Dataset[T] private[sql]( */ def take(n: Int): Array[T] = head(n) + // An internal version of `take`, which won't set execution id and trigger listeners. + private[sql] def takeInternal(n: Int): Array[T] = { + collectFromPlan(limit(n).queryExecution.executedPlan) + } + /** * Returns the first `n` rows in the Dataset as a list. * @@ -2477,6 +2499,11 @@ class Dataset[T] private[sql]( */ def collect(): Array[T] = withAction("collect", queryExecution)(collectFromPlan) + // An internal version of `collect`, which won't set execution id and trigger listeners. + private[sql] def collectInternal(): Array[T] = { + collectFromPlan(queryExecution.executedPlan) + } + /** * Returns a Java list that contains all rows in this Dataset. * @@ -2518,6 +2545,11 @@ class Dataset[T] private[sql]( plan.executeCollect().head.getLong(0) } + // An internal version of `count`, which won't set execution id and trigger listeners. + private[sql] def countInternal(): Long = { + groupBy().count().queryExecution.executedPlan.executeCollect().head.getLong(0) + } + /** * Returns a new Dataset that has exactly `numPartitions` partitions. * @@ -2763,7 +2795,7 @@ class Dataset[T] private[sql]( createTempViewCommand(viewName, replace = true, global = true) } - private def createTempViewCommand( + private[spark] def createTempViewCommand( viewName: String, replace: Boolean, global: Boolean): CreateViewCommand = { @@ -2954,17 +2986,17 @@ class Dataset[T] private[sql]( } /** A convenient function to wrap a logical plan and produce a DataFrame. */ - @inline private def withPlan(logicalPlan: => LogicalPlan): DataFrame = { + @inline private def withPlan(logicalPlan: LogicalPlan): DataFrame = { Dataset.ofRows(sparkSession, logicalPlan) } /** A convenient function to wrap a logical plan and produce a Dataset. */ - @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { + @inline private def withTypedPlan[U : Encoder](logicalPlan: LogicalPlan): Dataset[U] = { Dataset(sparkSession, logicalPlan) } /** A convenient function to wrap a set based logical plan and produce a Dataset. */ - @inline private def withSetOperator[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { + @inline private def withSetOperator[U : Encoder](logicalPlan: LogicalPlan): Dataset[U] = { if (classTag.runtimeClass.isAssignableFrom(classOf[Row])) { // Set operators widen types (change the schema), so we cannot reuse the row encoder. Dataset.ofRows(sparkSession, logicalPlan).asInstanceOf[Dataset[U]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 2e05e5d65923c6d6f8e35606c6c52518016fd77b..1ba9a79446aad1695f57ed74769fa05a99b2526a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -113,10 +113,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { /** - * Returns the result as a hive compatible sequence of strings. This is for testing only. + * Returns the result as a hive compatible sequence of strings. This is used in tests and + * `SparkSQLDriver` for CLI applications. */ def hiveResultString(): Seq[String] = executedPlan match { - case ExecutedCommandExec(desc: DescribeTableCommand) => + case ExecutedCommandExec(desc: DescribeTableCommand, _) => // If it is a describe command for a Hive table, we want to have the output format // be similar with Hive. desc.run(sparkSession).map { @@ -127,7 +128,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { .mkString("\t") } // SHOW TABLES in Hive only output table names, while ours output database, table name, isTemp. - case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended => + case command @ ExecutedCommandExec(s: ShowTablesCommand, _) if !s.isExtended => command.executeCollect().map(_.getString(1)) case other => val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index be35916e3447eababdb1e80dfddf919c51142539..bb206e84325fd62d442cb29d8589de03c9a6842e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -39,6 +39,19 @@ object SQLExecution { executionIdToQueryExecution.get(executionId) } + private val testing = sys.props.contains("spark.testing") + + private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = { + // only throw an exception during tests. a missing execution ID should not fail a job. + if (testing && sparkSession.sparkContext.getLocalProperty(EXECUTION_ID_KEY) == null) { + // Attention testers: when a test fails with this exception, it means that the action that + // started execution of a query didn't call withNewExecutionId. The execution ID should be + // set by calling withNewExecutionId in the action that begins execution, like + // Dataset.collect or DataFrameWriter.insertInto. + throw new IllegalStateException("Execution ID should be set") + } + } + /** * Wrap an action that will execute "queryExecution" to track all Spark jobs in the body so that * we can connect them with an execution. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 843ce63161220d9d1e9a5f578a2d549c0d8e0d20..f13294c925e366a84e52cd3501735fc1d1261378 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -346,7 +346,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case r: RunnableCommand => ExecutedCommandExec(r) :: Nil + case r: RunnableCommand => ExecutedCommandExec(r, r.children.map(planLater)) :: Nil case MemoryPlan(sink, output) => val encoder = RowEncoder(sink.schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 3486a6bce81800e92409ffadb55299f20bde690a..456a8f3b20f3055cda32d289e9907dc36d919de0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -64,7 +64,7 @@ case class InMemoryRelation( val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator) extends logical.LeafNode with MultiInstanceRelation { - override protected def innerChildren: Seq[SparkPlan] = Seq(child) + override def innerChildren: Seq[SparkPlan] = Seq(child) override def producedAttributes: AttributeSet = outputSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 7063b08f7c64406c70ec43afbc46a1eb9e9b0ea8..1d601374de135026e9ea128ef44628951aad5fbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -34,7 +34,7 @@ case class InMemoryTableScanExec( @transient relation: InMemoryRelation) extends LeafExecNode { - override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren + override def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 0d8db2ff5d5a03a931d70beb0de34789032b1502..2de14c90ec7574473061687ca606db46b4f438c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableTyp import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.QueryExecution /** @@ -96,11 +97,13 @@ case class AnalyzeColumnCommand( attributesToAnalyze.map(ColumnStat.statExprs(_, ndvMaxErr)) val namedExpressions = expressions.map(e => Alias(e, e.toString)()) - val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head() + val statsRow = new QueryExecution(sparkSession, Aggregate(Nil, namedExpressions, relation)) + .executedPlan.executeTake(1).head val rowCount = statsRow.getLong(0) val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) => - (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1), attr)) + // according to `ColumnStat.statExprs`, the stats struct always have 6 fields. + (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1, 6), attr)) }.toMap (rowCount, columnStats) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index d2ea0cdf61aa6ed3037152abc18a4d5c8007ca7e..3183c7911b1fb84d3d025e6b88a8ef69b1e0f663 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -56,7 +56,7 @@ case class AnalyzeTableCommand( // 2. when total size is changed, `oldRowCount` becomes invalid. // This is to make sure that we only record the right statistics. if (!noscan) { - val newRowCount = sparkSession.table(tableIdentWithDB).count() + val newRowCount = sparkSession.table(tableIdentWithDB).countInternal() if (newRowCount >= 0 && newRowCount != oldRowCount) { newStats = if (newStats.isDefined) { newStats.map(_.copy(rowCount = Some(BigInt(newRowCount)))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 336f14dd97aeacd7b389508df94f6134e4936b49..184d0387ebfa9250678e2f017c99897db7968843 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -30,19 +30,19 @@ case class CacheTableCommand( require(plan.isEmpty || tableIdent.database.isEmpty, "Database name is not allowed in CACHE TABLE AS SELECT") - override protected def innerChildren: Seq[QueryPlan[_]] = { - plan.toSeq - } + override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq override def run(sparkSession: SparkSession): Seq[Row] = { plan.foreach { logicalPlan => - Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString) + Dataset.ofRows(sparkSession, logicalPlan) + .createTempViewCommand(tableIdent.quotedString, replace = false, global = false) + .run(sparkSession) } sparkSession.catalog.cacheTable(tableIdent.quotedString) if (!isLazy) { // Performs eager caching - sparkSession.table(tableIdent).count() + sparkSession.table(tableIdent).countInternal() } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 41d91d877d4c2ae1b0d9b98eeeb6a26c255158cd..99d81c49f1e3bb5ff90d895ecbcf4c67458f1fb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -22,8 +22,7 @@ import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.{logical, QueryPlan} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.debug._ @@ -36,14 +35,20 @@ import org.apache.spark.sql.types._ * wrapped in `ExecutedCommand` during execution. */ trait RunnableCommand extends logical.Command { - def run(sparkSession: SparkSession): Seq[Row] + def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { + throw new NotImplementedError + } + + def run(sparkSession: SparkSession): Seq[Row] = { + throw new NotImplementedError + } } /** * A physical operator that executes the run method of a `RunnableCommand` and * saves the result to prevent multiple executions. */ -case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan { +case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) extends SparkPlan { /** * A concrete command should override this lazy field to wrap up any side effects caused by the * command or any other computation that should be evaluated exactly once. The value of this field @@ -55,14 +60,19 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan { */ protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { val converter = CatalystTypeConverters.createToCatalystConverter(schema) - cmd.run(sqlContext.sparkSession).map(converter(_).asInstanceOf[InternalRow]) + val rows = if (children.isEmpty) { + cmd.run(sqlContext.sparkSession) + } else { + cmd.run(sqlContext.sparkSession, children) + } + rows.map(converter(_).asInstanceOf[InternalRow]) } - override protected def innerChildren: Seq[QueryPlan[_]] = cmd :: Nil + override def innerChildren: Seq[QueryPlan[_]] = cmd.innerChildren override def output: Seq[Attribute] = cmd.output - override def children: Seq[SparkPlan] = Nil + override def nodeName: String = cmd.nodeName override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 2d890118ae0a5d07ec4099951902043f43011a63..729bd39d821c97b3f5d654c72db5a3fffb1909c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -122,7 +122,7 @@ case class CreateDataSourceTableAsSelectCommand( query: LogicalPlan) extends RunnableCommand { - override protected def innerChildren: Seq[LogicalPlan] = Seq(query) + override def innerChildren: Seq[LogicalPlan] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { assert(table.tableType != CatalogTableType.VIEW) @@ -195,7 +195,7 @@ case class CreateDataSourceTableAsSelectCommand( catalogTable = if (tableExists) Some(table) else None) try { - dataSource.writeAndRead(mode, Dataset.ofRows(session, query)) + dataSource.writeAndRead(mode, query) } catch { case ex: AnalysisException => logError(s"Failed to write to table ${table.identifier.unquotedString}", ex) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 00f0acab21aa2645bdd650686099f05e51b760f2..1945d68241343683941f2ae9f5cb9c465e980fc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -97,7 +97,7 @@ case class CreateViewCommand( import ViewHelper._ - override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child) + override def innerChildren: Seq[QueryPlan[_]] = Seq(child) if (viewType == PersistedView) { require(originalText.isDefined, "'originalText' must be provided to create permanent view") @@ -264,7 +264,7 @@ case class AlterViewAsCommand( import ViewHelper._ - override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(session: SparkSession): Seq[Row] = { // If the plan cannot be analyzed, throw an exception and don't proceed. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 9fce29b06b9d80cc4c39416f5ab2625dfbd36702..958715eefa0a2ab44fb26c6ed591ad1f8cd25c81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -28,8 +28,9 @@ import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider @@ -388,9 +389,10 @@ case class DataSource( } /** - * Writes the given [[DataFrame]] out in this [[FileFormat]]. + * Writes the given [[LogicalPlan]] out in this [[FileFormat]]. */ - private def writeInFileFormat(format: FileFormat, mode: SaveMode, data: DataFrame): Unit = { + private def planForWritingFileFormat( + format: FileFormat, mode: SaveMode, data: LogicalPlan): LogicalPlan = { // Don't glob path for the write path. The contracts here are: // 1. Only one output path can be specified on the write path; // 2. Output path must be a legal HDFS style file system path; @@ -408,16 +410,6 @@ case class DataSource( val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive) - // SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does - // not need to have the query as child, to avoid to analyze an optimized query, - // because InsertIntoHadoopFsRelationCommand will be optimized first. - val partitionAttributes = partitionColumns.map { name => - val plan = data.logicalPlan - plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse { - throw new AnalysisException( - s"Unable to resolve $name given [${plan.output.map(_.name).mkString(", ")}]") - }.asInstanceOf[Attribute] - } val fileIndex = catalogTable.map(_.identifier).map { tableIdent => sparkSession.table(tableIdent).queryExecution.analyzed.collect { case LogicalRelation(t: HadoopFsRelation, _, _) => t.location @@ -426,36 +418,35 @@ case class DataSource( // For partitioned relation r, r.schema's column ordering can be different from the column // ordering of data.logicalPlan (partition columns are all moved after data column). This // will be adjusted within InsertIntoHadoopFsRelation. - val plan = - InsertIntoHadoopFsRelationCommand( - outputPath = outputPath, - staticPartitions = Map.empty, - ifPartitionNotExists = false, - partitionColumns = partitionAttributes, - bucketSpec = bucketSpec, - fileFormat = format, - options = options, - query = data.logicalPlan, - mode = mode, - catalogTable = catalogTable, - fileIndex = fileIndex) - sparkSession.sessionState.executePlan(plan).toRdd + InsertIntoHadoopFsRelationCommand( + outputPath = outputPath, + staticPartitions = Map.empty, + ifPartitionNotExists = false, + partitionColumns = partitionColumns.map(UnresolvedAttribute.quoted), + bucketSpec = bucketSpec, + fileFormat = format, + options = options, + query = data, + mode = mode, + catalogTable = catalogTable, + fileIndex = fileIndex) } /** - * Writes the given [[DataFrame]] out to this [[DataSource]] and returns a [[BaseRelation]] for + * Writes the given [[LogicalPlan]] out to this [[DataSource]] and returns a [[BaseRelation]] for * the following reading. */ - def writeAndRead(mode: SaveMode, data: DataFrame): BaseRelation = { + def writeAndRead(mode: SaveMode, data: LogicalPlan): BaseRelation = { if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } providingClass.newInstance() match { case dataSource: CreatableRelationProvider => - dataSource.createRelation(sparkSession.sqlContext, mode, caseInsensitiveOptions, data) + dataSource.createRelation( + sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data)) case format: FileFormat => - writeInFileFormat(format, mode, data) + sparkSession.sessionState.executePlan(planForWritingFileFormat(format, mode, data)).toRdd // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation() case _ => @@ -464,18 +455,18 @@ case class DataSource( } /** - * Writes the given [[DataFrame]] out to this [[DataSource]]. + * Returns a logical plan to write the given [[LogicalPlan]] out to this [[DataSource]]. */ - def write(mode: SaveMode, data: DataFrame): Unit = { + def planForWriting(mode: SaveMode, data: LogicalPlan): LogicalPlan = { if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } providingClass.newInstance() match { case dataSource: CreatableRelationProvider => - dataSource.createRelation(sparkSession.sqlContext, mode, caseInsensitiveOptions, data) + SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode) case format: FileFormat => - writeInFileFormat(format, mode, data) + planForWritingFileFormat(format, mode, data) case _ => sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 4ec09bff429c5a8fe0f12892a5ada0f7cf9fcb90..afe454f714c47cc123083709b38be1e84593632b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -38,8 +38,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution} -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.types.StringType import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -96,7 +96,7 @@ object FileFormatWriter extends Logging { */ def write( sparkSession: SparkSession, - queryExecution: QueryExecution, + plan: SparkPlan, fileFormat: FileFormat, committer: FileCommitProtocol, outputSpec: OutputSpec, @@ -111,9 +111,9 @@ object FileFormatWriter extends Logging { job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) - val allColumns = queryExecution.logical.output + val allColumns = plan.output val partitionSet = AttributeSet(partitionColumns) - val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains) + val dataColumns = allColumns.filterNot(partitionSet.contains) val bucketIdExpression = bucketSpec.map { spec => val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) @@ -151,7 +151,7 @@ object FileFormatWriter extends Logging { // We should first sort by partition columns, then bucket id, and finally sorting columns. val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns // the sort order doesn't matter - val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child) + val actualOrdering = plan.outputOrdering.map(_.child) val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { false } else { @@ -161,50 +161,50 @@ object FileFormatWriter extends Logging { } } - SQLExecution.withNewExecutionId(sparkSession, queryExecution) { - // This call shouldn't be put into the `try` block below because it only initializes and - // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - committer.setupJob(job) - - try { - val rdd = if (orderingMatched) { - queryExecution.toRdd - } else { - SortExec( - requiredOrdering.map(SortOrder(_, Ascending)), - global = false, - child = queryExecution.executedPlan).execute() - } - val ret = new Array[WriteTaskResult](rdd.partitions.length) - sparkSession.sparkContext.runJob( - rdd, - (taskContext: TaskContext, iter: Iterator[InternalRow]) => { - executeTask( - description = description, - sparkStageId = taskContext.stageId(), - sparkPartitionId = taskContext.partitionId(), - sparkAttemptNumber = taskContext.attemptNumber(), - committer, - iterator = iter) - }, - 0 until rdd.partitions.length, - (index, res: WriteTaskResult) => { - committer.onTaskCommit(res.commitMsg) - ret(index) = res - }) - - val commitMsgs = ret.map(_.commitMsg) - val updatedPartitions = ret.flatMap(_.updatedPartitions) - .distinct.map(PartitioningUtils.parsePathFragment) - - committer.commitJob(job, commitMsgs) - logInfo(s"Job ${job.getJobID} committed.") - refreshFunction(updatedPartitions) - } catch { case cause: Throwable => - logError(s"Aborting job ${job.getJobID}.", cause) - committer.abortJob(job) - throw new SparkException("Job aborted.", cause) + SQLExecution.checkSQLExecutionId(sparkSession) + + // This call shouldn't be put into the `try` block below because it only initializes and + // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. + committer.setupJob(job) + + try { + val rdd = if (orderingMatched) { + plan.execute() + } else { + SortExec( + requiredOrdering.map(SortOrder(_, Ascending)), + global = false, + child = plan).execute() } + val ret = new Array[WriteTaskResult](rdd.partitions.length) + sparkSession.sparkContext.runJob( + rdd, + (taskContext: TaskContext, iter: Iterator[InternalRow]) => { + executeTask( + description = description, + sparkStageId = taskContext.stageId(), + sparkPartitionId = taskContext.partitionId(), + sparkAttemptNumber = taskContext.attemptNumber(), + committer, + iterator = iter) + }, + 0 until rdd.partitions.length, + (index, res: WriteTaskResult) => { + committer.onTaskCommit(res.commitMsg) + ret(index) = res + }) + + val commitMsgs = ret.map(_.commitMsg) + val updatedPartitions = ret.flatMap(_.updatedPartitions) + .distinct.map(PartitioningUtils.parsePathFragment) + + committer.commitJob(job, commitMsgs) + logInfo(s"Job ${job.getJobID} committed.") + refreshFunction(updatedPartitions) + } catch { case cause: Throwable => + logError(s"Aborting job ${job.getJobID}.", cause) + committer.abortJob(job) + throw new SparkException("Job aborted.", cause) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index a813829d50cb1ac7142b0fd6b4841c976fc5f74a..08b2f4f31170ffa80ad88e932b2bf6fe2b5b8230 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -33,7 +33,7 @@ case class InsertIntoDataSourceCommand( overwrite: Boolean) extends RunnableCommand { - override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index c9d31449d362944b64ffaf17e0a010e2ae238edc..00aa1240886e44047c72f27310b3c91ef75f028c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogT import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ /** @@ -53,12 +54,13 @@ case class InsertIntoHadoopFsRelationCommand( catalogTable: Option[CatalogTable], fileIndex: Option[FileIndex]) extends RunnableCommand { - import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName - override protected def innerChildren: Seq[LogicalPlan] = query :: Nil + override def children: Seq[LogicalPlan] = query :: Nil + + override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { + assert(children.length == 1) - override def run(sparkSession: SparkSession): Seq[Row] = { // Most formats don't do well with duplicate columns, so lets not allow that if (query.schema.fieldNames.length != query.schema.fieldNames.distinct.length) { val duplicateColumns = query.schema.fieldNames.groupBy(identity).collect { @@ -144,7 +146,7 @@ case class InsertIntoHadoopFsRelationCommand( FileFormatWriter.write( sparkSession = sparkSession, - queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, + plan = children.head, fileFormat = fileFormat, committer = committer, outputSpec = FileFormatWriter.OutputSpec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 6f19ea195c0cd736b81817e25020c1968046e47b..5eb6a8471be0d0a2d817616a8cffd43ec7649b5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.sources.CreatableRelationProvider /** * Saves the results of `query` in to a data source. @@ -33,19 +34,15 @@ import org.apache.spark.sql.execution.command.RunnableCommand */ case class SaveIntoDataSourceCommand( query: LogicalPlan, - provider: String, - partitionColumns: Seq[String], + dataSource: CreatableRelationProvider, options: Map[String, String], mode: SaveMode) extends RunnableCommand { - override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { - DataSource( - sparkSession, - className = provider, - partitionColumns = partitionColumns, - options = options).write(mode, Dataset.ofRows(sparkSession, query)) + dataSource.createRelation( + sparkSession.sqlContext, mode, options, Dataset.ofRows(sparkSession, query)) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 83bdf6fe224be20a7950f0d371d175c1f9516837..76f121c0c955fbd3a5facc029ada1b7311d67386 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -144,7 +144,8 @@ object TextInputCSVDataSource extends CSVDataSource { inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): StructType = { val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) - val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption + val maybeFirstLine = + CSVUtils.filterCommentAndEmpty(csv, parsedOptions).takeInternal(1).headOption inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index f8d4a9bb5b81ab3a5536a7e2ee72a937cf9223ed..fdc5e85f3c2eacfdf62c507320f35a84d06f457c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -76,7 +76,7 @@ case class CreateTempViewUsing( CatalogUtils.maskCredentials(options) } - def run(sparkSession: SparkSession): Seq[Row] = { + override def run(sparkSession: SparkSession): Seq[Row] = { if (provider.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) { throw new AnalysisException("Hive data source can only be used with tables, " + "you can't use it with CREATE TEMP VIEW USING") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 8b45dba04d29e09ae28d1b14efdcc2ba86196182..a06f1ce3287e676e5f587bbc20cdcf6355a7a8c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -129,12 +129,14 @@ private[sql] case class JDBCRelation( } override def insert(data: DataFrame, overwrite: Boolean): Unit = { - val url = jdbcOptions.url - val table = jdbcOptions.table - val properties = jdbcOptions.asProperties - data.write - .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) - .jdbc(url, table, properties) + import scala.collection.JavaConverters._ + + val options = jdbcOptions.asProperties.asScala + + ("url" -> jdbcOptions.url, "dbtable" -> jdbcOptions.table) + val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append + + new JdbcRelationProvider().createRelation( + data.sparkSession.sqlContext, mode, options.toMap, data) } override def toString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 71eaab119d75dc3eae70099d9859aee7e12cc44d..ca61c2efe2ddfcebabb1fac7656298b71532ba29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -788,7 +788,7 @@ object JdbcUtils extends Logging { case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n) case _ => df } - repartitionedDF.foreachPartition(iterator => savePartition( + repartitionedDF.rdd.foreachPartition(iterator => savePartition( getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 6885d0bf67ccb8de11446b234288441cc1712dd4..96225ecffad48e8ff63f18e3e61104b7e84386c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -122,7 +122,7 @@ class FileStreamSink( FileFormatWriter.write( sparkSession = sparkSession, - queryExecution = data.queryExecution, + plan = data.queryExecution.executedPlan, fileFormat = fileFormat, committer = committer, outputSpec = FileFormatWriter.OutputSpec(path, Map.empty), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index b6ddf7437ea13d9a28e7a7067c64edeb384cdbb7..ab8608563c4fba9d3bf2a8b702b1f876f64c93ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ @@ -655,7 +655,9 @@ class StreamExecution( new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema)) reportTimeTaken("addBatch") { - sink.addBatch(currentBatchId, nextBatch) + SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { + sink.addBatch(currentBatchId, nextBatch) + } } awaitBatchLock.lock() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index e8b9712d19cd559745dd98ce7174642ef0328bb8..38c63191106d037e1719f88fabbfb67054641bca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -46,8 +46,8 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging { println("-------------------------------------------") // scalastyle:off println data.sparkSession.createDataFrame( - data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) - .show(numRowsToShow, isTruncated) + data.sparkSession.sparkContext.parallelize(data.collectInternal()), data.schema) + .showInternal(numRowsToShow, isTruncated) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 971ce5afb177805c46452bd472085037b0d7b042..7eaa803a9ecb4f006879bed22e91cb73fe4050cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -196,11 +196,11 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi logDebug(s"Committing batch $batchId to $this") outputMode match { case Append | Update => - val rows = AddedData(batchId, data.collect()) + val rows = AddedData(batchId, data.collectInternal()) synchronized { batches += rows } case Complete => - val rows = AddedData(batchId, data.collect()) + val rows = AddedData(batchId, data.collectInternal()) synchronized { batches.clear() batches += rows diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index e544245588f461231cb2027936c72a84764b518a..a4e62f1d16792686c4209006f1a9c3aea9c14571 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -290,10 +290,13 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => + // person creates a temporary view. get the DF before listing previous execution IDs + val data = person.select('name) + sparkContext.listenerBus.waitUntilEmpty(10000) val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet // Assume the execution plan is // PhysicalRDD(nodeId = 0) - person.select('name).write.format("json").save(file.getAbsolutePath) + data.write.format("json").save(file.getAbsolutePath) sparkContext.listenerBus.waitUntilEmpty(10000) val executionIds = spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 7c9ea7d3936305670d0a4c518da81d7b3a641f81..a239e39d9c5a39acd4e6d71d7d182276faf87eb7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.{functions, AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, InsertIntoTable, LogicalPlan, Project} import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec} -import org.apache.spark.sql.execution.datasources.{CreateTable, SaveIntoDataSourceCommand} +import org.apache.spark.sql.execution.datasources.{CreateTable, InsertIntoHadoopFsRelationCommand} +import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.test.SharedSQLContext class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { @@ -178,26 +179,28 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { spark.range(10).write.format("json").save(path.getCanonicalPath) assert(commands.length == 1) assert(commands.head._1 == "save") - assert(commands.head._2.isInstanceOf[SaveIntoDataSourceCommand]) - assert(commands.head._2.asInstanceOf[SaveIntoDataSourceCommand].provider == "json") + assert(commands.head._2.isInstanceOf[InsertIntoHadoopFsRelationCommand]) + assert(commands.head._2.asInstanceOf[InsertIntoHadoopFsRelationCommand] + .fileFormat.isInstanceOf[JsonFileFormat]) } withTable("tab") { - sql("CREATE TABLE tab(i long) using parquet") + sql("CREATE TABLE tab(i long) using parquet") // adds commands(1) via onSuccess spark.range(10).write.insertInto("tab") - assert(commands.length == 2) - assert(commands(1)._1 == "insertInto") - assert(commands(1)._2.isInstanceOf[InsertIntoTable]) - assert(commands(1)._2.asInstanceOf[InsertIntoTable].table + assert(commands.length == 3) + assert(commands(2)._1 == "insertInto") + assert(commands(2)._2.isInstanceOf[InsertIntoTable]) + assert(commands(2)._2.asInstanceOf[InsertIntoTable].table .asInstanceOf[UnresolvedRelation].tableIdentifier.table == "tab") } + // exiting withTable adds commands(3) via onSuccess (drops tab) withTable("tab") { spark.range(10).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("tab") - assert(commands.length == 3) - assert(commands(2)._1 == "saveAsTable") - assert(commands(2)._2.isInstanceOf[CreateTable]) - assert(commands(2)._2.asInstanceOf[CreateTable].tableDesc.partitionColumnNames == Seq("p")) + assert(commands.length == 5) + assert(commands(4)._1 == "saveAsTable") + assert(commands(4)._2.isInstanceOf[CreateTable]) + assert(commands(4)._2.asInstanceOf[CreateTable].tableDesc.partitionColumnNames == Seq("p")) } withTable("tab") { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 0d5dc7af5f5222577bb5eb44e5fa1c1bcf0ec694..6775902173444794dd7ffc89c5acc0b67028317a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SQLContext} -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlContext) @@ -60,7 +60,9 @@ private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlCont try { context.sparkContext.setJobDescription(command) val execution = context.sessionState.executePlan(context.sql(command).logicalPlan) - hiveResponse = execution.hiveResultString() + hiveResponse = SQLExecution.withNewExecutionId(context.sparkSession, execution) { + execution.hiveResultString() + } tableSchema = getResultSetSchema(execution) new CommandProcessorResponse(0) } catch { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 10ce8e3730a0d29cc1b3bec4fb0f127af213b32b..392b7cfaa8eff113c1bc3cb50f640f48268cf441 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -32,10 +32,11 @@ import org.apache.hadoop.hive.ql.ErrorMsg import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.internal.io.FileCommitProtocol -import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.FileFormatWriter import org.apache.spark.sql.hive._ @@ -81,7 +82,7 @@ case class InsertIntoHiveTable( overwrite: Boolean, ifPartitionNotExists: Boolean) extends RunnableCommand { - override protected def innerChildren: Seq[LogicalPlan] = query :: Nil + override def children: Seq[LogicalPlan] = query :: Nil var createdTempDir: Option[Path] = None @@ -230,7 +231,9 @@ case class InsertIntoHiveTable( * `org.apache.hadoop.hive.serde2.SerDe` and the * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition. */ - override def run(sparkSession: SparkSession): Seq[Row] = { + override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { + assert(children.length == 1) + val sessionState = sparkSession.sessionState val externalCatalog = sparkSession.sharedState.externalCatalog val hiveVersion = externalCatalog.asInstanceOf[HiveExternalCatalog].client.version @@ -344,7 +347,7 @@ case class InsertIntoHiveTable( FileFormatWriter.write( sparkSession = sparkSession, - queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, + plan = children.head, fileFormat = new HiveFileFormat(fileSinkConf), committer = committer, outputSpec = FileFormatWriter.OutputSpec(tmpLocation.toString, Map.empty), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index e1534c797d55bcc1e59112922081c3c5d76c4d9b..4e1792321c89bcd83ef046d9ac7deff939b8dd53 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -34,8 +34,8 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.command.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.client.HiveClient @@ -294,23 +294,23 @@ private[hive] class TestHiveSparkSession( "CREATE TABLE src1 (key INT, value STRING)".cmd, s"LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), TestTable("srcpart", () => { - sql( - "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") + "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)" + .cmd.apply() for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { - sql( - s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' - |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') - """.stripMargin) + s""" + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' + |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') + """.stripMargin.cmd.apply() } }), TestTable("srcpart1", () => { - sql( - "CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") + "CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)" + .cmd.apply() for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) { - sql( - s"""LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' - |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') - """.stripMargin) + s""" + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/kv1.txt")}' + |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') + """.stripMargin.cmd.apply() } }), TestTable("src_thrift", () => { @@ -318,8 +318,7 @@ private[hive] class TestHiveSparkSession( import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat} import org.apache.thrift.protocol.TBinaryProtocol - sql( - s""" + s""" |CREATE TABLE src_thrift(fake INT) |ROW FORMAT SERDE '${classOf[ThriftDeserializer].getName}' |WITH SERDEPROPERTIES( @@ -329,13 +328,12 @@ private[hive] class TestHiveSparkSession( |STORED AS |INPUTFORMAT '${classOf[SequenceFileInputFormat[_, _]].getName}' |OUTPUTFORMAT '${classOf[SequenceFileOutputFormat[_, _]].getName}' - """.stripMargin) + """.stripMargin.cmd.apply() - sql( - s""" - |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/complex.seq")}' - |INTO TABLE src_thrift - """.stripMargin) + s""" + |LOAD DATA LOCAL INPATH '${quoteHiveFile("data/files/complex.seq")}' + |INTO TABLE src_thrift + """.stripMargin.cmd.apply() }), TestTable("serdeins", s"""CREATE TABLE serdeins (key INT, value STRING) @@ -458,7 +456,17 @@ private[hive] class TestHiveSparkSession( logDebug(s"Loading test table $name") val createCmds = testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) - createCmds.foreach(_()) + + // test tables are loaded lazily, so they may be loaded in the middle a query execution which + // has already set the execution id. + if (sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) == null) { + // We don't actually have a `QueryExecution` here, use a fake one instead. + SQLExecution.withNewExecutionId(this, new QueryExecution(this, OneRowRelation)) { + createCmds.foreach(_()) + } + } else { + createCmds.foreach(_()) + } if (cacheTables) { new SQLContext(self).cacheTable(name) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 98aa92a9bb88fd75c5aa615179f952c3f9e90bfc..cee82cda4628a9ff8a408ef1a0260fefae7ae368 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.hive.test.{TestHive, TestHiveQueryExecution} @@ -341,7 +342,10 @@ abstract class HiveComparisonTest // Run w/ catalyst val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => val query = new TestHiveQueryExecution(queryString.replace("../../data", testDataPath)) - try { (query, prepareAnswer(query, query.hiveResultString())) } catch { + def getResult(): Seq[String] = { + SQLExecution.withNewExecutionId(query.sparkSession, query)(query.hiveResultString()) + } + try { (query, prepareAnswer(query, getResult())) } catch { case e: Throwable => val errorMessage = s""" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c944f28d10ef408a5fba060a92cac6e51515183f..da7a0645dbbebb25c43404d7072698050ac6d37a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -965,14 +965,20 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("sanity test for SPARK-6618") { - (1 to 100).par.map { i => - val tableName = s"SPARK_6618_table_$i" - sql(s"CREATE TABLE $tableName (col1 string)") - sessionState.catalog.lookupRelation(TableIdentifier(tableName)) - table(tableName) - tables() - sql(s"DROP TABLE $tableName") + val threads: Seq[Thread] = (1 to 10).map { i => + new Thread("test-thread-" + i) { + override def run(): Unit = { + val tableName = s"SPARK_6618_table_$i" + sql(s"CREATE TABLE $tableName (col1 string)") + sessionState.catalog.lookupRelation(TableIdentifier(tableName)) + table(tableName) + tables() + sql(s"DROP TABLE $tableName") + } + } } + threads.foreach(_.start()) + threads.foreach(_.join(10000)) } test("SPARK-5203 union with different decimal precision") {