From 0dbd6963d589a8f6ad344273f3da7df680ada515 Mon Sep 17 00:00:00 2001
From: zsxwing <zsxwing@gmail.com>
Date: Thu, 30 Jul 2015 15:39:46 -0700
Subject: [PATCH] [SPARK-9479] [STREAMING] [TESTS] Fix ReceiverTrackerSuite
 failure for maven build and other potential test failures in Streaming

See https://issues.apache.org/jira/browse/SPARK-9479 for the failure cause.

The PR includes the following changes:
1. Make ReceiverTrackerSuite create StreamingContext in the test body.
2. Fix places that don't stop StreamingContext. I verified no SparkContext was stopped in the shutdown hook locally after this fix.
3. Fix an issue that `ReceiverTracker.endpoint` may be null.
4. Make sure stopping SparkContext in non-main thread won't fail other tests.

Author: zsxwing <zsxwing@gmail.com>

Closes #7797 from zsxwing/fix-ReceiverTrackerSuite and squashes the following commits:

3a4bb98 [zsxwing] Fix another potential NPE
d7497df [zsxwing] Fix ReceiverTrackerSuite; make sure StreamingContext in tests is closed
---
 .../StreamingLogisticRegressionSuite.scala    | 21 +++++--
 .../clustering/StreamingKMeansSuite.scala     | 17 ++++--
 .../StreamingLinearRegressionSuite.scala      | 21 +++++--
 .../streaming/scheduler/ReceiverTracker.scala | 12 +++-
 .../apache/spark/streaming/JavaAPISuite.java  |  1 +
 .../streaming/BasicOperationsSuite.scala      | 58 ++++++++++---------
 .../spark/streaming/InputStreamsSuite.scala   | 38 ++++++------
 .../spark/streaming/MasterFailureTest.scala   |  8 ++-
 .../streaming/StreamingContextSuite.scala     | 22 +++++--
 .../streaming/StreamingListenerSuite.scala    | 13 ++++-
 .../scheduler/ReceiverTrackerSuite.scala      | 56 +++++++++---------
 .../StreamingJobProgressListenerSuite.scala   | 19 ++++--
 12 files changed, 183 insertions(+), 103 deletions(-)

diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
index fd653296c9..d7b291d5a6 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
@@ -24,13 +24,22 @@ import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.streaming.dstream.DStream
-import org.apache.spark.streaming.TestSuiteBase
+import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
 
 class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase {
 
   // use longer wait time to ensure job completion
   override def maxWaitTimeMillis: Int = 30000
 
+  var ssc: StreamingContext = _
+
+  override def afterFunction() {
+    super.afterFunction()
+    if (ssc != null) {
+      ssc.stop()
+    }
+  }
+
   // Test if we can accurately learn B for Y = logistic(BX) on streaming data
   test("parameter accuracy") {
 
@@ -50,7 +59,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
     }
 
     // apply model training to input stream
-    val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
+    ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
       model.trainOn(inputDStream)
       inputDStream.count()
     })
@@ -84,7 +93,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
 
     // apply model training to input stream, storing the intermediate results
     // (we add a count to ensure the result is a DStream)
-    val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
+    ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
       model.trainOn(inputDStream)
       inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - B)))
       inputDStream.count()
@@ -118,7 +127,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
     }
 
     // apply model predictions to test stream
-    val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+    ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
       model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
     })
 
@@ -147,7 +156,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
     }
 
     // train and predict
-    val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+    ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
       model.trainOn(inputDStream)
       model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
     })
@@ -167,7 +176,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
       .setNumIterations(10)
     val numBatches = 10
     val emptyInput = Seq.empty[Seq[LabeledPoint]]
