diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 4e702bbb92061f76d8cae0edcae5edb1ccabd867..a3062ac94614ba5e9719dcd0fb32526a8c712202 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.streaming import java.util.concurrent.ConcurrentLinkedQueue -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials import scala.reflect.ClassTag +import org.scalatest.concurrent.Eventually.eventually + import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.{DStream, WindowedDStream} @@ -657,48 +657,57 @@ class BasicOperationsSuite extends TestSuiteBase { .window(Seconds(4), Seconds(2)) } - val operatedStream = runCleanupTest(conf, operation _, - numExpectedOutput = cleanupTestInput.size / 2, rememberDuration = Seconds(3)) - val windowedStream2 = operatedStream.asInstanceOf[WindowedDStream[_]] - val windowedStream1 = windowedStream2.dependencies.head.asInstanceOf[WindowedDStream[_]] - val mappedStream = windowedStream1.dependencies.head - - // Checkpoint remember durations - assert(windowedStream2.rememberDuration === rememberDuration) - assert(windowedStream1.rememberDuration === rememberDuration + windowedStream2.windowDuration) - assert(mappedStream.rememberDuration === - rememberDuration + windowedStream2.windowDuration + windowedStream1.windowDuration) - - // WindowedStream2 should remember till 7 seconds: 10, 9, 8, 7 - // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5, 4 - // MappedStream should remember till 2 seconds: 10, 9, 8, 7, 6, 5, 4, 3, 2 - - // WindowedStream2 - assert(windowedStream2.generatedRDDs.contains(Time(10000))) - assert(windowedStream2.generatedRDDs.contains(Time(8000))) - assert(!windowedStream2.generatedRDDs.contains(Time(6000))) - - // WindowedStream1 - assert(windowedStream1.generatedRDDs.contains(Time(10000))) - assert(windowedStream1.generatedRDDs.contains(Time(4000))) - assert(!windowedStream1.generatedRDDs.contains(Time(3000))) - - // MappedStream - assert(mappedStream.generatedRDDs.contains(Time(10000))) - assert(mappedStream.generatedRDDs.contains(Time(2000))) - assert(!mappedStream.generatedRDDs.contains(Time(1000))) + runCleanupTest( + conf, + operation _, + numExpectedOutput = cleanupTestInput.size / 2, + rememberDuration = Seconds(3)) { operatedStream => + eventually(eventuallyTimeout) { + val windowedStream2 = operatedStream.asInstanceOf[WindowedDStream[_]] + val windowedStream1 = windowedStream2.dependencies.head.asInstanceOf[WindowedDStream[_]] + val mappedStream = windowedStream1.dependencies.head + + // Checkpoint remember durations + assert(windowedStream2.rememberDuration === rememberDuration) + assert( + windowedStream1.rememberDuration === rememberDuration + windowedStream2.windowDuration) + assert(mappedStream.rememberDuration === + rememberDuration + windowedStream2.windowDuration + windowedStream1.windowDuration) + + // WindowedStream2 should remember till 7 seconds: 10, 9, 8, 7 + // WindowedStream1 should remember till 4 seconds: 10, 9, 8, 7, 6, 5, 4 + // MappedStream should remember till 2 seconds: 10, 9, 8, 7, 6, 5, 4, 3, 2 + + // WindowedStream2 + assert(windowedStream2.generatedRDDs.contains(Time(10000))) + assert(windowedStream2.generatedRDDs.contains(Time(8000))) + assert(!windowedStream2.generatedRDDs.contains(Time(6000))) + + // WindowedStream1 + assert(windowedStream1.generatedRDDs.contains(Time(10000))) + assert(windowedStream1.generatedRDDs.contains(Time(4000))) + assert(!windowedStream1.generatedRDDs.contains(Time(3000))) + + // MappedStream + assert(mappedStream.generatedRDDs.contains(Time(10000))) + assert(mappedStream.generatedRDDs.contains(Time(2000))) + assert(!mappedStream.generatedRDDs.contains(Time(1000))) + } + } } test("rdd cleanup - updateStateByKey") { val updateFunc = (values: Seq[Int], state: Option[Int]) => { Some(values.sum + state.getOrElse(0)) } - val stateStream = runCleanupTest( - conf, _.map(_ -> 1).updateStateByKey(updateFunc).checkpoint(Seconds(3))) - - assert(stateStream.rememberDuration === stateStream.checkpointDuration * 2) - assert(stateStream.generatedRDDs.contains(Time(10000))) - assert(!stateStream.generatedRDDs.contains(Time(4000))) + runCleanupTest( + conf, _.map(_ -> 1).updateStateByKey(updateFunc).checkpoint(Seconds(3))) { stateStream => + eventually(eventuallyTimeout) { + assert(stateStream.rememberDuration === stateStream.checkpointDuration * 2) + assert(stateStream.generatedRDDs.contains(Time(10000))) + assert(!stateStream.generatedRDDs.contains(Time(4000))) + } + } } test("rdd cleanup - input blocks and persisted RDDs") { @@ -779,13 +788,16 @@ class BasicOperationsSuite extends TestSuiteBase { } } - /** Test cleanup of RDDs in DStream metadata */ + /** + * Test cleanup of RDDs in DStream metadata. `assertCleanup` is the function that asserts the + * cleanup of RDDs is successful. + */ def runCleanupTest[T: ClassTag]( conf2: SparkConf, operation: DStream[Int] => DStream[T], numExpectedOutput: Int = cleanupTestInput.size, rememberDuration: Duration = null - ): DStream[T] = { + )(assertCleanup: (DStream[T]) => Unit): DStream[T] = { // Setup the stream computation assert(batchDuration === Seconds(1), @@ -794,7 +806,11 @@ class BasicOperationsSuite extends TestSuiteBase { val operatedStream = ssc.graph.getOutputStreams().head.dependencies.head.asInstanceOf[DStream[T]] if (rememberDuration != null) ssc.remember(rememberDuration) - val output = runStreams[(Int, Int)](ssc, cleanupTestInput.size, numExpectedOutput) + val output = runStreams[(Int, Int)]( + ssc, + cleanupTestInput.size, + numExpectedOutput, + () => assertCleanup(operatedStream)) val clock = ssc.scheduler.clock.asInstanceOf[Clock] assert(clock.getTimeMillis() === Seconds(10).milliseconds) assert(output.size === numExpectedOutput) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index fa975a146216d60973ebecf61f42d0564e2c080b..dbab70886102d84cd44730ab17e6836d5e19430b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -359,14 +359,20 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached. * * Returns a sequence of items for each RDD. + * + * @param ssc The StreamingContext + * @param numBatches The number of batches should be run + * @param numExpectedOutput The number of expected output + * @param preStop The function to run before stopping StreamingContext */ def runStreams[V: ClassTag]( ssc: StreamingContext, numBatches: Int, - numExpectedOutput: Int + numExpectedOutput: Int, + preStop: () => Unit = () => {} ): Seq[Seq[V]] = { // Flatten each RDD into a single Seq - runStreamsWithPartitions(ssc, numBatches, numExpectedOutput).map(_.flatten.toSeq) + runStreamsWithPartitions(ssc, numBatches, numExpectedOutput, preStop).map(_.flatten.toSeq) } /** @@ -376,11 +382,17 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { * * Returns a sequence of RDD's. Each RDD is represented as several sequences of items, each * representing one partition. + * + * @param ssc The StreamingContext + * @param numBatches The number of batches should be run + * @param numExpectedOutput The number of expected output + * @param preStop The function to run before stopping StreamingContext */ def runStreamsWithPartitions[V: ClassTag]( ssc: StreamingContext, numBatches: Int, - numExpectedOutput: Int + numExpectedOutput: Int, + preStop: () => Unit = () => {} ): Seq[Seq[Seq[V]]] = { assert(numBatches > 0, "Number of batches to run stream computation is zero") assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero") @@ -424,6 +436,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") Thread.sleep(100) // Give some time for the forgetting old RDDs to complete + preStop() } finally { ssc.stop(stopSparkContext = true) }