From e87075df977a539e4a1684045a7bd66c36285174 Mon Sep 17 00:00:00 2001
From: jerryshao <saisai.shao@intel.com>
Date: Tue, 5 Aug 2014 10:40:28 -0700
Subject: [PATCH] [SPARK-1022][Streaming] Add Kafka real unit test

This PR is a updated version of (https://github.com/apache/spark/pull/557) to actually test sending and receiving data through Kafka, and fix previous flaky issues.

@tdas, would you mind reviewing this PR? Thanks a lot.

Author: jerryshao <saisai.shao@intel.com>

Closes #1751 from jerryshao/kafka-unit-test and squashes the following commits:

b6a505f [jerryshao] code refactor according to comments
5222330 [jerryshao] Change JavaKafkaStreamSuite to better test it
5525f10 [jerryshao] Fix flaky issue of Kafka real unit test
4559310 [jerryshao] Minor changes for Kafka unit test
860f649 [jerryshao] Minor style changes, and tests ignored due to flakiness
796d4ca [jerryshao] Add real Kafka streaming test
---
 external/kafka/pom.xml                        |   6 +
 .../streaming/kafka/JavaKafkaStreamSuite.java | 125 +++++++++--
 .../streaming/kafka/KafkaStreamSuite.scala    | 197 ++++++++++++++++--
 3 files changed, 293 insertions(+), 35 deletions(-)

diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml
index daf03360bc..2aee999492 100644
--- a/external/kafka/pom.xml
+++ b/external/kafka/pom.xml
@@ -70,6 +70,12 @@
         </exclusion>
       </exclusions>
     </dependency>
+    <dependency>
+      <groupId>net.sf.jopt-simple</groupId>
+      <artifactId>jopt-simple</artifactId>
+      <version>3.2</version>
+      <scope>test</scope>
+    </dependency>
     <dependency>
       <groupId>org.scalatest</groupId>
       <artifactId>scalatest_${scala.binary.version}</artifactId>
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
index 9f8046bf00..0571454c01 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
@@ -17,31 +17,118 @@
 
 package org.apache.spark.streaming.kafka;
 
+import java.io.Serializable;
 import java.util.HashMap;
+import java.util.List;
+
+import scala.Predef;
+import scala.Tuple2;
+import scala.collection.JavaConverters;
+
+import junit.framework.Assert;
 
-import org.apache.spark.streaming.api.java.JavaPairReceiverInputDStream;
-import org.junit.Test;
-import com.google.common.collect.Maps;
 import kafka.serializer.StringDecoder;
+
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.function.Function;
 import org.apache.spark.storage.StorageLevel;
+import org.apache.spark.streaming.Duration;
 import org.apache.spark.streaming.LocalJavaStreamingContext;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.api.java.JavaPairDStream;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+
+import org.junit.Test;
+import org.junit.After;
+import org.junit.Before;
+
+public class JavaKafkaStreamSuite extends LocalJavaStreamingContext implements Serializable {
+  private transient KafkaStreamSuite testSuite = new KafkaStreamSuite();
+
+  @Before
+  @Override
+  public void setUp() {
+    testSuite.beforeFunction();
+    System.clearProperty("spark.driver.port");
+    //System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock");
+    ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000));
+  }
+
+  @After
+  @Override
+  public void tearDown() {
+    ssc.stop();
+    ssc = null;
+    System.clearProperty("spark.driver.port");
+    testSuite.afterFunction();
+  }
 