-    val ssc = setupStreams(emptyInput,
+    ssc = setupStreams(emptyInput,
       (inputDStream: DStream[LabeledPoint]) => {
         model.trainOn(inputDStream)
         model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
index ac01622b8a..3645d29dcc 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.streaming.TestSuiteBase
+import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
 import org.apache.spark.streaming.dstream.DStream
 import org.apache.spark.util.random.XORShiftRandom
 
@@ -28,6 +28,15 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
 
   override def maxWaitTimeMillis: Int = 30000
 
+  var ssc: StreamingContext = _
+
+  override def afterFunction() {
+    super.afterFunction()
+    if (ssc != null) {
+      ssc.stop()
+    }
+  }
+
   test("accuracy for single center and equivalence to grand average") {
     // set parameters
     val numBatches = 10
@@ -46,7 +55,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
     val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
 
     // setup and run the model training
-    val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+    ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
       model.trainOn(inputDStream)
       inputDStream.count()
     })
@@ -82,7 +91,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
     val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
 
     // setup and run the model training
-    val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+    ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
       kMeans.trainOn(inputDStream)
       inputDStream.count()
     })
@@ -114,7 +123,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase {
       StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0)))
 
     // setup and run the model training
-    val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+    ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
       kMeans.trainOn(inputDStream)
       inputDStream.count()
     })
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index a2a4c5f6b8..34c07ed170 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -22,14 +22,23 @@ import scala.collection.mutable.ArrayBuffer
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.LinearDataGenerator
+import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
 import org.apache.spark.streaming.dstream.DStream
-import org.apache.spark.streaming.TestSuiteBase
 
 class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
 
   // use longer wait time to ensure job completion
   override def maxWaitTimeMillis: Int = 20000
 
+  var ssc: StreamingContext = _
+
+  override def afterFunction() {
+    super.afterFunction()
+    if (ssc != null) {
+      ssc.stop()
+    }
+  }
+
   // Assert that two values are equal within tolerance epsilon
   def assertEqual(v1: Double, v2: Double, epsilon: Double) {
     def errorMessage = v1.toString + " did not equal " + v2.toString
@@ -62,7 +71,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
     }
 
     // apply model training to input stream
-    val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
+    ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
       model.trainOn(inputDStream)
       inputDStream.count()
     })
@@ -98,7 +107,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
 
     // apply model training to input stream, storing the intermediate results
     // (we add a count to ensure the result is a DStream)
-    val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
+    ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => {
       model.trainOn(inputDStream)
       inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - 10.0)))
       inputDStream.count()
@@ -129,7 +138,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
     }
 
     // apply model predictions to test stream
-    val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+    ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
       model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
     })
     // collect the output as (true, estimated) tuples
@@ -156,7 +165,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
     }
 
     // train and predict
-    val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
+    ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => {
       model.trainOn(inputDStream)
       model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
     })
@@ -177,7 +186,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
     val numBatches = 10
     val nPoints = 100
     val emptyInput = Seq.empty[Seq[LabeledPoint]]
-    val ssc = setupStreams(emptyInput,
+    ssc = setupStreams(emptyInput,
       (inputDStream: DStream[LabeledPoint]) => {
         model.trainOn(inputDStream)
         model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index 6270137951..e076fb5ea1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -223,7 +223,11 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
     // Signal the receivers to delete old block data
     if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) {
       logInfo(s"Cleanup old received batch data: $cleanupThreshTime")
-      endpoint.send(CleanupOldBlocks(cleanupThreshTime))
+      synchronized {
+        if (isTrackerStarted) {
+          endpoint.send(CleanupOldBlocks(cleanupThreshTime))
+        }
+      }
     }
   }
 
@@ -285,8 +289,10 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
   }
 
   /** Update a receiver's maximum ingestion rate */
