From 94624eacb0fdbbe210894151a956f8150cdf527e Mon Sep 17 00:00:00 2001
From: Davies Liu <davies@databricks.com>
Date: Wed, 18 Nov 2015 11:53:28 -0800
Subject: [PATCH] [SPARK-11739][SQL] clear the instantiated SQLContext

Currently, if the first SQLContext is not removed after stopping SparkContext, a SQLContext could set there forever. This patch make this more robust.

Author: Davies Liu <davies@databricks.com>

Closes #9706 from davies/clear_context.
---
 .../scala/org/apache/spark/sql/SQLContext.scala | 17 +++++++++++------
 .../spark/sql/MultiSQLContextsSuite.scala       |  5 ++---
 .../execution/ExchangeCoordinatorSuite.scala    |  2 +-
 3 files changed, 14 insertions(+), 10 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index cd1fdc4edb..39471d2fb7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -1229,7 +1229,7 @@ class SQLContext private[sql](
   // construction of the instance.
   sparkContext.addSparkListener(new SparkListener {
     override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = {
-      SQLContext.clearInstantiatedContext(self)
+      SQLContext.clearInstantiatedContext()
     }
   })
 
@@ -1270,13 +1270,13 @@ object SQLContext {
    */
   def getOrCreate(sparkContext: SparkContext): SQLContext = {
     val ctx = activeContext.get()
-    if (ctx != null) {
+    if (ctx != null && !ctx.sparkContext.isStopped) {
       return ctx
     }
 
     synchronized {
       val ctx = instantiatedContext.get()
-      if (ctx == null) {
+      if (ctx == null || ctx.sparkContext.isStopped) {
         new SQLContext(sparkContext)
       } else {
         ctx
@@ -1284,12 +1284,17 @@ object SQLContext {
     }
   }
 
-  private[sql] def clearInstantiatedContext(sqlContext: SQLContext): Unit = {
-    instantiatedContext.compareAndSet(sqlContext, null)
+  private[sql] def clearInstantiatedContext(): Unit = {
+    instantiatedContext.set(null)
   }
 
   private[sql] def setInstantiatedContext(sqlContext: SQLContext): Unit = {
-    instantiatedContext.compareAndSet(null, sqlContext)
+    synchronized {
+      val ctx = instantiatedContext.get()
+      if (ctx == null || ctx.sparkContext.isStopped) {
+        instantiatedContext.set(sqlContext)
+      }
+    }
   }
 
   private[sql] def getInstantiatedContextOption(): Option[SQLContext] = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala
index 0e8fcb6a85..34c5c68fd1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala
@@ -31,7 +31,7 @@ class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll {
     originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption()
 
     SQLContext.clearActive()
-    originalInstantiatedSQLContext.foreach(ctx => SQLContext.clearInstantiatedContext(ctx))
+    SQLContext.clearInstantiatedContext()
     sparkConf =
       new SparkConf(false)
         .setMaster("local[*]")
@@ -89,10 +89,9 @@ class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll {
         testNewSession(rootSQLContext)
         testNewSession(rootSQLContext)
         testCreatingNewSQLContext(allowMultipleSQLContexts)
-
-        SQLContext.clearInstantiatedContext(rootSQLContext)
       } finally {
         sc.stop()
+        SQLContext.clearInstantiatedContext()
       }
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
index 25f2f5caee..b96d50a70b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
@@ -34,7 +34,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
     originalInstantiatedSQLContext = SQLContext.getInstantiatedContextOption()
 
     SQLContext.clearActive()
-    originalInstantiatedSQLContext.foreach(ctx => SQLContext.clearInstantiatedContext(ctx))
+    SQLContext.clearInstantiatedContext()
   }
 
   override protected def afterAll(): Unit = {
-- 
GitLab