-public class JavaKafkaStreamSuite extends LocalJavaStreamingContext {
   @Test
-  public void testKafkaStream() {
-    HashMap<String, Integer> topics = Maps.newHashMap();
-
-    // tests the API, does not actually test data receiving
-    JavaPairReceiverInputDStream<String, String> test1 =
-            KafkaUtils.createStream(ssc, "localhost:12345", "group", topics);
-    JavaPairReceiverInputDStream<String, String> test2 = KafkaUtils.createStream(ssc, "localhost:12345", "group", topics,
-      StorageLevel.MEMORY_AND_DISK_SER_2());
-
-    HashMap<String, String> kafkaParams = Maps.newHashMap();
-    kafkaParams.put("zookeeper.connect", "localhost:12345");
-    kafkaParams.put("group.id","consumer-group");
-      JavaPairReceiverInputDStream<String, String> test3 = KafkaUtils.createStream(ssc,
-      String.class, String.class, StringDecoder.class, StringDecoder.class,
-      kafkaParams, topics, StorageLevel.MEMORY_AND_DISK_SER_2());
+  public void testKafkaStream() throws InterruptedException {
+    String topic = "topic1";
+    HashMap<String, Integer> topics = new HashMap<String, Integer>();
+    topics.put(topic, 1);
+
+    HashMap<String, Integer> sent = new HashMap<String, Integer>();
+    sent.put("a", 5);
+    sent.put("b", 3);
+    sent.put("c", 10);
+
+    testSuite.createTopic(topic);
+    HashMap<String, Object> tmp = new HashMap<String, Object>(sent);
+    testSuite.produceAndSendMessage(topic,
+      JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap(
+        Predef.<Tuple2<String, Object>>conforms()));
+
+    HashMap<String, String> kafkaParams = new HashMap<String, String>();
+    kafkaParams.put("zookeeper.connect", testSuite.zkConnect());
+    kafkaParams.put("group.id", "test-consumer-" + KafkaTestUtils.random().nextInt(10000));
+    kafkaParams.put("auto.offset.reset", "smallest");
+
+    JavaPairDStream<String, String> stream = KafkaUtils.createStream(ssc,
+      String.class,
+      String.class,
+      StringDecoder.class,
+      StringDecoder.class,
+      kafkaParams,
+      topics,
+      StorageLevel.MEMORY_ONLY_SER());
+
+    final HashMap<String, Long> result = new HashMap<String, Long>();
+
+    JavaDStream<String> words = stream.map(
+      new Function<Tuple2<String, String>, String>() {
+        @Override
+        public String call(Tuple2<String, String> tuple2) throws Exception {
+          return tuple2._2();
+        }
+      }
+    );
+
+    words.countByValue().foreachRDD(
+      new Function<JavaPairRDD<String, Long>, Void>() {
+        @Override
+        public Void call(JavaPairRDD<String, Long> rdd) throws Exception {
+          List<Tuple2<String, Long>> ret = rdd.collect();
+          for (Tuple2<String, Long> r : ret) {
+            if (result.containsKey(r._1())) {
+              result.put(r._1(), result.get(r._1()) + r._2());
+            } else {
+              result.put(r._1(), r._2());
+            }
+          }
+
+          return null;
+        }
+      }
+    );
+
+    ssc.start();
+    ssc.awaitTermination(3000);
+
+    Assert.assertEquals(sent.size(), result.size());
+    for (String k : sent.keySet()) {
+      Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue());
+    }
   }
 }
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
index e6f2c4a5cf..c0b55e9340 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
@@ -17,28 +17,193 @@
 
 package org.apache.spark.streaming.kafka
 
-import kafka.serializer.StringDecoder
+import java.io.File
+import java.net.InetSocketAddress
+import java.util.{Properties, Random}
+
+import scala.collection.mutable
+
+import kafka.admin.CreateTopicCommand
+import kafka.common.TopicAndPartition
+import kafka.producer.{KeyedMessage, ProducerConfig, Producer}
+import kafka.utils.ZKStringSerializer
+import kafka.serializer.{StringDecoder, StringEncoder}
+import kafka.server.{KafkaConfig, KafkaServer}
+
+import org.I0Itec.zkclient.ZkClient
+
+import org.apache.zookeeper.server.ZooKeeperServer
+import org.apache.zookeeper.server.NIOServerCnxnFactory
+
 import org.apache.spark.streaming.{StreamingContext, TestSuiteBase}
 import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.dstream.ReceiverInputDStream
+import org.apache.spark.util.Utils
 
 class KafkaStreamSuite extends TestSuiteBase {
+  import KafkaTestUtils._
+
+  val zkConnect = "localhost:2181"
+  val zkConnectionTimeout = 6000
+  val zkSessionTimeout = 6000
+
+  val brokerPort = 9092
+  val brokerProps = getBrokerConfig(brokerPort, zkConnect)
+  val brokerConf = new KafkaConfig(brokerProps)
+
+  protected var zookeeper: EmbeddedZookeeper = _
+  protected var zkClient: ZkClient = _
+  protected var server: KafkaServer = _
+  protected var producer: Producer[String, String] = _
+
+  override def useManualClock = false
+
+  override def beforeFunction() {
+    // Zookeeper server startup
+    zookeeper = new EmbeddedZookeeper(zkConnect)
+    logInfo("==================== 0 ====================")
+    zkClient = new ZkClient(zkConnect, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer)
+    logInfo("==================== 1 ====================")
 
-  test("kafka input stream") {
+    // Kafka broker startup
+    server = new KafkaServer(brokerConf)
+    logInfo("==================== 2 ====================")
+    server.startup()
+    logInfo("==================== 3 ====================")
+    Thread.sleep(2000)
+    logInfo("==================== 4 ====================")
+    super.beforeFunction()
+  }
+
+  override def afterFunction() {
+    producer.close()
+    server.shutdown()
+    brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) }
+
+    zkClient.close()
+    zookeeper.shutdown()
+
+    super.afterFunction()
+  }
+
+  test("Kafka input stream") {
     val ssc = new StreamingContext(master, framework, batchDuration)
-    val topics = Map("my-topic" -> 1)
-
-    // tests the API, does not actually test data receiving
-    val test1: ReceiverInputDStream[(String, String)] =
-      KafkaUtils.createStream(ssc, "localhost:1234", "group", topics)
-    val test2: ReceiverInputDStream[(String, String)] =
-      KafkaUtils.createStream(ssc, "localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK_SER_2)
-    val kafkaParams = Map("zookeeper.connect"->"localhost:12345","group.id"->"consumer-group")
-    val test3: ReceiverInputDStream[(String, String)] =
-      KafkaUtils.createStream[String, String, StringDecoder, StringDecoder](
-      ssc, kafkaParams, topics, StorageLevel.MEMORY_AND_DISK_SER_2)
-
-    // TODO: Actually test receiving data
+    val topic = "topic1"
+    val sent = Map("a" -> 5, "b" -> 3, "c" -> 10)
+    createTopic(topic)
+    produceAndSendMessage(topic, sent)
+
+    val kafkaParams = Map("zookeeper.connect" -> zkConnect,
+      "group.id" -> s"test-consumer-${random.nextInt(10000)}",
+      "auto.offset.reset" -> "smallest")
+
+    val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder](
+      ssc,
+      kafkaParams,
+      Map(topic -> 1),
+      StorageLevel.MEMORY_ONLY)
+    val result = new mutable.HashMap[String, Long]()
+    stream.map { case (k, v) => v }
+      .countByValue()
+      .foreachRDD { r =>
+        val ret = r.collect()
+        ret.toMap.foreach { kv =>
+          val count = result.getOrElseUpdate(kv._1, 0) + kv._2
+          result.put(kv._1, count)
+        }
+      }
+    ssc.start()
+    ssc.awaitTermination(3000)
+
+    assert(sent.size === result.size)
+    sent.keys.foreach { k => assert(sent(k) === result(k).toInt) }
+
     ssc.stop()
   }