-  def sendRateUpdate(streamUID: Int, newRate: Long): Unit = {
-    endpoint.send(UpdateReceiverRateLimit(streamUID, newRate))
+  def sendRateUpdate(streamUID: Int, newRate: Long): Unit = synchronized {
+    if (isTrackerStarted) {
+      endpoint.send(UpdateReceiverRateLimit(streamUID, newRate))
+    }
   }
 
   /** Add new blocks for the given stream */
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index a34f234758..e0718f73aa 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -1735,6 +1735,7 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
   @SuppressWarnings("unchecked")
   @Test
   public void testContextGetOrCreate() throws InterruptedException {
+    ssc.stop();
 
     final SparkConf conf = new SparkConf()
         .setMaster("local[2]")
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index 08faeaa58f..255376807c 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -81,39 +81,41 @@ class BasicOperationsSuite extends TestSuiteBase {
   test("repartition (more partitions)") {
     val input = Seq(1 to 100, 101 to 200, 201 to 300)
     val operation = (r: DStream[Int]) => r.repartition(5)
-    val ssc = setupStreams(input, operation, 2)
-    val output = runStreamsWithPartitions(ssc, 3, 3)
-    assert(output.size === 3)
-    val first = output(0)
-    val second = output(1)
-    val third = output(2)
-
-    assert(first.size === 5)
-    assert(second.size === 5)
-    assert(third.size === 5)
-
-    assert(first.flatten.toSet.equals((1 to 100).toSet) )
-    assert(second.flatten.toSet.equals((101 to 200).toSet))
-    assert(third.flatten.toSet.equals((201 to 300).toSet))
+    withStreamingContext(setupStreams(input, operation, 2)) { ssc =>
+      val output = runStreamsWithPartitions(ssc, 3, 3)
+      assert(output.size === 3)
+      val first = output(0)
+      val second = output(1)
+      val third = output(2)
+
+      assert(first.size === 5)
+      assert(second.size === 5)
+      assert(third.size === 5)
+
+      assert(first.flatten.toSet.equals((1 to 100).toSet))
+      assert(second.flatten.toSet.equals((101 to 200).toSet))
+      assert(third.flatten.toSet.equals((201 to 300).toSet))
+    }
   }
 
   test("repartition (fewer partitions)") {
     val input = Seq(1 to 100, 101 to 200, 201 to 300)
     val operation = (r: DStream[Int]) => r.repartition(2)
-    val ssc = setupStreams(input, operation, 5)
-    val output = runStreamsWithPartitions(ssc, 3, 3)
-    assert(output.size === 3)
-    val first = output(0)
-    val second = output(1)
-    val third = output(2)
-
-    assert(first.size === 2)
-    assert(second.size === 2)
-    assert(third.size === 2)
-
-    assert(first.flatten.toSet.equals((1 to 100).toSet))
-    assert(second.flatten.toSet.equals( (101 to 200).toSet))
-    assert(third.flatten.toSet.equals((201 to 300).toSet))
+    withStreamingContext(setupStreams(input, operation, 5)) { ssc =>
+      val output = runStreamsWithPartitions(ssc, 3, 3)
+      assert(output.size === 3)
+      val first = output(0)
+      val second = output(1)
+      val third = output(2)
+
+      assert(first.size === 2)
+      assert(second.size === 2)
+      assert(third.size === 2)
+
+      assert(first.flatten.toSet.equals((1 to 100).toSet))
+      assert(second.flatten.toSet.equals((101 to 200).toSet))
+      assert(third.flatten.toSet.equals((201 to 300).toSet))
+    }
   }
 
   test("groupByKey") {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
index b74d67c63a..ec2852d9a0 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
@@ -325,27 +325,31 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
   }
 
   test("test track the number of input stream") {
-    val ssc = new StreamingContext(conf, batchDuration)
+    withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
 
-    class TestInputDStream extends InputDStream[String](ssc) {
-      def start() { }
-      def stop() { }
-      def compute(validTime: Time): Option[RDD[String]] = None
-    }
+      class TestInputDStream extends InputDStream[String](ssc) {
+        def start() {}
 
-    class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) {
-      def getReceiver: Receiver[String] = null
-    }
+        def stop() {}
+
+        def compute(validTime: Time): Option[RDD[String]] = None
+      }
+
+      class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) {
+        def getReceiver: Receiver[String] = null
+      }
 
-    // Register input streams
-    val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream)
-    val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream)
+      // Register input streams
+      val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream)
+      val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream)
 
