From 44de108d743ddbec11905f0fc86fb3fccdbac90e Mon Sep 17 00:00:00 2001
From: jinxing <jinxing6042@126.com>
Date: Tue, 6 Jun 2017 11:14:39 +0100
Subject: [PATCH] [SPARK-20985] Stop SparkContext using
 LocalSparkContext.withSpark

## What changes were proposed in this pull request?
SparkContext should always be stopped after using, thus other tests won't complain that there's only one `SparkContext` can exist.

Author: jinxing <jinxing6042@126.com>

Closes #18204 from jinxing64/SPARK-20985.
---
 .../org/apache/spark/MapOutputTrackerSuite.scala   |  7 ++-----
 .../KryoSerializerResizableOutputSuite.scala       | 14 +++++++-------
 2 files changed, 9 insertions(+), 12 deletions(-)

diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 71bedda5ac..4fe5c5e4fe 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -23,6 +23,7 @@ import org.mockito.Matchers.any
 import org.mockito.Mockito._
 
 import org.apache.spark.broadcast.BroadcastManager
+import org.apache.spark.LocalSparkContext._
 import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv}
 import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus}
 import org.apache.spark.shuffle.FetchFailedException
@@ -245,8 +246,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     newConf.set("spark.shuffle.mapOutput.minSizeForBroadcast", "10240") // 10 KB << 1MB framesize
 
     // needs TorrentBroadcast so need a SparkContext
-    val sc = new SparkContext("local", "MapOutputTrackerSuite", newConf)
-    try {
+    withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { sc =>
       val masterTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
       val rpcEnv = sc.env.rpcEnv
       val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf)
@@ -271,9 +271,6 @@ class MapOutputTrackerSuite extends SparkFunSuite {
       assert(1 == masterTracker.getNumCachedSerializedBroadcast)
       masterTracker.unregisterShuffle(20)
       assert(0 == masterTracker.getNumCachedSerializedBroadcast)
-
-    } finally {
-      LocalSparkContext.stop(sc)
     }
   }
 
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala
index 21251f0b93..cf01f79f49 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.serializer
 
 import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.LocalSparkContext
+import org.apache.spark.LocalSparkContext._
 import org.apache.spark.SparkContext
 import org.apache.spark.SparkException
 
@@ -32,9 +32,9 @@ class KryoSerializerResizableOutputSuite extends SparkFunSuite {
     conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
     conf.set("spark.kryoserializer.buffer", "1m")
     conf.set("spark.kryoserializer.buffer.max", "1m")
-    val sc = new SparkContext("local", "test", conf)
-    intercept[SparkException](sc.parallelize(x).collect())
-    LocalSparkContext.stop(sc)
+    withSpark(new SparkContext("local", "test", conf)) { sc =>
+      intercept[SparkException](sc.parallelize(x).collect())
+    }
   }
 
   test("kryo with resizable output buffer should succeed on large array") {
@@ -42,8 +42,8 @@ class KryoSerializerResizableOutputSuite extends SparkFunSuite {
     conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
     conf.set("spark.kryoserializer.buffer", "1m")
     conf.set("spark.kryoserializer.buffer.max", "2m")
-    val sc = new SparkContext("local", "test", conf)
-    assert(sc.parallelize(x).collect() === x)
-    LocalSparkContext.stop(sc)
+    withSpark(new SparkContext("local", "test", conf)) { sc =>
+      assert(sc.parallelize(x).collect() === x)
+    }
   }
 }
-- 
GitLab