diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 3d1e90a3522a4cce1d566e95b2689f9a22167493..ac07a55cb9101a1d867af68e97eac25451c7ca79 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -74,7 +74,7 @@ object MetadataCleanerType extends Enumeration { // initialization of StreamingContext. It's okay for users trying to configure stuff themselves. object MetadataCleaner { def getDelaySeconds(conf: SparkConf) = { - conf.getInt("spark.cleaner.ttl", 3500) + conf.getInt("spark.cleaner.ttl", -1) } def getDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType): Int = diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index fcc159e85a85bf918f587dc0c626bca6bebe8231..73e7ce6e968c6fdc1575562beb969a3a5b616b79 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.storage.StorageLevel class MQTTStreamSuite extends TestSuiteBase { - test("MQTT input stream") { + test("mqtt input stream") { val ssc = new StreamingContext(master, framework, batchDuration) val brokerUrl = "abc" val topic = "def" diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala index a0a8fe617b134c7c39ae8df77690f12a1fa42797..ccc38784ef671185cc564ea9354e9c11af2f8407 100644 --- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala +++ b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala @@ -23,7 +23,7 @@ import twitter4j.auth.{NullAuthorization, Authorization} class TwitterStreamSuite extends TestSuiteBase { - test("kafka input stream") { + test("twitter input stream") { val ssc = new StreamingContext(master, framework, batchDuration) val filters = Seq("filter1", "filter2") val authorization: Authorization = NullAuthorization.getInstance() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 1249ef4c3d5bb969283df6a0825414ac6e3c2abc..108bc2de3e3e28f653ec7b8d023e7e883312f511 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -40,7 +40,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val graph = ssc.graph val checkpointDir = ssc.checkpointDir val checkpointDuration = ssc.checkpointDuration - val pendingTimes = ssc.scheduler.getPendingTimes() + val pendingTimes = ssc.scheduler.getPendingTimes().toArray val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf) val sparkConf = ssc.conf diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala new file mode 100644 index 0000000000000000000000000000000000000000..1f5dacb543db863ce7d00629cafc1ae651c52b80 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala @@ -0,0 +1,28 @@ +package org.apache.spark.streaming + +private[streaming] class ContextWaiter { + private var error: Throwable = null + private var stopped: Boolean = false + + def notifyError(e: Throwable) = synchronized { + error = e + notifyAll() + } + + def notifyStop() = synchronized { + notifyAll() + } + + def waitForStopOrError(timeout: Long = -1) = synchronized { + // If already had error, then throw it + if (error != null) { + throw error + } + + // If not already stopped, then wait + if (!stopped) { + if (timeout < 0) wait() else wait(timeout) + if (error != null) throw error + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala index 9432a709d0ca964dc37918c2abe771fe5c75ce29..f760093579ebff31c22fd8a62f1f4d33997d6674 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala @@ -54,7 +54,7 @@ import org.apache.spark.util.MetadataCleaner */ abstract class DStream[T: ClassTag] ( - @transient protected[streaming] var ssc: StreamingContext + @transient private[streaming] var ssc: StreamingContext ) extends Serializable with Logging { // ======================================================================= @@ -74,31 +74,31 @@ abstract class DStream[T: ClassTag] ( // Methods and fields available on all DStreams // ======================================================================= - // RDDs generated, marked as protected[streaming] so that testsuites can access it + // RDDs generated, marked as private[streaming] so that testsuites can access it @transient - protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] () + private[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] () // Time zero for the DStream - protected[streaming] var zeroTime: Time = null + private[streaming] var zeroTime: Time = null // Duration for which the DStream will remember each RDD created - protected[streaming] var rememberDuration: Duration = null + private[streaming] var rememberDuration: Duration = null // Storage level of the RDDs in the stream - protected[streaming] var storageLevel: StorageLevel = StorageLevel.NONE + private[streaming] var storageLevel: StorageLevel = StorageLevel.NONE // Checkpoint details - protected[streaming] val mustCheckpoint = false - protected[streaming] var checkpointDuration: Duration = null - protected[streaming] val checkpointData = new DStreamCheckpointData(this) + private[streaming] val mustCheckpoint = false + private[streaming] var checkpointDuration: Duration = null + private[streaming] val checkpointData = new DStreamCheckpointData(this) // Reference to whole DStream graph - protected[streaming] var graph: DStreamGraph = null + private[streaming] var graph: DStreamGraph = null - protected[streaming] def isInitialized = (zeroTime != null) + private[streaming] def isInitialized = (zeroTime != null) // Duration for which the DStream requires its parent DStream to remember each RDD created - protected[streaming] def parentRememberDuration = rememberDuration + private[streaming] def parentRememberDuration = rememberDuration /** Return the StreamingContext associated with this DStream */ def context = ssc @@ -138,7 +138,7 @@ abstract class DStream[T: ClassTag] ( * the validity of future times is calculated. This method also recursively initializes * its parent DStreams. */ - protected[streaming] def initialize(time: Time) { + private[streaming] def initialize(time: Time) { if (zeroTime != null && zeroTime != time) { throw new Exception("ZeroTime is already initialized to " + zeroTime + ", cannot initialize it again to " + time) @@ -164,7 +164,7 @@ abstract class DStream[T: ClassTag] ( dependencies.foreach(_.initialize(zeroTime)) } - protected[streaming] def validate() { + private[streaming] def validate() { assert(rememberDuration != null, "Remember duration is set to null") assert( @@ -228,7 +228,7 @@ abstract class DStream[T: ClassTag] ( logInfo("Initialized and validated " + this) } - protected[streaming] def setContext(s: StreamingContext) { + private[streaming] def setContext(s: StreamingContext) { if (ssc != null && ssc != s) { throw new Exception("Context is already set in " + this + ", cannot set it again") } @@ -237,7 +237,7 @@ abstract class DStream[T: ClassTag] ( dependencies.foreach(_.setContext(ssc)) } - protected[streaming] def setGraph(g: DStreamGraph) { + private[streaming] def setGraph(g: DStreamGraph) { if (graph != null && graph != g) { throw new Exception("Graph is already set in " + this + ", cannot set it again") } @@ -245,7 +245,7 @@ abstract class DStream[T: ClassTag] ( dependencies.foreach(_.setGraph(graph)) } - protected[streaming] def remember(duration: Duration) { + private[streaming] def remember(duration: Duration) { if (duration != null && duration > rememberDuration) { rememberDuration = duration logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this) @@ -254,14 +254,14 @@ abstract class DStream[T: ClassTag] ( } /** Checks whether the 'time' is valid wrt slideDuration for generating RDD */ - protected def isTimeValid(time: Time): Boolean = { + private[streaming] def isTimeValid(time: Time): Boolean = { if (!isInitialized) { throw new Exception (this + " has not been initialized") } else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideDuration)) { logInfo("Time " + time + " is invalid as zeroTime is " + zeroTime + " and slideDuration is " + slideDuration + " and difference is " + (time - zeroTime)) false } else { - logInfo("Time " + time + " is valid") + logDebug("Time " + time + " is valid") true } } @@ -270,7 +270,7 @@ abstract class DStream[T: ClassTag] ( * Retrieve a precomputed RDD of this DStream, or computes the RDD. This is an internal * method that should not be called directly. */ - protected[streaming] def getOrCompute(time: Time): Option[RDD[T]] = { + private[streaming] def getOrCompute(time: Time): Option[RDD[T]] = { // If this DStream was not initialized (i.e., zeroTime not set), then do it // If RDD was already generated, then retrieve it from HashMap generatedRDDs.get(time) match { @@ -311,7 +311,7 @@ abstract class DStream[T: ClassTag] ( * that materializes the corresponding RDD. Subclasses of DStream may override this * to generate their own jobs. */ - protected[streaming] def generateJob(time: Time): Option[Job] = { + private[streaming] def generateJob(time: Time): Option[Job] = { getOrCompute(time) match { case Some(rdd) => { val jobFunc = () => { @@ -330,7 +330,7 @@ abstract class DStream[T: ClassTag] ( * implementation clears the old generated RDDs. Subclasses of DStream may override * this to clear their own metadata along with the generated RDDs. */ - protected[streaming] def clearMetadata(time: Time) { + private[streaming] def clearMetadata(time: Time) { val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration)) generatedRDDs --= oldRDDs.keys logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " + @@ -339,9 +339,9 @@ abstract class DStream[T: ClassTag] ( } /* Adds metadata to the Stream while it is running. - * This methd should be overwritten by sublcasses of InputDStream. + * This method should be overwritten by sublcasses of InputDStream. */ - protected[streaming] def addMetadata(metadata: Any) { + private[streaming] def addMetadata(metadata: Any) { if (metadata != null) { logInfo("Dropping Metadata: " + metadata.toString) } @@ -354,18 +354,18 @@ abstract class DStream[T: ClassTag] ( * checkpointData. Subclasses of DStream (especially those of InputDStream) may override * this method to save custom checkpoint data. */ - protected[streaming] def updateCheckpointData(currentTime: Time) { - logInfo("Updating checkpoint data for time " + currentTime) + private[streaming] def updateCheckpointData(currentTime: Time) { + logDebug("Updating checkpoint data for time " + currentTime) checkpointData.update(currentTime) dependencies.foreach(_.updateCheckpointData(currentTime)) logDebug("Updated checkpoint data for time " + currentTime + ": " + checkpointData) } - protected[streaming] def clearCheckpointData(time: Time) { - logInfo("Clearing checkpoint data") + private[streaming] def clearCheckpointData(time: Time) { + logDebug("Clearing checkpoint data") checkpointData.cleanup(time) dependencies.foreach(_.clearCheckpointData(time)) - logInfo("Cleared checkpoint data") + logDebug("Cleared checkpoint data") } /** @@ -374,7 +374,7 @@ abstract class DStream[T: ClassTag] ( * from the checkpoint file names stored in checkpointData. Subclasses of DStream that * override the updateCheckpointData() method would also need to override this method. */ - protected[streaming] def restoreCheckpointData() { + private[streaming] def restoreCheckpointData() { // Create RDDs from the checkpoint data logInfo("Restoring checkpoint data") checkpointData.restore() @@ -699,7 +699,7 @@ abstract class DStream[T: ClassTag] ( /** * Return all the RDDs defined by the Interval object (both end times included) */ - protected[streaming] def slice(interval: Interval): Seq[RDD[T]] = { + def slice(interval: Interval): Seq[RDD[T]] = { slice(interval.beginTime, interval.endTime) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index eee9591ffc383e808978c35a7dce2e356fa2eca4..668e5324e6b64c8a279adefb67ae30d5b8263a7f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming -import dstream.InputDStream +import org.apache.spark.streaming.dstream.{NetworkInputDStream, InputDStream} import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import collection.mutable.ArrayBuffer import org.apache.spark.Logging @@ -103,6 +103,12 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { def getOutputStreams() = this.synchronized { outputStreams.toArray } + def getNetworkInputStreams() = this.synchronized { + inputStreams.filter(_.isInstanceOf[NetworkInputDStream[_]]) + .map(_.asInstanceOf[NetworkInputDStream[_]]) + .toArray + } + def generateJobs(time: Time): Seq[Job] = { logDebug("Generating jobs for time " + time) val jobs = this.synchronized { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index dd34f6f4f2b46c6301593d9246cdd7363a5b92e7..ee83ae902be11adff1aef4999e85f29877415615 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -46,7 +46,7 @@ import org.apache.hadoop.conf.Configuration * information (such as, cluster URL and job name) to internally create a SparkContext, it provides * methods used to create DStream from various input sources. */ -class StreamingContext private ( +class StreamingContext private[streaming] ( sc_ : SparkContext, cp_ : Checkpoint, batchDur_ : Duration @@ -101,20 +101,9 @@ class StreamingContext private ( "both SparkContext and checkpoint as null") } - private val conf_ = Option(sc_).map(_.conf).getOrElse(cp_.sparkConf) + private[streaming] val isCheckpointPresent = (cp_ != null) - if(cp_ != null && cp_.delaySeconds >= 0 && MetadataCleaner.getDelaySeconds(conf_) < 0) { - MetadataCleaner.setDelaySeconds(conf_, cp_.delaySeconds) - } - - if (MetadataCleaner.getDelaySeconds(conf_) < 0) { - throw new SparkException("Spark Streaming cannot be used without setting spark.cleaner.ttl; " - + "set this property before creating a SparkContext (use SPARK_JAVA_OPTS for the shell)") - } - - protected[streaming] val isCheckpointPresent = (cp_ != null) - - protected[streaming] val sc: SparkContext = { + private[streaming] val sc: SparkContext = { if (isCheckpointPresent) { new SparkContext(cp_.sparkConf) } else { @@ -122,11 +111,16 @@ class StreamingContext private ( } } - protected[streaming] val conf = sc.conf + if (MetadataCleaner.getDelaySeconds(sc.conf) < 0) { + throw new SparkException("Spark Streaming cannot be used without setting spark.cleaner.ttl; " + + "set this property before creating a SparkContext (use SPARK_JAVA_OPTS for the shell)") + } + + private[streaming] val conf = sc.conf - protected[streaming] val env = SparkEnv.get + private[streaming] val env = SparkEnv.get - protected[streaming] val graph: DStreamGraph = { + private[streaming] val graph: DStreamGraph = { if (isCheckpointPresent) { cp_.graph.setContext(this) cp_.graph.restoreCheckpointData() @@ -139,10 +133,9 @@ class StreamingContext private ( } } - protected[streaming] val nextNetworkInputStreamId = new AtomicInteger(0) - protected[streaming] var networkInputTracker: NetworkInputTracker = null + private val nextNetworkInputStreamId = new AtomicInteger(0) - protected[streaming] var checkpointDir: String = { + private[streaming] var checkpointDir: String = { if (isCheckpointPresent) { sc.setCheckpointDir(cp_.checkpointDir) cp_.checkpointDir @@ -151,11 +144,13 @@ class StreamingContext private ( } } - protected[streaming] val checkpointDuration: Duration = { + private[streaming] val checkpointDuration: Duration = { if (isCheckpointPresent) cp_.checkpointDuration else graph.batchDuration } - protected[streaming] val scheduler = new JobScheduler(this) + private[streaming] val scheduler = new JobScheduler(this) + + private[streaming] val waiter = new ContextWaiter /** * Return the associated Spark context */ @@ -191,11 +186,11 @@ class StreamingContext private ( } } - protected[streaming] def initialCheckpoint: Checkpoint = { + private[streaming] def initialCheckpoint: Checkpoint = { if (isCheckpointPresent) cp_ else null } - protected[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() + private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() /** * Create an input stream with any arbitrary user implemented network receiver. @@ -416,7 +411,7 @@ class StreamingContext private ( scheduler.listenerBus.addListener(streamingListener) } - protected def validate() { + private def validate() { assert(graph != null, "Graph is null") graph.validate() @@ -430,38 +425,37 @@ class StreamingContext private ( /** * Start the execution of the streams. */ - def start() { + def start() = synchronized { validate() + scheduler.start() + } - // Get the network input streams - val networkInputStreams = graph.getInputStreams().filter(s => s match { - case n: NetworkInputDStream[_] => true - case _ => false - }).map(_.asInstanceOf[NetworkInputDStream[_]]).toArray - - // Start the network input tracker (must start before receivers) - if (networkInputStreams.length > 0) { - networkInputTracker = new NetworkInputTracker(this, networkInputStreams) - networkInputTracker.start() - } - Thread.sleep(1000) + /** + * Wait for the execution to stop. Any exceptions that occurs during the execution + * will be thrown in this thread. + */ + def awaitTermination() { + waiter.waitForStopOrError() + } - // Start the scheduler - scheduler.start() + /** + * Wait for the execution to stop. Any exceptions that occurs during the execution + * will be thrown in this thread. + * @param timeout time to wait in milliseconds + */ + def awaitTermination(timeout: Long) { + waiter.waitForStopOrError(timeout) } /** * Stop the execution of the streams. + * @param stopSparkContext Stop the associated SparkContext or not */ - def stop() { - try { - if (scheduler != null) scheduler.stop() - if (networkInputTracker != null) networkInputTracker.stop() - sc.stop() - logInfo("StreamingContext stopped successfully") - } catch { - case e: Exception => logWarning("Error while stopping", e) - } + def stop(stopSparkContext: Boolean = true) = synchronized { + scheduler.stop() + logInfo("StreamingContext stopped successfully") + waiter.notifyStop() + if (stopSparkContext) sc.stop() } } @@ -472,6 +466,8 @@ class StreamingContext private ( object StreamingContext extends Logging { + private[streaming] val DEFAULT_CLEANER_TTL = 3600 + implicit def toPairDStreamFunctions[K: ClassTag, V: ClassTag](stream: DStream[(K,V)]) = { new PairDStreamFunctions[K, V](stream) } @@ -515,37 +511,29 @@ object StreamingContext extends Logging { */ def jarOfClass(cls: Class[_]) = SparkContext.jarOfClass(cls) - - protected[streaming] def createNewSparkContext(conf: SparkConf): SparkContext = { + private[streaming] def createNewSparkContext(conf: SparkConf): SparkContext = { // Set the default cleaner delay to an hour if not already set. // This should be sufficient for even 1 second batch intervals. if (MetadataCleaner.getDelaySeconds(conf) < 0) { - MetadataCleaner.setDelaySeconds(conf, 3600) + MetadataCleaner.setDelaySeconds(conf, DEFAULT_CLEANER_TTL) } val sc = new SparkContext(conf) sc } - protected[streaming] def createNewSparkContext( + private[streaming] def createNewSparkContext( master: String, appName: String, sparkHome: String, jars: Seq[String], environment: Map[String, String] ): SparkContext = { - val conf = SparkContext.updatedConf( new SparkConf(), master, appName, sparkHome, jars, environment) - // Set the default cleaner delay to an hour if not already set. - // This should be sufficient for even 1 second batch intervals. - if (MetadataCleaner.getDelaySeconds(conf) < 0) { - MetadataCleaner.setDelaySeconds(conf, 3600) - } - val sc = new SparkContext(master, appName, sparkHome, jars, environment) - sc + createNewSparkContext(conf) } - protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { + private[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { if (prefix == null) { time.milliseconds.toString } else if (suffix == null || suffix.length ==0) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 523173d45a19ce8d0eda49d82fd135d14064152d..b4c46f5e506ac1403efc0efe93eb30988cd1b40d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -483,9 +483,28 @@ class JavaStreamingContext(val ssc: StreamingContext) { def start() = ssc.start() /** - * Stop the execution of the streams. + * Wait for the execution to stop. Any exceptions that occurs during the execution + * will be thrown in this thread. + */ + def awaitTermination() = ssc.awaitTermination() + + /** + * Wait for the execution to stop. Any exceptions that occurs during the execution + * will be thrown in this thread. + * @param timeout time to wait in milliseconds + */ + def awaitTermination(timeout: Long) = ssc.awaitTermination(timeout) + + /** + * Stop the execution of the streams. Will stop the associated JavaSparkContext as well. */ def stop() = ssc.stop() + + /** + * Stop the execution of the streams. + * @param stopSparkContext Stop the associated SparkContext or not + */ + def stop(stopSparkContext: Boolean) = ssc.stop(stopSparkContext) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index f01e67fe13096ca3b5db44e2b74c52fc573ec0a1..8f84232cab3485f4c8d78200312f034214dcb9d8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -43,7 +43,7 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) * This ensures that InputDStream.compute() is called strictly on increasing * times. */ - override protected def isTimeValid(time: Time): Boolean = { + override private[streaming] def isTimeValid(time: Time): Boolean = { if (!super.isTimeValid(time)) { false // Time not valid } else { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala index d41f726f8322c9275fdebbfefc0d925702cf344f..0f1f6fc2cec4a846404dc90e3fc3e9a56f1dd312 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala @@ -68,7 +68,7 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte // then this returns an empty RDD. This may happen when recovering from a // master failure if (validTime >= graph.startTime) { - val blockIds = ssc.networkInputTracker.getBlockIds(id, validTime) + val blockIds = ssc.scheduler.networkInputTracker.getBlockIds(id, validTime) Some(new BlockRDD[T](ssc.sc, blockIds)) } else { Some(new BlockRDD[T](ssc.sc, Array[BlockId]())) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala index c8ee93bf5bde706f90cb23c6cd11f64635da9023..7e0f6b2cdfc084446eb238c8d5d79e8617beea46 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming.scheduler import org.apache.spark.streaming.Time +import scala.util.Try /** * Class representing a Spark computation. It may contain multiple Spark jobs. @@ -25,12 +26,10 @@ import org.apache.spark.streaming.Time private[streaming] class Job(val time: Time, func: () => _) { var id: String = _ + var result: Try[_] = null - def run(): Long = { - val startTime = System.currentTimeMillis - func() - val stopTime = System.currentTimeMillis - (stopTime - startTime) + def run() { + result = Try(func()) } def setId(number: Int) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 2fa6853ae0613968b35806032438d3388714116b..b5f11d344068d828330377bfa37c251e2f208736 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -17,11 +17,11 @@ package org.apache.spark.streaming.scheduler -import akka.actor.{Props, Actor} -import org.apache.spark.SparkEnv -import org.apache.spark.Logging +import akka.actor.{ActorRef, ActorSystem, Props, Actor} +import org.apache.spark.{SparkException, SparkEnv, Logging} import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter} import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock} +import scala.util.{Failure, Success, Try} /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent @@ -37,29 +37,38 @@ private[scheduler] case class ClearCheckpointData(time: Time) extends JobGenerat private[streaming] class JobGenerator(jobScheduler: JobScheduler) extends Logging { - val ssc = jobScheduler.ssc - val graph = ssc.graph - val eventProcessorActor = ssc.env.actorSystem.actorOf(Props(new Actor { - def receive = { - case event: JobGeneratorEvent => - logDebug("Got event of type " + event.getClass.getName) - processEvent(event) - } - })) + private val ssc = jobScheduler.ssc + private val graph = ssc.graph val clock = { val clockClass = ssc.sc.conf.get( "spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") Class.forName(clockClass).newInstance().asInstanceOf[Clock] } - val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, - longTime => eventProcessorActor ! GenerateJobs(new Time(longTime))) - lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { + private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, + longTime => eventActor ! GenerateJobs(new Time(longTime))) + private lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration) } else { null } + // eventActor is created when generator starts. + // This not being null means the scheduler has been started and not stopped + private var eventActor: ActorRef = null + + /** Start generation of jobs */ def start() = synchronized { + if (eventActor != null) { + throw new SparkException("JobGenerator already started") + } + + eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { + def receive = { + case event: JobGeneratorEvent => + logDebug("Got event of type " + event.getClass.getName) + processEvent(event) + } + }), "JobGenerator") if (ssc.isCheckpointPresent) { restart() } else { @@ -67,22 +76,26 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { } } - def stop() { - timer.stop() - if (checkpointWriter != null) checkpointWriter.stop() - ssc.graph.stop() - logInfo("JobGenerator stopped") + /** Stop generation of jobs */ + def stop() = synchronized { + if (eventActor != null) { + timer.stop() + ssc.env.actorSystem.stop(eventActor) + if (checkpointWriter != null) checkpointWriter.stop() + ssc.graph.stop() + logInfo("JobGenerator stopped") + } } /** * On batch completion, clear old metadata and checkpoint computation. */ - private[scheduler] def onBatchCompletion(time: Time) { - eventProcessorActor ! ClearMetadata(time) + def onBatchCompletion(time: Time) { + eventActor ! ClearMetadata(time) } - private[streaming] def onCheckpointCompletion(time: Time) { - eventProcessorActor ! ClearCheckpointData(time) + def onCheckpointCompletion(time: Time) { + eventActor ! ClearCheckpointData(time) } /** Processes all events */ @@ -121,14 +134,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val checkpointTime = ssc.initialCheckpoint.checkpointTime val restartTime = new Time(timer.getRestartTime(graph.zeroTime.milliseconds)) val downTimes = checkpointTime.until(restartTime, batchDuration) - logInfo("Batches during down time (" + downTimes.size + " batches): " + downTimes.mkString(", ")) + logInfo("Batches during down time (" + downTimes.size + " batches): " + + downTimes.mkString(", ")) // Batches that were unprocessed before failure val pendingTimes = ssc.initialCheckpoint.pendingTimes.sorted(Time.ordering) - logInfo("Batches pending processing (" + pendingTimes.size + " batches): " + pendingTimes.mkString(", ")) + logInfo("Batches pending processing (" + pendingTimes.size + " batches): " + + pendingTimes.mkString(", ")) // Reschedule jobs for these times val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering) - logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + timesToReschedule.mkString(", ")) + logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + + timesToReschedule.mkString(", ")) timesToReschedule.foreach(time => jobScheduler.runJobs(time, graph.generateJobs(time)) ) @@ -141,15 +157,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Generate jobs and perform checkpoint for the given `time`. */ private def generateJobs(time: Time) { SparkEnv.set(ssc.env) - logInfo("\n-----------------------------------------------------\n") - jobScheduler.runJobs(time, graph.generateJobs(time)) - eventProcessorActor ! DoCheckpoint(time) + Try(graph.generateJobs(time)) match { + case Success(jobs) => jobScheduler.runJobs(time, jobs) + case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) + } + eventActor ! DoCheckpoint(time) } /** Clear DStream metadata for the given `time`. */ private def clearMetadata(time: Time) { ssc.graph.clearMetadata(time) - eventProcessorActor ! DoCheckpoint(time) + eventActor ! DoCheckpoint(time) } /** Clear DStream checkpoint data for the given `time`. */ @@ -166,4 +184,3 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { } } } - diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 30c070c274d85e0f487447fb467828c68f7d8d99..de675d3c7fb94b4983e1583ea782c4a29b8dce73 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -17,36 +17,68 @@ package org.apache.spark.streaming.scheduler -import org.apache.spark.Logging -import org.apache.spark.SparkEnv +import scala.util.{Failure, Success, Try} +import scala.collection.JavaConversions._ import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors} -import scala.collection.mutable.HashSet +import akka.actor.{ActorRef, Actor, Props} +import org.apache.spark.{SparkException, Logging, SparkEnv} import org.apache.spark.streaming._ + +private[scheduler] sealed trait JobSchedulerEvent +private[scheduler] case class JobStarted(job: Job) extends JobSchedulerEvent +private[scheduler] case class JobCompleted(job: Job) extends JobSchedulerEvent +private[scheduler] case class ErrorReported(msg: String, e: Throwable) extends JobSchedulerEvent + /** * This class schedules jobs to be run on Spark. It uses the JobGenerator to generate - * the jobs and runs them using a thread pool. Number of threads + * the jobs and runs them using a thread pool. */ private[streaming] class JobScheduler(val ssc: StreamingContext) extends Logging { - val jobSets = new ConcurrentHashMap[Time, JobSet] - val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1) - val executor = Executors.newFixedThreadPool(numConcurrentJobs) - val generator = new JobGenerator(this) + private val jobSets = new ConcurrentHashMap[Time, JobSet] + private val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1) + private val executor = Executors.newFixedThreadPool(numConcurrentJobs) + private val jobGenerator = new JobGenerator(this) + val clock = jobGenerator.clock val listenerBus = new StreamingListenerBus() - def clock = generator.clock + // These two are created only when scheduler starts. + // eventActor not being null means the scheduler has been started and not stopped + var networkInputTracker: NetworkInputTracker = null + private var eventActor: ActorRef = null + + + def start() = synchronized { + if (eventActor != null) { + throw new SparkException("JobScheduler already started") + } - def start() { - generator.start() + eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { + def receive = { + case event: JobSchedulerEvent => processEvent(event) + } + }), "JobScheduler") + listenerBus.start() + networkInputTracker = new NetworkInputTracker(ssc) + networkInputTracker.start() + Thread.sleep(1000) + jobGenerator.start() + logInfo("JobScheduler started") } - def stop() { - generator.stop() - executor.shutdown() - if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { - executor.shutdownNow() + def stop() = synchronized { + if (eventActor != null) { + jobGenerator.stop() + networkInputTracker.stop() + executor.shutdown() + if (!executor.awaitTermination(2, TimeUnit.SECONDS)) { + executor.shutdownNow() + } + listenerBus.stop() + ssc.env.actorSystem.stop(eventActor) + logInfo("JobScheduler stopped") } } @@ -61,46 +93,67 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } } - def getPendingTimes(): Array[Time] = { - jobSets.keySet.toArray(new Array[Time](0)) + def getPendingTimes(): Seq[Time] = { + jobSets.keySet.toSeq + } + + def reportError(msg: String, e: Throwable) { + eventActor ! ErrorReported(msg, e) } - private def beforeJobStart(job: Job) { + private def processEvent(event: JobSchedulerEvent) { + try { + event match { + case JobStarted(job) => handleJobStart(job) + case JobCompleted(job) => handleJobCompletion(job) + case ErrorReported(m, e) => handleError(m, e) + } + } catch { + case e: Throwable => + reportError("Error in job scheduler", e) + } + } + + private def handleJobStart(job: Job) { val jobSet = jobSets.get(job.time) if (!jobSet.hasStarted) { - listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo())) + listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo)) } - jobSet.beforeJobStart(job) + jobSet.handleJobStart(job) logInfo("Starting job " + job.id + " from job set of time " + jobSet.time) - SparkEnv.set(generator.ssc.env) + SparkEnv.set(ssc.env) } - private def afterJobEnd(job: Job) { - val jobSet = jobSets.get(job.time) - jobSet.afterJobStop(job) - logInfo("Finished job " + job.id + " from job set of time " + jobSet.time) - if (jobSet.hasCompleted) { - jobSets.remove(jobSet.time) - generator.onBatchCompletion(jobSet.time) - logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format( - jobSet.totalDelay / 1000.0, jobSet.time.toString, - jobSet.processingDelay / 1000.0 - )) - listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo())) + private def handleJobCompletion(job: Job) { + job.result match { + case Success(_) => + val jobSet = jobSets.get(job.time) + jobSet.handleJobCompletion(job) + logInfo("Finished job " + job.id + " from job set of time " + jobSet.time) + if (jobSet.hasCompleted) { + jobSets.remove(jobSet.time) + jobGenerator.onBatchCompletion(jobSet.time) + logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format( + jobSet.totalDelay / 1000.0, jobSet.time.toString, + jobSet.processingDelay / 1000.0 + )) + listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo)) + } + case Failure(e) => + reportError("Error running job " + job, e) } } - private[streaming] - class JobHandler(job: Job) extends Runnable { + private def handleError(msg: String, e: Throwable) { + logError(msg, e) + ssc.waiter.notifyError(e) + } + + private class JobHandler(job: Job) extends Runnable { def run() { - beforeJobStart(job) - try { - job.run() - } catch { - case e: Exception => - logError("Running " + job + " failed", e) - } - afterJobEnd(job) + eventActor ! JobStarted(job) + job.run() + eventActor ! JobCompleted(job) } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index 57268674ead9dd22a9c77a941f0195544400999c..fcf303aee6cd73e34b3af85162c65cd36b63e349 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.scheduler -import scala.collection.mutable.HashSet +import scala.collection.mutable.{ArrayBuffer, HashSet} import org.apache.spark.streaming.Time /** Class representing a set of Jobs @@ -27,25 +27,25 @@ private[streaming] case class JobSet(time: Time, jobs: Seq[Job]) { private val incompleteJobs = new HashSet[Job]() - var submissionTime = System.currentTimeMillis() // when this jobset was submitted - var processingStartTime = -1L // when the first job of this jobset started processing - var processingEndTime = -1L // when the last job of this jobset finished processing + private val submissionTime = System.currentTimeMillis() // when this jobset was submitted + private var processingStartTime = -1L // when the first job of this jobset started processing + private var processingEndTime = -1L // when the last job of this jobset finished processing jobs.zipWithIndex.foreach { case (job, i) => job.setId(i) } incompleteJobs ++= jobs - def beforeJobStart(job: Job) { + def handleJobStart(job: Job) { if (processingStartTime < 0) processingStartTime = System.currentTimeMillis() } - def afterJobStop(job: Job) { + def handleJobCompletion(job: Job) { incompleteJobs -= job if (hasCompleted) processingEndTime = System.currentTimeMillis() } - def hasStarted() = (processingStartTime > 0) + def hasStarted = processingStartTime > 0 - def hasCompleted() = incompleteJobs.isEmpty + def hasCompleted = incompleteJobs.isEmpty // Time taken to process all the jobs from the time they started processing // (i.e. not including the time they wait in the streaming scheduler queue) @@ -57,7 +57,7 @@ case class JobSet(time: Time, jobs: Seq[Job]) { processingEndTime - time.milliseconds } - def toBatchInfo(): BatchInfo = { + def toBatchInfo: BatchInfo = { new BatchInfo( time, submissionTime, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala index 75f724464348c107f32498d1457b78c9b8f97c61..0d9733fa69a12a67f386fa736f379aebfe06619c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala @@ -19,8 +19,7 @@ package org.apache.spark.streaming.scheduler import org.apache.spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver} import org.apache.spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError} -import org.apache.spark.Logging -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkException, Logging, SparkEnv} import org.apache.spark.SparkContext._ import scala.collection.mutable.HashMap @@ -32,6 +31,7 @@ import akka.pattern.ask import akka.dispatch._ import org.apache.spark.storage.BlockId import org.apache.spark.streaming.{Time, StreamingContext} +import org.apache.spark.util.AkkaUtils private[streaming] sealed trait NetworkInputTrackerMessage private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage @@ -39,33 +39,47 @@ private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[BlockId], m private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage /** - * This class manages the execution of the receivers of NetworkInputDStreams. + * This class manages the execution of the receivers of NetworkInputDStreams. Instance of + * this class must be created after all input streams have been added and StreamingContext.start() + * has been called because it needs the final set of input streams at the time of instantiation. */ private[streaming] -class NetworkInputTracker( - @transient ssc: StreamingContext, - @transient networkInputStreams: Array[NetworkInputDStream[_]]) - extends Logging { +class NetworkInputTracker(ssc: StreamingContext) extends Logging { + val networkInputStreams = ssc.graph.getNetworkInputStreams() val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*) val receiverExecutor = new ReceiverExecutor() val receiverInfo = new HashMap[Int, ActorRef] val receivedBlockIds = new HashMap[Int, Queue[BlockId]] - val timeout = 5000.milliseconds + val timeout = AkkaUtils.askTimeout(ssc.conf) + + // actor is created when generator starts. + // This not being null means the tracker has been started and not stopped + var actor: ActorRef = null var currentTime: Time = null /** Start the actor and receiver execution thread. */ def start() { - ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker") - receiverExecutor.start() + if (actor != null) { + throw new SparkException("NetworkInputTracker already started") + } + + if (!networkInputStreams.isEmpty) { + actor = ssc.env.actorSystem.actorOf(Props(new NetworkInputTrackerActor), "NetworkInputTracker") + receiverExecutor.start() + logInfo("NetworkInputTracker started") + } } /** Stop the receiver execution thread. */ def stop() { - // TODO: stop the actor as well - receiverExecutor.interrupt() - receiverExecutor.stopReceivers() + if (!networkInputStreams.isEmpty && actor != null) { + receiverExecutor.interrupt() + receiverExecutor.stopReceivers() + ssc.env.actorSystem.stop(actor) + logInfo("NetworkInputTracker stopped") + } } /** Return all the blocks received from a receiver. */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala index 36225e190cd7917502f23debf6a7b9c77b14743e..461ea3506477f8fcfcb6e6cadb1141bf3bf15473 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala @@ -24,9 +24,10 @@ import org.apache.spark.util.Distribution sealed trait StreamingListenerEvent case class StreamingListenerBatchCompleted(batchInfo: BatchInfo) extends StreamingListenerEvent - case class StreamingListenerBatchStarted(batchInfo: BatchInfo) extends StreamingListenerEvent +/** An event used in the listener to shutdown the listener daemon thread. */ +private[scheduler] case object StreamingListenerShutdown extends StreamingListenerEvent /** * A listener interface for receiving information about an ongoing streaming diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala index 110a20f282f110879ad7836399f2f9e3784a1ac1..6e6e22e1aff48e2aa1d9efd29c214acbf5612143 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala @@ -31,7 +31,7 @@ private[spark] class StreamingListenerBus() extends Logging { private val eventQueue = new LinkedBlockingQueue[StreamingListenerEvent](EVENT_QUEUE_CAPACITY) private var queueFullErrorMessageLogged = false - new Thread("StreamingListenerBus") { + val listenerThread = new Thread("StreamingListenerBus") { setDaemon(true) override def run() { while (true) { @@ -41,11 +41,18 @@ private[spark] class StreamingListenerBus() extends Logging { listeners.foreach(_.onBatchStarted(batchStarted)) case batchCompleted: StreamingListenerBatchCompleted => listeners.foreach(_.onBatchCompleted(batchCompleted)) + case StreamingListenerShutdown => + // Get out of the while loop and shutdown the daemon thread + return case _ => } } } - }.start() + } + + def start() { + listenerThread.start() + } def addListener(listener: StreamingListener) { listeners += listener @@ -54,9 +61,9 @@ private[spark] class StreamingListenerBus() extends Logging { def post(event: StreamingListenerEvent) { val eventAdded = eventQueue.offer(event) if (!eventAdded && !queueFullErrorMessageLogged) { - logError("Dropping SparkListenerEvent because no remaining room in event queue. " + - "This likely means one of the SparkListeners is too slow and cannot keep up with the " + - "rate at which tasks are being started by the scheduler.") + logError("Dropping StreamingListenerEvent because no remaining room in event queue. " + + "This likely means one of the StreamingListeners is too slow and cannot keep up with the " + + "rate at which events are being started by the scheduler.") queueFullErrorMessageLogged = true } } @@ -68,7 +75,7 @@ private[spark] class StreamingListenerBus() extends Logging { */ def waitUntilEmpty(timeoutMillis: Int): Boolean = { val finishTime = System.currentTimeMillis + timeoutMillis - while (!eventQueue.isEmpty()) { + while (!eventQueue.isEmpty) { if (System.currentTimeMillis > finishTime) { return false } @@ -78,4 +85,6 @@ private[spark] class StreamingListenerBus() extends Logging { } return true } + + def stop(): Unit = post(StreamingListenerShutdown) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index d644240405caa478f6b838473a8d7f7475615942..559c2473851b30bf73f9d376126f846edf826c6d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -20,17 +20,7 @@ package org.apache.spark.streaming.util private[streaming] class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) { - private val minPollTime = 25L - - private val pollTime = { - if (period / 10.0 > minPollTime) { - (period / 10.0).toLong - } else { - minPollTime - } - } - - private val thread = new Thread() { + private val thread = new Thread("RecurringTimer") { override def run() { loop } } @@ -66,7 +56,6 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => callback(nextTime) nextTime += period } - } catch { case e: InterruptedException => } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 9a187ce031f091475cfc961000fd6ac6a1beeba9..9406e0e20a403e83cf7689e63b8faa3421556104 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -375,11 +375,7 @@ class BasicOperationsSuite extends TestSuiteBase { } test("slice") { - val conf2 = new SparkConf() - .setMaster("local[2]") - .setAppName("BasicOperationsSuite") - .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") - val ssc = new StreamingContext(new SparkContext(conf2), Seconds(1)) + val ssc = new StreamingContext(conf, Seconds(1)) val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) val stream = new TestInputStream[Int](ssc, input, 2) ssc.registerInputStream(stream) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 6499de98c925e3426ed942114565aef7c161b209..9590bca9892fe6899be0171450915b4de9394f68 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -28,6 +28,8 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.dstream.FileInputDStream import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.util.Utils +import org.apache.spark.SparkConf /** * This test suites tests the checkpointing functionality of DStreams - @@ -142,6 +144,26 @@ class CheckpointSuite extends TestSuiteBase { ssc = null } + // This tests whether spark conf persists through checkpoints, and certain + // configs gets scrubbed + test("persistence of conf through checkpoints") { + val key = "spark.mykey" + val value = "myvalue" + System.setProperty(key, value) + ssc = new StreamingContext(master, framework, batchDuration) + val cp = new Checkpoint(ssc, Time(1000)) + assert(!cp.sparkConf.contains("spark.driver.host")) + assert(!cp.sparkConf.contains("spark.driver.port")) + assert(!cp.sparkConf.contains("spark.hostPort")) + assert(cp.sparkConf.get(key) === value) + ssc.stop() + val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) + assert(!newCp.sparkConf.contains("spark.driver.host")) + assert(!newCp.sparkConf.contains("spark.driver.port")) + assert(!newCp.sparkConf.contains("spark.hostPort")) + assert(newCp.sparkConf.get(key) === value) + } + // This tests whether the systm can recover from a master failure with simple // non-stateful operations. This assumes as reliable, replayable input diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..a477d200c91e37ce97ce0e6fe731df2f6899e9b9 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import org.scalatest.{FunSuite, BeforeAndAfter} +import org.scalatest.exceptions.TestFailedDueToTimeoutException +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ +import org.apache.spark.{SparkException, SparkConf, SparkContext} +import org.apache.spark.util.{Utils, MetadataCleaner} + +class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { + + val master = "local[2]" + val appName = this.getClass.getSimpleName + val batchDuration = Seconds(1) + val sparkHome = "someDir" + val envPair = "key" -> "value" + val ttl = StreamingContext.DEFAULT_CLEANER_TTL + 100 + + var sc: SparkContext = null + var ssc: StreamingContext = null + + before { + System.clearProperty("spark.cleaner.ttl") + } + + after { + if (ssc != null) { + ssc.stop() + ssc = null + } + if (sc != null) { + sc.stop() + sc = null + } + } + + test("from no conf constructor") { + ssc = new StreamingContext(master, appName, batchDuration) + assert(ssc.sparkContext.conf.get("spark.master") === master) + assert(ssc.sparkContext.conf.get("spark.app.name") === appName) + assert(MetadataCleaner.getDelaySeconds(ssc.sparkContext.conf) === + StreamingContext.DEFAULT_CLEANER_TTL) + } + + test("from no conf + spark home") { + ssc = new StreamingContext(master, appName, batchDuration, sparkHome, Nil) + assert(ssc.conf.get("spark.home") === sparkHome) + assert(MetadataCleaner.getDelaySeconds(ssc.sparkContext.conf) === + StreamingContext.DEFAULT_CLEANER_TTL) + } + + test("from no conf + spark home + env") { + ssc = new StreamingContext(master, appName, batchDuration, + sparkHome, Nil, Map(envPair)) + assert(ssc.conf.getExecutorEnv.exists(_ == envPair)) + assert(MetadataCleaner.getDelaySeconds(ssc.sparkContext.conf) === + StreamingContext.DEFAULT_CLEANER_TTL) + } + + test("from conf without ttl set") { + val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) + ssc = new StreamingContext(myConf, batchDuration) + assert(MetadataCleaner.getDelaySeconds(ssc.conf) === + StreamingContext.DEFAULT_CLEANER_TTL) + } + + test("from conf with ttl set") { + val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) + myConf.set("spark.cleaner.ttl", ttl.toString) + ssc = new StreamingContext(myConf, batchDuration) + assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === ttl) + } + + test("from existing SparkContext without ttl set") { + sc = new SparkContext(master, appName) + val exception = intercept[SparkException] { + ssc = new StreamingContext(sc, batchDuration) + } + assert(exception.getMessage.contains("ttl")) + } + + test("from existing SparkContext with ttl set") { + val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) + myConf.set("spark.cleaner.ttl", ttl.toString) + ssc = new StreamingContext(myConf, batchDuration) + assert(ssc.conf.getInt("spark.cleaner.ttl", -1) === ttl) + } + + test("from checkpoint") { + val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) + myConf.set("spark.cleaner.ttl", ttl.toString) + val ssc1 = new StreamingContext(myConf, batchDuration) + val cp = new Checkpoint(ssc1, Time(1000)) + assert(MetadataCleaner.getDelaySeconds(cp.sparkConf) === ttl) + ssc1.stop() + val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) + assert(MetadataCleaner.getDelaySeconds(newCp.sparkConf) === ttl) + ssc = new StreamingContext(null, cp, null) + assert(MetadataCleaner.getDelaySeconds(ssc.conf) === ttl) + } + + test("start multiple times") { + ssc = new StreamingContext(master, appName, batchDuration) + addInputStream(ssc).register + + ssc.start() + intercept[SparkException] { + ssc.start() + } + } + + test("stop multiple times") { + ssc = new StreamingContext(master, appName, batchDuration) + addInputStream(ssc).register + ssc.start() + ssc.stop() + ssc.stop() + ssc = null + } + + test("stop only streaming context") { + ssc = new StreamingContext(master, appName, batchDuration) + sc = ssc.sparkContext + addInputStream(ssc).register + ssc.start() + ssc.stop(false) + ssc = null + assert(sc.makeRDD(1 to 100).collect().size === 100) + ssc = new StreamingContext(sc, batchDuration) + } + + test("awaitTermination") { + ssc = new StreamingContext(master, appName, batchDuration) + val inputStream = addInputStream(ssc) + inputStream.map(x => x).register + + // test whether start() blocks indefinitely or not + failAfter(2000 millis) { + ssc.start() + } + + // test whether waitForStop() exits after give amount of time + failAfter(1000 millis) { + ssc.awaitTermination(500) + } + + // test whether waitForStop() does not exit if not time is given + val exception = intercept[Exception] { + failAfter(1000 millis) { + ssc.awaitTermination() + throw new Exception("Did not wait for stop") + } + } + assert(exception.isInstanceOf[TestFailedDueToTimeoutException], "Did not wait for stop") + + // test whether wait exits if context is stopped + failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown + new Thread() { + override def run { + Thread.sleep(500) + ssc.stop() + } + }.start() + ssc.awaitTermination() + } + } + + test("awaitTermination with error in task") { + ssc = new StreamingContext(master, appName, batchDuration) + val inputStream = addInputStream(ssc) + inputStream.map(x => { throw new TestException("error in map task"); x}) + .foreach(_.count) + + val exception = intercept[Exception] { + ssc.start() + ssc.awaitTermination(5000) + } + assert(exception.getMessage.contains("map task"), "Expected exception not thrown") + } + + test("awaitTermination with error in job generation") { + ssc = new StreamingContext(master, appName, batchDuration) + val inputStream = addInputStream(ssc) + + inputStream.transform(rdd => { throw new TestException("error in transform"); rdd }).register + val exception = intercept[TestException] { + ssc.start() + ssc.awaitTermination(5000) + } + assert(exception.getMessage.contains("transform"), "Expected exception not thrown") + } + + def addInputStream(s: StreamingContext): DStream[Int] = { + val input = (1 to 100).map(i => (1 to i)) + val inputStream = new TestInputStream(s, input, 1) + s.registerInputStream(inputStream) + inputStream + } +} + +class TestException(msg: String) extends Exception(msg) \ No newline at end of file diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index b20d02f99681e87879c31a59778a9fc24388b6ab..63a07cfbdfa5aaf66f75ee84bbbb9808e380f500 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -137,7 +137,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { val conf = new SparkConf() .setMaster(master) .setAppName(framework) - .set("spark.cleaner.ttl", "3600") + .set("spark.cleaner.ttl", StreamingContext.DEFAULT_CLEANER_TTL.toString) // Default before function for any streaming test suite. Override this // if you want to add your stuff to "before" (i.e., don't call before { } ) @@ -273,10 +273,11 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { val startTime = System.currentTimeMillis() while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput) - Thread.sleep(10) + ssc.awaitTermination(50) } val timeTaken = System.currentTimeMillis() - startTime - + logInfo("Output generated in " + timeTaken + " milliseconds") + output.foreach(x => logInfo("[" + x.mkString(",") + "]")) assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") assert(output.size === numExpectedOutput, "Unexpected number of outputs generated")