-    assert(ssc.graph.getInputStreams().length == receiverInputStreams.length + inputStreams.length)
-    assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length)
-    assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams)
-    assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i))
-    assert(receiverInputStreams.map(_.id) === Array(0, 1))
+      assert(ssc.graph.getInputStreams().length ==
+        receiverInputStreams.length + inputStreams.length)
+      assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length)
+      assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams)
+      assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i))
+      assert(receiverInputStreams.map(_.id) === Array(0, 1))
+    }
   }
 
   def testFileStream(newFilesOnly: Boolean) {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala
index 6e9d443109..0e64b57e0f 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala
@@ -244,7 +244,13 @@ object MasterFailureTest extends Logging {
       } catch {
         case e: Exception => logError("Error running streaming context", e)
       }
-      if (killingThread.isAlive) killingThread.interrupt()
+      if (killingThread.isAlive) {
+        killingThread.interrupt()
+        // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is
+        // stopped before running the next test. Otherwise, it's possible that we set SparkEnv.env
+        // to null after the next test creates the new SparkContext and fail the test.
+        killingThread.join()
+      }
       ssc.stop()
 
       logInfo("Has been killed = " + killed)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index 4bba9691f8..84a5fbb3d9 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -120,7 +120,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
 
     val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName)
     myConf.set("spark.streaming.checkpoint.directory", checkpointDirectory)
-    val ssc = new StreamingContext(myConf, batchDuration)
+    ssc = new StreamingContext(myConf, batchDuration)
     assert(ssc.checkpointDir != null)
   }
 
@@ -369,16 +369,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
     }
     assert(exception.isInstanceOf[TestFailedDueToTimeoutException], "Did not wait for stop")
 
+    var t: Thread = null
     // test whether wait exits if context is stopped
     failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown
-      new Thread() {
+      t = new Thread() {
         override def run() {
           Thread.sleep(500)
           ssc.stop()
         }
-      }.start()
+      }
+      t.start()
       ssc.awaitTermination()
     }
+    // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is stopped
+    // before running the next test. Otherwise, it's possible that we set SparkEnv.env to null after
+    // the next test creates the new SparkContext and fail the test.
+    t.join()
   }
 
   test("awaitTermination after stop") {
@@ -430,16 +436,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
       assert(ssc.awaitTerminationOrTimeout(500) === false)
     }
 
+    var t: Thread = null
     // test whether awaitTerminationOrTimeout() return true if context is stopped
     failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown
-      new Thread() {
+      t = new Thread() {
         override def run() {
           Thread.sleep(500)
           ssc.stop()
         }
-      }.start()
+      }
+      t.start()
       assert(ssc.awaitTerminationOrTimeout(10000) === true)
     }
