diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index ca8bed5214f87e44ac54fcc8f5e03801baef364b..e991da7df0bde219b00570210971aa60aefd37f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -22,15 +22,12 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, - SparkListenerSQLExecutionStart} +import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart} object SQLExecution { val EXECUTION_ID_KEY = "spark.sql.execution.id" - private val IGNORE_NESTED_EXECUTION_ID = "spark.sql.execution.ignoreNestedExecutionId" - private val _nextExecutionId = new AtomicLong(0) private def nextExecutionId: Long = _nextExecutionId.getAndIncrement @@ -45,10 +42,8 @@ object SQLExecution { private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = { val sc = sparkSession.sparkContext - val isNestedExecution = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null - val hasExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) != null // only throw an exception during tests. a missing execution ID should not fail a job. - if (testing && !isNestedExecution && !hasExecutionId) { + if (testing && sc.getLocalProperty(EXECUTION_ID_KEY) == null) { // Attention testers: when a test fails with this exception, it means that the action that // started execution of a query didn't call withNewExecutionId. The execution ID should be // set by calling withNewExecutionId in the action that begins execution, like @@ -66,56 +61,27 @@ object SQLExecution { queryExecution: QueryExecution)(body: => T): T = { val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) - if (oldExecutionId == null) { - val executionId = SQLExecution.nextExecutionId - sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) - executionIdToQueryExecution.put(executionId, queryExecution) - try { - // sparkContext.getCallSite() would first try to pick up any call site that was previously - // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on - // streaming queries would give us call site like "run at <unknown>:0" - val callSite = sparkSession.sparkContext.getCallSite() - - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) - try { - body - } finally { - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) - } - } finally { - executionIdToQueryExecution.remove(executionId) - sc.setLocalProperty(EXECUTION_ID_KEY, null) - } - } else if (sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null) { - // If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore the execution id while evaluating the - // `body`, so that Spark jobs issued in the `body` won't be tracked. + val executionId = SQLExecution.nextExecutionId + sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) + executionIdToQueryExecution.put(executionId, queryExecution) + try { + // sparkContext.getCallSite() would first try to pick up any call site that was previously + // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on + // streaming queries would give us call site like "run at <unknown>:0" + val callSite = sparkSession.sparkContext.getCallSite() + + sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) try { - sc.setLocalProperty(EXECUTION_ID_KEY, null) body } finally { - sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId) + sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) } - } else { - // Don't support nested `withNewExecutionId`. This is an example of the nested - // `withNewExecutionId`: - // - // class DataFrame { - // def foo: T = withNewExecutionId { something.createNewDataFrame().collect() } - // } - // - // Note: `collect` will call withNewExecutionId - // In this case, only the "executedPlan" for "collect" will be executed. The "executedPlan" - // for the outer DataFrame won't be executed. So it's meaningless to create a new Execution - // for the outer DataFrame. Even if we track it, since its "executedPlan" doesn't run, - // all accumulator metrics will be 0. It will confuse people if we show them in Web UI. - // - // A real case is the `DataFrame.count` method. - throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set, please wrap your " + - "action with SQLExecution.ignoreNestedExecutionId if you don't want to track the Spark " + - "jobs issued by the nested execution.") + } finally { + executionIdToQueryExecution.remove(executionId) + sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId) } } @@ -133,20 +99,4 @@ object SQLExecution { sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) } } - - /** - * Wrap an action which may have nested execution id. This method can be used to run an execution - * inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`. Note that, - * all Spark jobs issued in the body won't be tracked in UI. - */ - def ignoreNestedExecutionId[T](sparkSession: SparkSession)(body: => T): T = { - val sc = sparkSession.sparkContext - val allowNestedPreviousValue = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) - try { - sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, "true") - body - } finally { - sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, allowNestedPreviousValue) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index d780ef42f3fae0c1347825a386611ae8bfffd1f5..42e2a9ca5c4e281efda3be37df19937adb026414 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -51,9 +51,7 @@ case class AnalyzeTableCommand( // 2. when total size is changed, `oldRowCount` becomes invalid. // This is to make sure that we only record the right statistics. if (!noscan) { - val newRowCount = SQLExecution.ignoreNestedExecutionId(sparkSession) { - sparkSession.table(tableIdentWithDB).count() - } + val newRowCount = sparkSession.table(tableIdentWithDB).count() if (newRowCount >= 0 && newRowCount != oldRowCount) { newStats = if (newStats.isDefined) { newStats.map(_.copy(rowCount = Some(BigInt(newRowCount)))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index d36eb7587a3ef783442c5dd191d6c7336e2447fe..47952f2f227a3b27c68134dd48be277875c12e34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -34,16 +34,14 @@ case class CacheTableCommand( override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq override def run(sparkSession: SparkSession): Seq[Row] = { - SQLExecution.ignoreNestedExecutionId(sparkSession) { - plan.foreach { logicalPlan => - Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString) - } - sparkSession.catalog.cacheTable(tableIdent.quotedString) + plan.foreach { logicalPlan => + Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString) + } + sparkSession.catalog.cacheTable(tableIdent.quotedString) - if (!isLazy) { - // Performs eager caching - sparkSession.table(tableIdent).count() - } + if (!isLazy) { + // Performs eager caching + sparkSession.table(tableIdent).count() } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 99133bd70989a77d1849958a1433ed3bbb1fa252..2031381dd2e100d955e4ffe35754675e5889899d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -145,9 +145,7 @@ object TextInputCSVDataSource extends CSVDataSource { inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): StructType = { val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) - val maybeFirstLine = SQLExecution.ignoreNestedExecutionId(sparkSession) { - CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption - } + val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index b11da7045de22eec586d1ef15815c5293996bbd4..a521fd132385218b0d7f4a60183b86a0c028b73e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -130,11 +130,9 @@ private[sql] case class JDBCRelation( } override def insert(data: DataFrame, overwrite: Boolean): Unit = { - SQLExecution.ignoreNestedExecutionId(data.sparkSession) { - data.write - .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) - .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties) - } + data.write + .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) + .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties) } override def toString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 6fa7c113defaa86480a0ab6a0c0e9763d61d7cd9..3baea6376069fb2a9d064959cf2838bd8c09fa5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -48,11 +48,9 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging { println(batchIdStr) println("-------------------------------------------") // scalastyle:off println - SQLExecution.ignoreNestedExecutionId(data.sparkSession) { - data.sparkSession.createDataFrame( - data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) - .show(numRowsToShow, isTruncated) - } + data.sparkSession.createDataFrame( + data.sparkSession.sparkContext.parallelize(data.collect()), data.schema) + .show(numRowsToShow, isTruncated) } } @@ -82,9 +80,7 @@ class ConsoleSinkProvider extends StreamSinkProvider // Truncate the displayed data if it is too long, by default it is true val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true) - SQLExecution.ignoreNestedExecutionId(sqlContext.sparkSession) { - data.show(numRowsToShow, isTruncated) - } + data.show(numRowsToShow, isTruncated) ConsoleRelation(sqlContext, data) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 198a342582804d176eb33caa645592abe8c98cee..4979873ee3c7feed7f8e4e004b4cd9340c230bc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -194,23 +194,21 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi } if (notCommitted) { logDebug(s"Committing batch $batchId to $this") - SQLExecution.ignoreNestedExecutionId(data.sparkSession) { - outputMode match { - case Append | Update => - val rows = AddedData(batchId, data.collect()) - synchronized { batches += rows } - - case Complete => - val rows = AddedData(batchId, data.collect()) - synchronized { - batches.clear() - batches += rows - } - - case _ => - throw new IllegalArgumentException( - s"Output mode $outputMode is not supported by MemorySink") - } + outputMode match { + case Append | Update => + val rows = AddedData(batchId, data.collect()) + synchronized { batches += rows } + + case Complete => + val rows = AddedData(batchId, data.collect()) + synchronized { + batches.clear() + batches += rows + } + + case _ => + throw new IllegalArgumentException( + s"Output mode $outputMode is not supported by MemorySink") } } else { logDebug(s"Skipping already committed batch: $batchId") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index fe78a76568837fdcd309f69ca072e746dc9791eb..f6b006b98edd15dcc10e446500f60b23444ae1d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -26,22 +26,9 @@ import org.apache.spark.sql.SparkSession class SQLExecutionSuite extends SparkFunSuite { test("concurrent query execution (SPARK-10548)") { - // Try to reproduce the issue with the old SparkContext val conf = new SparkConf() .setMaster("local[*]") .setAppName("test") - val badSparkContext = new BadSparkContext(conf) - try { - testConcurrentQueryExecution(badSparkContext) - fail("unable to reproduce SPARK-10548") - } catch { - case e: IllegalArgumentException => - assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY)) - } finally { - badSparkContext.stop() - } - - // Verify that the issue is fixed with the latest SparkContext val goodSparkContext = new SparkContext(conf) try { testConcurrentQueryExecution(goodSparkContext) @@ -134,17 +121,6 @@ class SQLExecutionSuite extends SparkFunSuite { } } -/** - * A bad [[SparkContext]] that does not clone the inheritable thread local properties - * when passing them to children threads. - */ -private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) { - protected[spark] override val localProperties = new InheritableThreadLocal[Properties] { - override protected def childValue(parent: Properties): Properties = new Properties(parent) - override protected def initialValue(): Properties = new Properties() - } -} - object SQLExecutionSuite { @volatile var canProgress = false }