diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index 4e702bbb92061f76d8cae0edcae5edb1ccabd867..a3062ac94614ba5e9719dcd0fb32526a8c712202 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -19,13 +19,13 @@ package org.apache.spark.streaming
 
 import java.util.concurrent.ConcurrentLinkedQueue
 
-import scala.collection.JavaConverters._
 import scala.collection.mutable
 import scala.language.existentials
 import scala.reflect.ClassTag
 
+import org.scalatest.concurrent.Eventually.eventually
+
 import org.apache.spark.{SparkConf, SparkException}
-import org.apache.spark.SparkContext._
 import org.apache.spark.rdd.{BlockRDD, RDD}
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.streaming.dstream.{DStream, WindowedDStream}
@@ -657,48 +657,57 @@ class BasicOperationsSuite extends TestSuiteBase {
        .window(Seconds(4), Seconds(2))
     }
 
-    val operatedStream = runCleanupTest(conf, operation _,
-      numExpectedOutput = cleanupTestInput.size / 2, rememberDuration = Seconds(3))
-    val windowedStream2 = operatedStream.asInstanceOf[WindowedDStream[_]]
-    val windowedStream1 = windowedStream2.dependencies.head.asInstanceOf[WindowedDStream[_]]
-    val mappedStream = windowedStream1.dependencies.head
-
-    // Checkpoint remember durations
-    assert(windowedStream2.rememberDuration === rememberDuration)
-    assert(windowedStream1.rememberDuration === rememberDuration + windowedStream2.windowDuration)
-    assert(mappedStream.rememberDuration ===
-      rememberDuration + windowedStream2.windowDuration + windowedStream1.windowDuration)
-
-    // WindowedStream2 should remember till 7 seconds: 10, 9, 8, 7
-    // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5, 4
-    // MappedStream should remember till 2 seconds:    10, 9, 8, 7, 6, 5, 4, 3, 2
-
-    // WindowedStream2
-    assert(windowedStream2.generatedRDDs.contains(Time(10000)))
-    assert(windowedStream2.generatedRDDs.contains(Time(8000)))
-    assert(!windowedStream2.generatedRDDs.contains(Time(6000)))
-
-    // WindowedStream1
-    assert(windowedStream1.generatedRDDs.contains(Time(10000)))
-    assert(windowedStream1.generatedRDDs.contains(Time(4000)))
-    assert(!windowedStream1.generatedRDDs.contains(Time(3000)))
-
-    // MappedStream
-    assert(mappedStream.generatedRDDs.contains(Time(10000)))
-    assert(mappedStream.generatedRDDs.contains(Time(2000)))
-    assert(!mappedStream.generatedRDDs.contains(Time(1000)))
+    runCleanupTest(
+        conf,
+        operation _,
+        numExpectedOutput = cleanupTestInput.size / 2,
+        rememberDuration = Seconds(3)) { operatedStream =>
+      eventually(eventuallyTimeout) {
+        val windowedStream2 = operatedStream.asInstanceOf[WindowedDStream[_]]
+        val windowedStream1 = windowedStream2.dependencies.head.asInstanceOf[WindowedDStream[_]]
+        val mappedStream = windowedStream1.dependencies.head
+
+        // Checkpoint remember durations
+        assert(windowedStream2.rememberDuration === rememberDuration)
+        assert(
+          windowedStream1.rememberDuration === rememberDuration + windowedStream2.windowDuration)
+        assert(mappedStream.rememberDuration ===
+          rememberDuration + windowedStream2.windowDuration + windowedStream1.windowDuration)
+
+        // WindowedStream2 should remember till 7 seconds: 10, 9, 8, 7
+        // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5, 4
+        // MappedStream should remember till 2 seconds:    10, 9, 8, 7, 6, 5, 4, 3, 2
+
+        // WindowedStream2
+        assert(windowedStream2.generatedRDDs.contains(Time(10000)))
+        assert(windowedStream2.generatedRDDs.contains(Time(8000)))
+        assert(!windowedStream2.generatedRDDs.contains(Time(6000)))
+
+        // WindowedStream1
+        assert(windowedStream1.generatedRDDs.contains(Time(10000)))
+        assert(windowedStream1.generatedRDDs.contains(Time(4000)))
+        assert(!windowedStream1.generatedRDDs.contains(Time(3000)))
+
+        // MappedStream
+        assert(mappedStream.generatedRDDs.contains(Time(10000)))
+        assert(mappedStream.generatedRDDs.contains(Time(2000)))
+        assert(!mappedStream.generatedRDDs.contains(Time(1000)))
+      }
+    }
   }
 
   test("rdd cleanup - updateStateByKey") {
     val updateFunc = (values: Seq[Int], state: Option[Int]) => {
       Some(values.sum + state.getOrElse(0))
     }
-    val stateStream = runCleanupTest(
-      conf, _.map(_ -> 1).updateStateByKey(updateFunc).checkpoint(Seconds(3)))
-
-    assert(stateStream.rememberDuration === stateStream.checkpointDuration * 2)
-    assert(stateStream.generatedRDDs.contains(Time(10000)))
-    assert(!stateStream.generatedRDDs.contains(Time(4000)))
+    runCleanupTest(
+      conf, _.map(_ -> 1).updateStateByKey(updateFunc).checkpoint(Seconds(3))) { stateStream =>
+      eventually(eventuallyTimeout) {
+        assert(stateStream.rememberDuration === stateStream.checkpointDuration * 2)
+        assert(stateStream.generatedRDDs.contains(Time(10000)))
+        assert(!stateStream.generatedRDDs.contains(Time(4000)))
+      }
+    }
   }
 
   test("rdd cleanup - input blocks and persisted RDDs") {
@@ -779,13 +788,16 @@ class BasicOperationsSuite extends TestSuiteBase {
     }
   }
 
