From 1b6fe9b1a70aa3f81448c2705ea3a4b501cbda9d Mon Sep 17 00:00:00 2001
From: cody koeninger <cody@koeninger.org>
Date: Fri, 19 Jun 2015 18:54:07 -0700
Subject: [PATCH] [SPARK-8127] [STREAMING] [KAFKA] KafkaRDD optimize count()
 take() isEmpty()
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

…ed KafkaRDD methods.  Possible fix for [SPARK-7122], but probably a worthwhile optimization regardless.

Author: cody koeninger <cody@koeninger.org>

Closes #6632 from koeninger/kafka-rdd-count and squashes the following commits:

321340d [cody koeninger] [SPARK-8127][Streaming][Kafka] additional test of ordering of take()
5a05d0f [cody koeninger] [SPARK-8127][Streaming][Kafka] additional test of isEmpty
f68bd32 [cody koeninger] [Streaming][Kafka][SPARK-8127] code cleanup
9555b73 [cody koeninger] Merge branch 'master' into kafka-rdd-count
253031d [cody koeninger] [Streaming][Kafka][SPARK-8127] mima exclusion for change to private method
8974b9e [cody koeninger] [Streaming][Kafka][SPARK-8127] check offset ranges before constructing KafkaRDD
c3768c5 [cody koeninger] [Streaming][Kafka] Take advantage of offset range info for size-related KafkaRDD methods.  Possible fix for [SPARK-7122], but probably a worthwhile optimization regardless.
---
 .../kafka/DirectKafkaInputDStream.scala       |  8 +---
 .../spark/streaming/kafka/KafkaCluster.scala  |  8 ++++
 .../spark/streaming/kafka/KafkaRDD.scala      | 44 ++++++++++++++++++
 .../streaming/kafka/KafkaRDDPartition.scala   |  5 +-
 .../spark/streaming/kafka/KafkaUtils.scala    | 46 +++++++++++++------
 .../spark/streaming/kafka/OffsetRange.scala   |  6 +++
 .../spark/streaming/kafka/KafkaRDDSuite.scala | 26 +++++++++--
 project/MimaExcludes.scala                    |  3 ++
 8 files changed, 122 insertions(+), 24 deletions(-)

diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
index 060c2f23ed..876456c964 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
@@ -120,8 +120,7 @@ class DirectKafkaInputDStream[
       context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler)
 
     // Report the record number of this batch interval to InputInfoTracker.
-    val numRecords = rdd.offsetRanges.map(r => r.untilOffset - r.fromOffset).sum
-    val inputInfo = InputInfo(id, numRecords)
+    val inputInfo = InputInfo(id, rdd.count)
     ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo)
 
     currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset)
@@ -153,10 +152,7 @@ class DirectKafkaInputDStream[
     override def restore() {
       // this is assuming that the topics don't change during execution, which is true currently
       val topics = fromOffsets.keySet
-      val leaders = kc.findLeaders(topics).fold(
-        errs => throw new SparkException(errs.mkString("\n")),
-        ok => ok
-      )
+      val leaders = KafkaCluster.checkErrors(kc.findLeaders(topics))
 
       batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) =>
           logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}")
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
index 65d51d87f8..3e6b937af5 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
@@ -360,6 +360,14 @@ private[spark]
 object KafkaCluster {
   type Err = ArrayBuffer[Throwable]
 
+  /** If the result is right, return it, otherwise throw SparkException */
+  def checkErrors[T](result: Either[Err, T]): T = {
+    result.fold(
+      errs => throw new SparkException(errs.mkString("\n")),
+      ok => ok
+    )
+  }
+
   private[spark]
   case class LeaderOffset(host: String, port: Int, offset: Long)
 
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
index a1b4a12e5d..c5cd215477 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
@@ -17,9 +17,11 @@
 
 package org.apache.spark.streaming.kafka
 
+import scala.collection.mutable.ArrayBuffer
 import scala.reflect.{classTag, ClassTag}
 
 import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskContext}
+import org.apache.spark.partial.{PartialResult, BoundedDouble}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.NextIterator
 
@@ -60,6 +62,48 @@ class KafkaRDD[
     }.toArray
   }
 
