From 9f6b3e65ccfa0daec31b58c5a6386b3a890c2149 Mon Sep 17 00:00:00 2001
From: Wenchen Fan <wenchen@databricks.com>
Date: Thu, 29 Jun 2017 14:37:42 +0800
Subject: [PATCH] [SPARK-21238][SQL] allow nested SQL execution

## What changes were proposed in this pull request?

This is kind of another follow-up for https://github.com/apache/spark/pull/18064 .

In #18064 , we wrap every SQL command with SQL execution, which makes nested SQL execution very likely to happen. #18419 trid to improve it a little bit, by introduing `SQLExecition.ignoreNestedExecutionId`. However, this is not friendly to data source developers, they may need to update their code to use this `ignoreNestedExecutionId` API.

This PR proposes a new solution, to just allow nested execution. The downside is that, we may have multiple executions for one query. We can improve this by updating the data organization in SQLListener, to have 1-n mapping from query to execution, instead of 1-1 mapping. This can be done in a follow-up.

## How was this patch tested?

existing tests.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #18450 from cloud-fan/execution-id.
---
 .../spark/sql/execution/SQLExecution.scala    | 88 ++++---------------
 .../command/AnalyzeTableCommand.scala         |  4 +-
 .../spark/sql/execution/command/cache.scala   | 16 ++--
 .../datasources/csv/CSVDataSource.scala       |  4 +-
 .../datasources/jdbc/JDBCRelation.scala       |  8 +-
 .../sql/execution/streaming/console.scala     | 12 +--
 .../sql/execution/streaming/memory.scala      | 32 ++++---
 .../sql/execution/SQLExecutionSuite.scala     | 24 -----
 8 files changed, 50 insertions(+), 138 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index ca8bed5214..e991da7df0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -22,15 +22,12 @@ import java.util.concurrent.atomic.AtomicLong
 
 import org.apache.spark.SparkContext
 import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd,
-  SparkListenerSQLExecutionStart}
+import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart}
 
 object SQLExecution {
 
   val EXECUTION_ID_KEY = "spark.sql.execution.id"
 
-  private val IGNORE_NESTED_EXECUTION_ID = "spark.sql.execution.ignoreNestedExecutionId"
-
   private val _nextExecutionId = new AtomicLong(0)
 
   private def nextExecutionId: Long = _nextExecutionId.getAndIncrement
@@ -45,10 +42,8 @@ object SQLExecution {
 
   private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = {
     val sc = sparkSession.sparkContext
-    val isNestedExecution = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null
-    val hasExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) != null
     // only throw an exception during tests. a missing execution ID should not fail a job.
-    if (testing && !isNestedExecution && !hasExecutionId) {
+    if (testing && sc.getLocalProperty(EXECUTION_ID_KEY) == null) {
       // Attention testers: when a test fails with this exception, it means that the action that
       // started execution of a query didn't call withNewExecutionId. The execution ID should be
       // set by calling withNewExecutionId in the action that begins execution, like
@@ -66,56 +61,27 @@ object SQLExecution {
       queryExecution: QueryExecution)(body: => T): T = {
     val sc = sparkSession.sparkContext
     val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
-    if (oldExecutionId == null) {
-      val executionId = SQLExecution.nextExecutionId
-      sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
-      executionIdToQueryExecution.put(executionId, queryExecution)
-      try {
-        // sparkContext.getCallSite() would first try to pick up any call site that was previously
-        // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
-        // streaming queries would give us call site like "run at <unknown>:0"
-        val callSite = sparkSession.sparkContext.getCallSite()
-
-        sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
-          executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
-          SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
-        try {
-          body
-        } finally {
-          sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
-            executionId, System.currentTimeMillis()))
-        }
-      } finally {
-        executionIdToQueryExecution.remove(executionId)
-        sc.setLocalProperty(EXECUTION_ID_KEY, null)
-      }
-    } else if (sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID) != null) {
-      // If `IGNORE_NESTED_EXECUTION_ID` is set, just ignore the execution id while evaluating the
-      // `body`, so that Spark jobs issued in the `body` won't be tracked.
+    val executionId = SQLExecution.nextExecutionId
+    sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
+    executionIdToQueryExecution.put(executionId, queryExecution)
+    try {
+      // sparkContext.getCallSite() would first try to pick up any call site that was previously
+      // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
+      // streaming queries would give us call site like "run at <unknown>:0"
+      val callSite = sparkSession.sparkContext.getCallSite()
+
+      sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
+        executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
+        SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
       try {
-        sc.setLocalProperty(EXECUTION_ID_KEY, null)
         body
       } finally {
-        sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId)
+        sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
+          executionId, System.currentTimeMillis()))
       }
-    } else {
-      // Don't support nested `withNewExecutionId`. This is an example of the nested
-      // `withNewExecutionId`:
-      //
-      // class DataFrame {
-      //   def foo: T = withNewExecutionId { something.createNewDataFrame().collect() }
-      // }
-      //
-      // Note: `collect` will call withNewExecutionId
-      // In this case, only the "executedPlan" for "collect" will be executed. The "executedPlan"
-      // for the outer DataFrame won't be executed. So it's meaningless to create a new Execution
-      // for the outer DataFrame. Even if we track it, since its "executedPlan" doesn't run,
-      // all accumulator metrics will be 0. It will confuse people if we show them in Web UI.
-      //
-      // A real case is the `DataFrame.count` method.
-      throw new IllegalArgumentException(s"$EXECUTION_ID_KEY is already set, please wrap your " +
-        "action with SQLExecution.ignoreNestedExecutionId if you don't want to track the Spark " +
-        "jobs issued by the nested execution.")
+    } finally {
+      executionIdToQueryExecution.remove(executionId)
+      sc.setLocalProperty(EXECUTION_ID_KEY, oldExecutionId)
     }
   }
 