+    // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is stopped
+    // before running the next test. Otherwise, it's possible that we set SparkEnv.env to null after
+    // the next test creates the new SparkContext and fail the test.
+    t.join()
   }
 
   test("getOrCreate") {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
index 4bc1dd4a30..d840c349bb 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
@@ -36,13 +36,22 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
   val input = (1 to 4).map(Seq(_)).toSeq
   val operation = (d: DStream[Int]) => d.map(x => x)
 
+  var ssc: StreamingContext = _
+
+  override def afterFunction() {
+    super.afterFunction()
+    if (ssc != null) {
+      ssc.stop()
+    }
+  }
+
   // To make sure that the processing start and end times in collected
   // information are different for successive batches
   override def batchDuration: Duration = Milliseconds(100)
   override def actuallyWait: Boolean = true
 
   test("batch info reporting") {
-    val ssc = setupStreams(input, operation)
+    ssc = setupStreams(input, operation)
     val collector = new BatchInfoCollector
     ssc.addStreamingListener(collector)
     runStreams(ssc, input.size, input.size)
@@ -107,7 +116,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
   }
 
   test("receiver info reporting") {
-    val ssc = new StreamingContext("local[2]", "test", Milliseconds(1000))
+    ssc = new StreamingContext("local[2]", "test", Milliseconds(1000))
     val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver)
     inputStream.foreachRDD(_.count)
 
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
index aff8b53f75..afad5f16db 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
@@ -29,36 +29,40 @@ import org.apache.spark.storage.StorageLevel
 /** Testsuite for receiver scheduling */
 class ReceiverTrackerSuite extends TestSuiteBase {
   val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test")
-  val ssc = new StreamingContext(sparkConf, Milliseconds(100))
 
-  ignore("Receiver tracker - propagates rate limit") {
-    object ReceiverStartedWaiter extends StreamingListener {
-      @volatile
-      var started = false
+  test("Receiver tracker - propagates rate limit") {
+    withStreamingContext(new StreamingContext(sparkConf, Milliseconds(100))) { ssc =>
+      object ReceiverStartedWaiter extends StreamingListener {
+        @volatile
+        var started = false
 
-      override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = {
-        started = true
+        override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = {
+          started = true
+        }
       }
-    }
-
-    ssc.addStreamingListener(ReceiverStartedWaiter)
-    ssc.scheduler.listenerBus.start(ssc.sc)
-    SingletonTestRateReceiver.reset()
-
-    val newRateLimit = 100L
-    val inputDStream = new RateLimitInputDStream(ssc)
-    val tracker = new ReceiverTracker(ssc)
-    tracker.start()
 
-    // we wait until the Receiver has registered with the tracker,
-    // otherwise our rate update is lost
-    eventually(timeout(5 seconds)) {
-      assert(ReceiverStartedWaiter.started)
-    }
-    tracker.sendRateUpdate(inputDStream.id, newRateLimit)
-    // this is an async message, we need to wait a bit for it to be processed
-    eventually(timeout(3 seconds)) {
-      assert(inputDStream.getCurrentRateLimit.get === newRateLimit)
+      ssc.addStreamingListener(ReceiverStartedWaiter)
+      ssc.scheduler.listenerBus.start(ssc.sc)
+      SingletonTestRateReceiver.reset()
+
+      val newRateLimit = 100L
+      val inputDStream = new RateLimitInputDStream(ssc)
+      val tracker = new ReceiverTracker(ssc)
+      tracker.start()
+      try {
+        // we wait until the Receiver has registered with the tracker,
+        // otherwise our rate update is lost
+        eventually(timeout(5 seconds)) {
+          assert(ReceiverStartedWaiter.started)
+        }
+        tracker.sendRateUpdate(inputDStream.id, newRateLimit)
+        // this is an async message, we need to wait a bit for it to be processed
+        eventually(timeout(3 seconds)) {
+          assert(inputDStream.getCurrentRateLimit.get === newRateLimit)
+        }
+      } finally {
+        tracker.stop(false)
+      }
     }
   }
 }
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
index 0891309f95..995f1197cc 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala
@@ -22,15 +22,24 @@ import java.util.Properties
 import org.scalatest.Matchers
 
 import org.apache.spark.scheduler.SparkListenerJobStart
+import org.apache.spark.streaming._
 import org.apache.spark.streaming.dstream.DStream
 import org.apache.spark.streaming.scheduler._
-import org.apache.spark.streaming.{Duration, Time, Milliseconds, TestSuiteBase}
 
 class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
 
   val input = (1 to 4).map(Seq(_)).toSeq
   val operation = (d: DStream[Int]) => d.map(x => x)
 
+  var ssc: StreamingContext = _
+
+  override def afterFunction() {
+    super.afterFunction()
+    if (ssc != null) {
+      ssc.stop()
+    }
+  }
+
   private def createJobStart(
       batchTime: Time, outputOpId: Int, jobId: Int): SparkListenerJobStart = {
     val properties = new Properties()
@@ -46,7 +55,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
 
   test("onBatchSubmitted, onBatchStarted, onBatchCompleted, " +
     "onReceiverStarted, onReceiverError, onReceiverStopped") {
-    val ssc = setupStreams(input, operation)
+    ssc = setupStreams(input, operation)
     val listener = new StreamingJobProgressListener(ssc)
 
     val streamIdToInputInfo = Map(
@@ -141,7 +150,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
   }
 
   test("Remove the old completed batches when exceeding the limit") {
-    val ssc = setupStreams(input, operation)
+    ssc = setupStreams(input, operation)
     val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000)
     val listener = new StreamingJobProgressListener(ssc)
 
@@ -158,7 +167,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
   }
 
   test("out-of-order onJobStart and onBatchXXX") {
-    val ssc = setupStreams(input, operation)
+    ssc = setupStreams(input, operation)
     val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000)
     val listener = new StreamingJobProgressListener(ssc)
 
@@ -209,7 +218,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers {
   }
 
   test("detect memory leak") {
-    val ssc = setupStreams(input, operation)
+    ssc = setupStreams(input, operation)
     val listener = new StreamingJobProgressListener(ssc)
 
     val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000)
-- 
GitLab