+  override def count(): Long = offsetRanges.map(_.count).sum
+
+  override def countApprox(
+      timeout: Long,
+      confidence: Double = 0.95
+  ): PartialResult[BoundedDouble] = {
+    val c = count
+    new PartialResult(new BoundedDouble(c, 1.0, c, c), true)
+  }
+
+  override def isEmpty(): Boolean = count == 0L
+
+  override def take(num: Int): Array[R] = {
+    val nonEmptyPartitions = this.partitions
+      .map(_.asInstanceOf[KafkaRDDPartition])
+      .filter(_.count > 0)
+
+    if (num < 1 || nonEmptyPartitions.size < 1) {
+      return new Array[R](0)
+    }
+
+    // Determine in advance how many messages need to be taken from each partition
+    val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) =>
+      val remain = num - result.values.sum
+      if (remain > 0) {
+        val taken = Math.min(remain, part.count)
+        result + (part.index -> taken.toInt)
+      } else {
+        result
+      }
+    }
+
+    val buf = new ArrayBuffer[R]
+    val res = context.runJob(
+      this,
+      (tc: TaskContext, it: Iterator[R]) => it.take(parts(tc.partitionId)).toArray,
+      parts.keys.toArray,
+      allowLocal = true)
+    res.foreach(buf ++= _)
+    buf.toArray
+  }
+
   override def getPreferredLocations(thePart: Partition): Seq[String] = {
     val part = thePart.asInstanceOf[KafkaRDDPartition]
     // TODO is additional hostname resolution necessary here
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala
index a842a6f177..a660d2a00c 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala
@@ -35,4 +35,7 @@ class KafkaRDDPartition(
   val untilOffset: Long,
   val host: String,
   val port: Int
-) extends Partition
+) extends Partition {
+  /** Number of messages this partition refers to */
+  def count(): Long = untilOffset - fromOffset
+}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
index 0b8a391a2c..0e33362d34 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
@@ -158,15 +158,31 @@ object KafkaUtils {
 
   /** get leaders for the given offset ranges, or throw an exception */
   private def leadersForRanges(
-      kafkaParams: Map[String, String],
+      kc: KafkaCluster,
       offsetRanges: Array[OffsetRange]): Map[TopicAndPartition, (String, Int)] = {
-    val kc = new KafkaCluster(kafkaParams)
     val topics = offsetRanges.map(o => TopicAndPartition(o.topic, o.partition)).toSet
-    val leaders = kc.findLeaders(topics).fold(
-      errs => throw new SparkException(errs.mkString("\n")),
-      ok => ok
-    )
-    leaders
+    val leaders = kc.findLeaders(topics)
+    KafkaCluster.checkErrors(leaders)
+  }
+
+  /** Make sure offsets are available in kafka, or throw an exception */
+  private def checkOffsets(
+      kc: KafkaCluster,
+      offsetRanges: Array[OffsetRange]): Unit = {
+    val topics = offsetRanges.map(_.topicAndPartition).toSet
+    val result = for {
+      low <- kc.getEarliestLeaderOffsets(topics).right
+      high <- kc.getLatestLeaderOffsets(topics).right
+    } yield {
+      offsetRanges.filterNot { o =>
+        low(o.topicAndPartition).offset <= o.fromOffset &&
+        o.untilOffset <= high(o.topicAndPartition).offset
+      }
+    }
+    val badRanges = KafkaCluster.checkErrors(result)
+    if (!badRanges.isEmpty) {
+      throw new SparkException("Offsets not available on leader: " + badRanges.mkString(","))
+    }
   }
 
   /**
@@ -191,7 +207,9 @@ object KafkaUtils {
       offsetRanges: Array[OffsetRange]
     ): RDD[(K, V)] = sc.withScope {
     val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message)
-    val leaders = leadersForRanges(kafkaParams, offsetRanges)
+    val kc = new KafkaCluster(kafkaParams)
+    val leaders = leadersForRanges(kc, offsetRanges)
+    checkOffsets(kc, offsetRanges)
     new KafkaRDD[K, V, KD, VD, (K, V)](sc, kafkaParams, offsetRanges, leaders, messageHandler)
   }
 
@@ -225,8 +243,9 @@ object KafkaUtils {
       leaders: Map[TopicAndPartition, Broker],
       messageHandler: MessageAndMetadata[K, V] => R
     ): RDD[R] = sc.withScope {
+    val kc = new KafkaCluster(kafkaParams)
     val leaderMap = if (leaders.isEmpty) {
-      leadersForRanges(kafkaParams, offsetRanges)
+      leadersForRanges(kc, offsetRanges)
     } else {
       // This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker
       leaders.map {
@@ -234,6 +253,7 @@ object KafkaUtils {
       }.toMap
     }
     val cleanedHandler = sc.clean(messageHandler)
+    checkOffsets(kc, offsetRanges)
     new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, cleanedHandler)
   }
 
@@ -399,7 +419,7 @@ object KafkaUtils {
     val kc = new KafkaCluster(kafkaParams)
     val reset = kafkaParams.get("auto.offset.reset").map(_.toLowerCase)
 
-    (for {
+    val result = for {
       topicPartitions <- kc.getPartitions(topics).right
       leaderOffsets <- (if (reset == Some("smallest")) {
         kc.getEarliestLeaderOffsets(topicPartitions)
@@ -412,10 +432,8 @@ object KafkaUtils {
       }
       new DirectKafkaInputDStream[K, V, KD, VD, (K, V)](
         ssc, kafkaParams, fromOffsets, messageHandler)
-    }).fold(
-      errs => throw new SparkException(errs.mkString("\n")),
-      ok => ok
-    )
+    }
+    KafkaCluster.checkErrors(result)
   }
 
   /**
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala
index 9c3dfeb8f5..2675042666 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala
@@ -55,6 +55,12 @@ final class OffsetRange private(
     val untilOffset: Long) extends Serializable {
   import OffsetRange.OffsetRangeTuple
 
+  /** Kafka TopicAndPartition object, for convenience */
+  def topicAndPartition(): TopicAndPartition = TopicAndPartition(topic, partition)
+
+  /** Number of messages this OffsetRange refers to */
+  def count(): Long = untilOffset - fromOffset
+
   override def equals(obj: Any): Boolean = obj match {
     case that: OffsetRange =>
       this.topic == that.topic &&
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
index d5baf5fd89..f52a738afd 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
@@ -55,8 +55,8 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
   test("basic usage") {
     val topic = s"topicbasic-${Random.nextInt}"
     kafkaTestUtils.createTopic(topic)
-    val messages = Set("the", "quick", "brown", "fox")
-    kafkaTestUtils.sendMessages(topic, messages.toArray)
+    val messages = Array("the", "quick", "brown", "fox")
+    kafkaTestUtils.sendMessages(topic, messages)
 
     val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress,
       "group.id" -> s"test-consumer-${Random.nextInt}")
@@ -67,7 +67,27 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
       sc, kafkaParams, offsetRanges)
 
     val received = rdd.map(_._2).collect.toSet
-    assert(received === messages)
+    assert(received === messages.toSet)
+
+    // size-related method optimizations return sane results
+    assert(rdd.count === messages.size)
+    assert(rdd.countApprox(0).getFinalValue.mean === messages.size)
+    assert(!rdd.isEmpty)
+    assert(rdd.take(1).size === 1)
+    assert(rdd.take(1).head._2 === messages.head)
+    assert(rdd.take(messages.size + 10).size === messages.size)
+
+    val emptyRdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder](
+      sc, kafkaParams, Array(OffsetRange(topic, 0, 0, 0)))
+
+    assert(emptyRdd.isEmpty)
+
+    // invalid offset ranges throw exceptions
+    val badRanges = Array(OffsetRange(topic, 0, 0, messages.size + 1))
+    intercept[SparkException] {
+      KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder](
+        sc, kafkaParams, badRanges)
+    }
   }
 
   test("iterator boundary conditions") {
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 8a93ca2999..015d0296dd 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -44,6 +44,9 @@ object MimaExcludes {
             // JavaRDDLike is not meant to be extended by user programs
             ProblemFilters.exclude[MissingMethodProblem](
               "org.apache.spark.api.java.JavaRDDLike.partitioner"),
+            // Modification of private static method
+            ProblemFilters.exclude[IncompatibleMethTypeProblem](
+              "org.apache.spark.streaming.kafka.KafkaUtils.org$apache$spark$streaming$kafka$KafkaUtils$$leadersForRanges"),
             // Mima false positive (was a private[spark] class)
             ProblemFilters.exclude[MissingClassProblem](
               "org.apache.spark.util.collection.PairIterator"),
-- 
GitLab