@@ -133,20 +99,4 @@ object SQLExecution {
       sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
     }
   }
-
-  /**
-   * Wrap an action which may have nested execution id. This method can be used to run an execution
-   * inside another execution, e.g., `CacheTableCommand` need to call `Dataset.collect`. Note that,
-   * all Spark jobs issued in the body won't be tracked in UI.
-   */
-  def ignoreNestedExecutionId[T](sparkSession: SparkSession)(body: => T): T = {
-    val sc = sparkSession.sparkContext
-    val allowNestedPreviousValue = sc.getLocalProperty(IGNORE_NESTED_EXECUTION_ID)
-    try {
-      sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, "true")
-      body
-    } finally {
-      sc.setLocalProperty(IGNORE_NESTED_EXECUTION_ID, allowNestedPreviousValue)
-    }
-  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
index d780ef42f3..42e2a9ca5c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
@@ -51,9 +51,7 @@ case class AnalyzeTableCommand(
     // 2. when total size is changed, `oldRowCount` becomes invalid.
     // This is to make sure that we only record the right statistics.
     if (!noscan) {
-      val newRowCount = SQLExecution.ignoreNestedExecutionId(sparkSession) {
-        sparkSession.table(tableIdentWithDB).count()
-      }
+      val newRowCount = sparkSession.table(tableIdentWithDB).count()
       if (newRowCount >= 0 && newRowCount != oldRowCount) {
         newStats = if (newStats.isDefined) {
           newStats.map(_.copy(rowCount = Some(BigInt(newRowCount))))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
index d36eb7587a..47952f2f22 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
@@ -34,16 +34,14 @@ case class CacheTableCommand(
   override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
-    SQLExecution.ignoreNestedExecutionId(sparkSession) {
-      plan.foreach { logicalPlan =>
-        Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)
-      }
-      sparkSession.catalog.cacheTable(tableIdent.quotedString)
+    plan.foreach { logicalPlan =>
+      Dataset.ofRows(sparkSession, logicalPlan).createTempView(tableIdent.quotedString)
+    }
+    sparkSession.catalog.cacheTable(tableIdent.quotedString)
 
-      if (!isLazy) {
-        // Performs eager caching
-        sparkSession.table(tableIdent).count()
-      }
+    if (!isLazy) {
+      // Performs eager caching
+      sparkSession.table(tableIdent).count()
     }
 
     Seq.empty[Row]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index 99133bd709..2031381dd2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -145,9 +145,7 @@ object TextInputCSVDataSource extends CSVDataSource {
       inputPaths: Seq[FileStatus],
       parsedOptions: CSVOptions): StructType = {
     val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions)
-    val maybeFirstLine = SQLExecution.ignoreNestedExecutionId(sparkSession) {
-      CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption
-    }
+    val maybeFirstLine = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption
     inferFromDataset(sparkSession, csv, maybeFirstLine, parsedOptions)
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index b11da7045d..a521fd1323 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -130,11 +130,9 @@ private[sql] case class JDBCRelation(
   }
 
   override def insert(data: DataFrame, overwrite: Boolean): Unit = {
-    SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
-      data.write
-        .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
-        .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties)
-    }
+    data.write
+      .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
+      .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties)
   }
 
   override def toString: String = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
index 6fa7c113de..3baea63760 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
@@ -48,11 +48,9 @@ class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
     println(batchIdStr)
     println("-------------------------------------------")
     // scalastyle:off println
-    SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
-      data.sparkSession.createDataFrame(
-        data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
-        .show(numRowsToShow, isTruncated)
-    }
+    data.sparkSession.createDataFrame(
+      data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
+      .show(numRowsToShow, isTruncated)
   }
 }
 
@@ -82,9 +80,7 @@ class ConsoleSinkProvider extends StreamSinkProvider
 
     // Truncate the displayed data if it is too long, by default it is true
     val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true)