+
+  private def createTestMessage(topic: String, sent: Map[String, Int])
+    : Seq[KeyedMessage[String, String]] = {
+    val messages = for ((s, freq) <- sent; i <- 0 until freq) yield {
+      new KeyedMessage[String, String](topic, s)
+    }
+    messages.toSeq
+  }
+
+  def createTopic(topic: String) {
+    CreateTopicCommand.createTopic(zkClient, topic, 1, 1, "0")
+    logInfo("==================== 5 ====================")
+    // wait until metadata is propagated
+    waitUntilMetadataIsPropagated(Seq(server), topic, 0, 1000)
+  }
+
+  def produceAndSendMessage(topic: String, sent: Map[String, Int]) {
+    val brokerAddr = brokerConf.hostName + ":" + brokerConf.port
+    producer = new Producer[String, String](new ProducerConfig(getProducerConfig(brokerAddr)))
+    producer.send(createTestMessage(topic, sent): _*)
+    logInfo("==================== 6 ====================")
+  }
+}
+
+object KafkaTestUtils {
+  val random = new Random()
+
+  def getBrokerConfig(port: Int, zkConnect: String): Properties = {
+    val props = new Properties()
+    props.put("broker.id", "0")
+    props.put("host.name", "localhost")
+    props.put("port", port.toString)
+    props.put("log.dir", Utils.createTempDir().getAbsolutePath)
+    props.put("zookeeper.connect", zkConnect)
+    props.put("log.flush.interval.messages", "1")
+    props.put("replica.socket.timeout.ms", "1500")
+    props
+  }
+
+  def getProducerConfig(brokerList: String): Properties = {
+    val props = new Properties()
+    props.put("metadata.broker.list", brokerList)
+    props.put("serializer.class", classOf[StringEncoder].getName)
+    props
+  }
+
+  def waitUntilTrue(condition: () => Boolean, waitTime: Long): Boolean = {
+    val startTime = System.currentTimeMillis()
+    while (true) {
+      if (condition())
+        return true
+      if (System.currentTimeMillis() > startTime + waitTime)
+        return false
+      Thread.sleep(waitTime.min(100L))
+    }
+    // Should never go to here
+    throw new RuntimeException("unexpected error")
+  }
+
+  def waitUntilMetadataIsPropagated(servers: Seq[KafkaServer], topic: String, partition: Int,
+      timeout: Long) {
+    assert(waitUntilTrue(() =>
+      servers.foldLeft(true)(_ && _.apis.leaderCache.keySet.contains(
+        TopicAndPartition(topic, partition))), timeout),
+      s"Partition [$topic, $partition] metadata not propagated after timeout")
+  }
+
+  class EmbeddedZookeeper(val zkConnect: String) {
+    val random = new Random()
+    val snapshotDir = Utils.createTempDir()
+    val logDir = Utils.createTempDir()
+
+    val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500)
+    val (ip, port) = {
+      val splits = zkConnect.split(":")
+      (splits(0), splits(1).toInt)
+    }
+    val factory = new NIOServerCnxnFactory()
+    factory.configure(new InetSocketAddress(ip, port), 16)
+    factory.startup(zookeeper)
+
+    def shutdown() {
+      factory.shutdown()
+      Utils.deleteRecursively(snapshotDir)
+      Utils.deleteRecursively(logDir)
+    }
+  }
 }
-- 
GitLab