-  /** Test cleanup of RDDs in DStream metadata */
+  /**
+   * Test cleanup of RDDs in DStream metadata. `assertCleanup` is the function that asserts the
+   * cleanup of RDDs is successful.
+   */
   def runCleanupTest[T: ClassTag](
       conf2: SparkConf,
       operation: DStream[Int] => DStream[T],
       numExpectedOutput: Int = cleanupTestInput.size,
       rememberDuration: Duration = null
-    ): DStream[T] = {
+    )(assertCleanup: (DStream[T]) => Unit): DStream[T] = {
 
     // Setup the stream computation
     assert(batchDuration === Seconds(1),
@@ -794,7 +806,11 @@ class BasicOperationsSuite extends TestSuiteBase {
       val operatedStream =
         ssc.graph.getOutputStreams().head.dependencies.head.asInstanceOf[DStream[T]]
       if (rememberDuration != null) ssc.remember(rememberDuration)
-      val output = runStreams[(Int, Int)](ssc, cleanupTestInput.size, numExpectedOutput)
+      val output = runStreams[(Int, Int)](
+        ssc,
+        cleanupTestInput.size,
+        numExpectedOutput,
+        () => assertCleanup(operatedStream))
       val clock = ssc.scheduler.clock.asInstanceOf[Clock]
       assert(clock.getTimeMillis() === Seconds(10).milliseconds)
       assert(output.size === numExpectedOutput)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index fa975a146216d60973ebecf61f42d0564e2c080b..dbab70886102d84cd44730ab17e6836d5e19430b 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -359,14 +359,20 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging {
    * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached.
    *
    * Returns a sequence of items for each RDD.
+   *
+   * @param ssc The StreamingContext
+   * @param numBatches The number of batches should be run
+   * @param numExpectedOutput The number of expected output
+   * @param preStop The function to run before stopping StreamingContext
    */
   def runStreams[V: ClassTag](
       ssc: StreamingContext,
       numBatches: Int,
-      numExpectedOutput: Int
+      numExpectedOutput: Int,
+      preStop: () => Unit = () => {}
     ): Seq[Seq[V]] = {
     // Flatten each RDD into a single Seq
-    runStreamsWithPartitions(ssc, numBatches, numExpectedOutput).map(_.flatten.toSeq)
+    runStreamsWithPartitions(ssc, numBatches, numExpectedOutput, preStop).map(_.flatten.toSeq)
   }
 
   /**
@@ -376,11 +382,17 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging {
    *
    * Returns a sequence of RDD's. Each RDD is represented as several sequences of items, each
    * representing one partition.
+   *
+   * @param ssc The StreamingContext
+   * @param numBatches The number of batches should be run
+   * @param numExpectedOutput The number of expected output
+   * @param preStop The function to run before stopping StreamingContext
    */
   def runStreamsWithPartitions[V: ClassTag](
       ssc: StreamingContext,
       numBatches: Int,
-      numExpectedOutput: Int
+      numExpectedOutput: Int,
+      preStop: () => Unit = () => {}
     ): Seq[Seq[Seq[V]]] = {
     assert(numBatches > 0, "Number of batches to run stream computation is zero")
     assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero")
@@ -424,6 +436,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging {
       assert(output.size === numExpectedOutput, "Unexpected number of outputs generated")
 
       Thread.sleep(100) // Give some time for the forgetting old RDDs to complete
+      preStop()
     } finally {
       ssc.stop(stopSparkContext = true)
     }