diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 54d736ee5101b97e52d02cba9c2ca707b0efc7f0..dce2028b4887811af2762ba19d99cedc4bdbed36 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -31,12 +31,15 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { private val inputStreams = new ArrayBuffer[InputDStream[_]]() private val outputStreams = new ArrayBuffer[DStream[_]]() + @volatile private var inputStreamNameAndID: Seq[(String, Int)] = Nil + var rememberDuration: Duration = null var checkpointInProgress = false var zeroTime: Time = null var startTime: Time = null var batchDuration: Duration = null + @volatile private var numReceivers: Int = 0 def start(time: Time) { this.synchronized { @@ -45,7 +48,9 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { startTime = time outputStreams.foreach(_.initialize(zeroTime)) outputStreams.foreach(_.remember(rememberDuration)) - outputStreams.foreach(_.validateAtStart) + outputStreams.foreach(_.validateAtStart()) + numReceivers = inputStreams.count(_.isInstanceOf[ReceiverInputDStream[_]]) + inputStreamNameAndID = inputStreams.map(is => (is.name, is.id)) inputStreams.par.foreach(_.start()) } } @@ -106,9 +111,9 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { .toArray } - def getInputStreamName(streamId: Int): Option[String] = synchronized { - inputStreams.find(_.id == streamId).map(_.name) - } + def getNumReceivers: Int = numReceivers + + def getInputStreamNameAndID: Seq[(String, Int)] = inputStreamNameAndID def generateJobs(time: Time): Seq[Job] = { logDebug("Generating jobs for time " + time) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index 95f582106c71f9d5863154b7fda99c18bab9cad4..ed4c1e484efd2ec725b676c92d35dc045b335b8f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -169,7 +169,7 @@ private[spark] class StreamingJobProgressListener(ssc: StreamingContext) } def numInactiveReceivers: Int = { - ssc.graph.getReceiverInputStreams().length - numActiveReceivers + ssc.graph.getNumReceivers - numActiveReceivers } def numTotalCompletedBatches: Long = synchronized { @@ -197,17 +197,17 @@ private[spark] class StreamingJobProgressListener(ssc: StreamingContext) } def retainedCompletedBatches: Seq[BatchUIData] = synchronized { - completedBatchUIData.toSeq + completedBatchUIData.toIndexedSeq } def streamName(streamId: Int): Option[String] = { - ssc.graph.getInputStreamName(streamId) + ssc.graph.getInputStreamNameAndID.find(_._2 == streamId).map(_._1) } /** * Return all InputDStream Ids */ - def streamIds: Seq[Int] = ssc.graph.getInputStreams().map(_.id) + def streamIds: Seq[Int] = ssc.graph.getInputStreamNameAndID.map(_._2) /** * Return all of the record rates for each InputDStream in each batch. The key of the return value