-    SQLExecution.ignoreNestedExecutionId(sqlContext.sparkSession) {
-      data.show(numRowsToShow, isTruncated)
-    }
+    data.show(numRowsToShow, isTruncated)
 
     ConsoleRelation(sqlContext, data)
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 198a342582..4979873ee3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -194,23 +194,21 @@ class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink wi
     }
     if (notCommitted) {
       logDebug(s"Committing batch $batchId to $this")
-      SQLExecution.ignoreNestedExecutionId(data.sparkSession) {
-        outputMode match {
-          case Append | Update =>
-            val rows = AddedData(batchId, data.collect())
-            synchronized { batches += rows }
-
-          case Complete =>
-            val rows = AddedData(batchId, data.collect())
-            synchronized {
-              batches.clear()
-              batches += rows
-            }
-
-          case _ =>
-            throw new IllegalArgumentException(
-              s"Output mode $outputMode is not supported by MemorySink")
-        }
+      outputMode match {
+        case Append | Update =>
+          val rows = AddedData(batchId, data.collect())
+          synchronized { batches += rows }
+
+        case Complete =>
+          val rows = AddedData(batchId, data.collect())
+          synchronized {
+            batches.clear()
+            batches += rows
+          }
+
+        case _ =>
+          throw new IllegalArgumentException(
+            s"Output mode $outputMode is not supported by MemorySink")
       }
     } else {
       logDebug(s"Skipping already committed batch: $batchId")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala
index fe78a76568..f6b006b98e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala
@@ -26,22 +26,9 @@ import org.apache.spark.sql.SparkSession
 class SQLExecutionSuite extends SparkFunSuite {
 
   test("concurrent query execution (SPARK-10548)") {
-    // Try to reproduce the issue with the old SparkContext
     val conf = new SparkConf()
       .setMaster("local[*]")
       .setAppName("test")
-    val badSparkContext = new BadSparkContext(conf)
-    try {
-      testConcurrentQueryExecution(badSparkContext)
-      fail("unable to reproduce SPARK-10548")
-    } catch {
-      case e: IllegalArgumentException =>
-        assert(e.getMessage.contains(SQLExecution.EXECUTION_ID_KEY))
-    } finally {
-      badSparkContext.stop()
-    }
-
-    // Verify that the issue is fixed with the latest SparkContext
     val goodSparkContext = new SparkContext(conf)
     try {
       testConcurrentQueryExecution(goodSparkContext)
@@ -134,17 +121,6 @@ class SQLExecutionSuite extends SparkFunSuite {
   }
 }
 
-/**
- * A bad [[SparkContext]] that does not clone the inheritable thread local properties
- * when passing them to children threads.
- */
-private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) {
-  protected[spark] override val localProperties = new InheritableThreadLocal[Properties] {
-    override protected def childValue(parent: Properties): Properties = new Properties(parent)
-    override protected def initialValue(): Properties = new Properties()
-  }
-}
-
 object SQLExecutionSuite {
   @volatile var canProgress = false
 }
-- 
GitLab