diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index 54d8c8b03f2060d4cf0fec395a1c7f235f37d2dd..0eaaf408c0112962b34735a9551935b8d6e4f0d3 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -89,23 +89,32 @@ class DirectKafkaInputDStream[ private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt( "spark.streaming.kafka.maxRatePerPartition", 0) - protected def maxMessagesPerPartition: Option[Long] = { + + protected[streaming] def maxMessagesPerPartition( + offsets: Map[TopicAndPartition, Long]): Option[Map[TopicAndPartition, Long]] = { val estimatedRateLimit = rateController.map(_.getLatestRate().toInt) - val numPartitions = currentOffsets.keys.size - - val effectiveRateLimitPerPartition = estimatedRateLimit - .filter(_ > 0) - .map { limit => - if (maxRateLimitPerPartition > 0) { - Math.min(maxRateLimitPerPartition, (limit / numPartitions)) - } else { - limit / numPartitions + + // calculate a per-partition rate limit based on current lag + val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { + case Some(rate) => + val lagPerPartition = offsets.map { case (tp, offset) => + tp -> Math.max(offset - currentOffsets(tp), 0) + } + val totalLag = lagPerPartition.values.sum + + lagPerPartition.map { case (tp, lag) => + val backpressureRate = Math.round(lag / totalLag.toFloat * rate) + tp -> (if (maxRateLimitPerPartition > 0) { + Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate) } - }.getOrElse(maxRateLimitPerPartition) + case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition } + } - if (effectiveRateLimitPerPartition > 0) { + if (effectiveRateLimitPerPartition.values.sum > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 - Some((secsPerBatch * effectiveRateLimitPerPartition).toLong) + Some(effectiveRateLimitPerPartition.map { + case (tp, limit) => tp -> (secsPerBatch * limit).toLong + }) } else { None } @@ -134,9 +143,12 @@ class DirectKafkaInputDStream[ // limits the maximum number of messages per partition protected def clamp( leaderOffsets: Map[TopicAndPartition, LeaderOffset]): Map[TopicAndPartition, LeaderOffset] = { - maxMessagesPerPartition.map { mmp => - leaderOffsets.map { case (tp, lo) => - tp -> lo.copy(offset = Math.min(currentOffsets(tp) + mmp, lo.offset)) + val offsets = leaderOffsets.mapValues(lo => lo.offset) + + maxMessagesPerPartition(offsets).map { mmp => + mmp.map { case (tp, messages) => + val lo = leaderOffsets(tp) + tp -> lo.copy(offset = Math.min(currentOffsets(tp) + messages, lo.offset)) } }.getOrElse(leaderOffsets) } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index a76fa6671a4b08e293a052f55859c4f7895eb923..a5ea1d6d2848d3ce6b5f04eb2654e8af5aa43502 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -152,12 +152,15 @@ private[kafka] class KafkaTestUtils extends Logging { } /** Create a Kafka topic and wait until it is propagated to the whole cluster */ - def createTopic(topic: String): Unit = { - AdminUtils.createTopic(zkClient, topic, 1, 1) + def createTopic(topic: String, partitions: Int): Unit = { + AdminUtils.createTopic(zkClient, topic, partitions, 1) // wait until metadata is propagated - waitUntilMetadataIsPropagated(topic, 0) + (0 until partitions).foreach { p => waitUntilMetadataIsPropagated(topic, p) } } + /** Single-argument version for backwards compatibility */ + def createTopic(topic: String): Unit = createTopic(topic, 1) + /** Java-friendly function for sending messages to the Kafka broker */ def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = { sendMessages(topic, Map(messageToFreq.asScala.mapValues(_.intValue()).toSeq: _*)) diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java index 4891e4f4a17bc5d4dd3fd38efb6104c7b51d6599..fa6b0dbc8c2197b4087dd717246966a366913571 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -168,7 +168,7 @@ public class JavaDirectKafkaStreamSuite implements Serializable { private String[] createTopicAndSendData(String topic) { String[] data = { topic + "-1", topic + "-2", topic + "-3"}; - kafkaTestUtils.createTopic(topic); + kafkaTestUtils.createTopic(topic, 1); kafkaTestUtils.sendMessages(topic, data); return data; } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java index afcc6cfccd39a8ade454e22cf226d15c8d7bc728..c41b6297b0481a78d80a234fa239cfe090b81e58 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java @@ -149,7 +149,7 @@ public class JavaKafkaRDDSuite implements Serializable { private String[] createTopicAndSendData(String topic) { String[] data = { topic + "-1", topic + "-2", topic + "-3"}; - kafkaTestUtils.createTopic(topic); + kafkaTestUtils.createTopic(topic, 1); kafkaTestUtils.sendMessages(topic, data); return data; } diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index 617c92a008fc54c2ac2a3f421201dfca424dab19..868df64e8c94449ad5cd39d1c85f9ba68b636fd7 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -76,7 +76,7 @@ public class JavaKafkaStreamSuite implements Serializable { sent.put("b", 3); sent.put("c", 10); - kafkaTestUtils.createTopic(topic); + kafkaTestUtils.createTopic(topic, 1); kafkaTestUtils.sendMessages(topic, sent); Map<String, String> kafkaParams = new HashMap<>(); diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 8398178e9b79b8f3e7fc5752ee070425b6dd6330..b2c81d1534ee687b005c83098d42b227dbdca1bb 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -353,10 +353,38 @@ class DirectKafkaStreamSuite ssc.stop() } + test("maxMessagesPerPartition with backpressure disabled") { + val topic = "maxMessagesPerPartition" + val kafkaStream = getDirectKafkaStream(topic, None) + + val input = Map(TopicAndPartition(topic, 0) -> 50L, TopicAndPartition(topic, 1) -> 50L) + assert(kafkaStream.maxMessagesPerPartition(input).get == + Map(TopicAndPartition(topic, 0) -> 10L, TopicAndPartition(topic, 1) -> 10L)) + } + + test("maxMessagesPerPartition with no lag") { + val topic = "maxMessagesPerPartition" + val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 100)) + val kafkaStream = getDirectKafkaStream(topic, rateController) + + val input = Map(TopicAndPartition(topic, 0) -> 0L, TopicAndPartition(topic, 1) -> 0L) + assert(kafkaStream.maxMessagesPerPartition(input).isEmpty) + } + + test("maxMessagesPerPartition respects max rate") { + val topic = "maxMessagesPerPartition" + val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 1000)) + val kafkaStream = getDirectKafkaStream(topic, rateController) + + val input = Map(TopicAndPartition(topic, 0) -> 1000L, TopicAndPartition(topic, 1) -> 1000L) + assert(kafkaStream.maxMessagesPerPartition(input).get == + Map(TopicAndPartition(topic, 0) -> 10L, TopicAndPartition(topic, 1) -> 10L)) + } + test("using rate controller") { val topic = "backpressure" - val topicPartition = TopicAndPartition(topic, 0) - kafkaTestUtils.createTopic(topic) + val topicPartitions = Set(TopicAndPartition(topic, 0), TopicAndPartition(topic, 1)) + kafkaTestUtils.createTopic(topic, 2) val kafkaParams = Map( "metadata.broker.list" -> kafkaTestUtils.brokerAddress, "auto.offset.reset" -> "smallest" @@ -364,8 +392,8 @@ class DirectKafkaStreamSuite val batchIntervalMilliseconds = 100 val estimator = new ConstantEstimator(100) - val messageKeys = (1 to 200).map(_.toString) - val messages = messageKeys.map((_, 1)).toMap + val messages = Map("foo" -> 200) + kafkaTestUtils.sendMessages(topic, messages) val sparkConf = new SparkConf() // Safe, even with streaming, because we're using the direct API. @@ -380,11 +408,11 @@ class DirectKafkaStreamSuite val kafkaStream = withClue("Error creating direct stream") { val kc = new KafkaCluster(kafkaParams) val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) - val m = kc.getEarliestLeaderOffsets(Set(topicPartition)) + val m = kc.getEarliestLeaderOffsets(topicPartitions) .fold(e => Map.empty[TopicAndPartition, Long], m => m.mapValues(lo => lo.offset)) new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( - ssc, kafkaParams, m, messageHandler) { + ssc, kafkaParams, m, messageHandler) { override protected[streaming] val rateController = Some(new DirectKafkaRateController(id, estimator)) } @@ -405,13 +433,12 @@ class DirectKafkaStreamSuite ssc.start() // Try different rate limits. - // Send data to Kafka and wait for arrays of data to appear matching the rate. + // Wait for arrays of data to appear matching the rate. Seq(100, 50, 20).foreach { rate => collectedData.clear() // Empty this buffer on each pass. estimator.updateRate(rate) // Set a new rate. // Expect blocks of data equal to "rate", scaled by the interval length in secs. val expectedSize = Math.round(rate * batchIntervalMilliseconds * 0.001) - kafkaTestUtils.sendMessages(topic, messages) eventually(timeout(5.seconds), interval(batchIntervalMilliseconds.milliseconds)) { // Assert that rate estimator values are used to determine maxMessagesPerPartition. // Funky "-" in message makes the complete assertion message read better. @@ -430,6 +457,25 @@ class DirectKafkaStreamSuite rdd.asInstanceOf[KafkaRDD[K, V, _, _, (K, V)]].offsetRanges }.toSeq.sortBy { _._1 } } + + private def getDirectKafkaStream(topic: String, mockRateController: Option[RateController]) = { + val batchIntervalMilliseconds = 100 + + val sparkConf = new SparkConf() + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + + val earliestOffsets = Map(TopicAndPartition(topic, 0) -> 0L, TopicAndPartition(topic, 1) -> 0L) + val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) + new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( + ssc, Map[String, String](), earliestOffsets, messageHandler) { + override protected[streaming] val rateController = mockRateController + } + } } object DirectKafkaStreamSuite { @@ -468,3 +514,9 @@ private[streaming] class ConstantEstimator(@volatile private var rate: Long) processingDelay: Long, schedulingDelay: Long): Option[Double] = Some(rate) } + +private[streaming] class ConstantRateController(id: Int, estimator: RateEstimator, rate: Long) + extends RateController(id, estimator) { + override def publish(rate: Long): Unit = () + override def getLatestRate(): Long = rate +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9ce37fc753c46a20f34db4b1dd6ca75a55fa41b9..983f71684c38b46aa6d1de5ce3a9db9eb2ad18c1 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -288,6 +288,10 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry$") + ) ++ Seq( + // SPARK-12073: backpressure rate controller consumes events preferentially from lagging partitions + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.KafkaTestUtils.createTopic"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.DirectKafkaInputDStream.maxMessagesPerPartition") ) case v if v.startsWith("1.6") => Seq(