diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 1b1fc8051d052dde6494dcea6a8055f7b72b8894..6715aede7928a57917969bde0640e889b63dc880 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -17,7 +17,6 @@ package org.apache.spark.streaming.kafka - import scala.annotation.tailrec import scala.collection.mutable import scala.reflect.{classTag, ClassTag} @@ -27,10 +26,10 @@ import kafka.message.MessageAndMetadata import kafka.serializer.Decoder import org.apache.spark.{Logging, SparkException} -import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset +import org.apache.spark.streaming.scheduler.InputInfo /** * A stream of {@link org.apache.spark.streaming.kafka.KafkaRDD} where @@ -117,6 +116,11 @@ class DirectKafkaInputDStream[ val rdd = KafkaRDD[K, V, U, T, R]( context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler) + // Report the record number of this batch interval to InputInfoTracker. + val numRecords = rdd.offsetRanges.map(r => r.untilOffset - r.fromOffset).sum + val inputInfo = InputInfo(id, numRecords) + ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo) + currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset) Some(rdd) } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 415730f5559c59eb819954cc1246d3d842762529..b6d314dfc77838f63da844cc814019e7ef2301e7 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming.kafka import java.io.File +import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -34,6 +35,7 @@ import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.Utils class DirectKafkaStreamSuite @@ -290,7 +292,6 @@ class DirectKafkaStreamSuite }, "Recovered ranges are not the same as the ones generated" ) - // Restart context, give more data and verify the total at the end // If the total is write that means each records has been received only once ssc.start() @@ -301,6 +302,44 @@ class DirectKafkaStreamSuite ssc.stop() } + test("Direct Kafka stream report input information") { + val topic = "report-test" + val data = Map("a" -> 7, "b" -> 9) + kafkaTestUtils.createTopic(topic) + kafkaTestUtils.sendMessages(topic, data) + + val totalSent = data.values.sum + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + import DirectKafkaStreamSuite._ + ssc = new StreamingContext(sparkConf, Milliseconds(200)) + val collector = new InputInfoCollector + ssc.addStreamingListener(collector) + + val stream = withClue("Error creating direct stream") { + KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, Set(topic)) + } + + val allReceived = new ArrayBuffer[(String, String)] + + stream.foreachRDD { rdd => allReceived ++= rdd.collect() } + ssc.start() + eventually(timeout(20000.milliseconds), interval(200.milliseconds)) { + assert(allReceived.size === totalSent, + "didn't get expected number of messages, messages:\n" + allReceived.mkString("\n")) + + // Calculate all the record number collected in the StreamingListener. + assert(collector.numRecordsSubmitted.get() === totalSent) + assert(collector.numRecordsStarted.get() === totalSent) + assert(collector.numRecordsCompleted.get() === totalSent) + } + ssc.stop() + } + /** Get the generated offset ranges from the DirectKafkaStream */ private def getOffsetRanges[K, V]( kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = { @@ -313,4 +352,22 @@ class DirectKafkaStreamSuite object DirectKafkaStreamSuite { val collectedData = new mutable.ArrayBuffer[String]() var total = -1L + + class InputInfoCollector extends StreamingListener { + val numRecordsSubmitted = new AtomicLong(0L) + val numRecordsStarted = new AtomicLong(0L) + val numRecordsCompleted = new AtomicLong(0L) + + override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = { + numRecordsSubmitted.addAndGet(batchSubmitted.batchInfo.numRecords) + } + + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = { + numRecordsStarted.addAndGet(batchStarted.batchInfo.numRecords) + } + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = { + numRecordsCompleted.addAndGet(batchCompleted.batchInfo.numRecords) + } + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index d2729fa70d6d21c694340a853c71b486eaebfbef..24cbb2bf9d8fe96a2480475e76207e8d46e0e9ab 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -192,8 +192,8 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) val latestReceiverNumRecords = latestBatchInfos.map(_.receiverNumRecords) val streamIds = ssc.graph.getInputStreams().map(_.id) streamIds.map { id => - val recordsOfParticularReceiver = - latestReceiverNumRecords.map(v => v.getOrElse(id, 0L).toDouble * 1000 / batchDuration) + val recordsOfParticularReceiver = + latestReceiverNumRecords.map(v => v.getOrElse(id, 0L).toDouble * 1000 / batchDuration) val distribution = Distribution(recordsOfParticularReceiver) (id, distribution) }.toMap