diff --git a/dev/run-tests b/dev/run-tests
index 1b6cf78b5da01b41d120486973a3b1814dd25b75..bb21ab6c9aa0401a9d3a9c882f3ec16995223a68 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -173,7 +173,7 @@ CURRENT_BLOCK=$BLOCK_BUILD
     build/mvn $HIVE_BUILD_ARGS clean package -DskipTests
   else
     echo -e "q\n" \
-      | build/sbt $HIVE_BUILD_ARGS package assembly/assembly  \
+      | build/sbt $HIVE_BUILD_ARGS package assembly/assembly streaming-kafka-assembly/assembly \
       | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
   fi
 }
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
new file mode 100644
index 0000000000000000000000000000000000000000..13e947506597987d72c02f09f1bffa116469b6e7
--- /dev/null
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka
+
+import java.io.File
+import java.lang.{Integer => JInt}
+import java.net.InetSocketAddress
+import java.util.{Map => JMap}
+import java.util.Properties
+import java.util.concurrent.TimeoutException
+
+import scala.annotation.tailrec
+import scala.language.postfixOps
+import scala.util.control.NonFatal
+
+import kafka.admin.AdminUtils
+import kafka.producer.{KeyedMessage, Producer, ProducerConfig}
+import kafka.serializer.StringEncoder
+import kafka.server.{KafkaConfig, KafkaServer}
+import kafka.utils.ZKStringSerializer
+import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer}
+import org.I0Itec.zkclient.ZkClient
+
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.streaming.Time
+import org.apache.spark.util.Utils
+
+/**
+ * This is a helper class for Kafka test suites. This has the functionality to set up
+ * and tear down local Kafka servers, and to push data using Kafka producers.
+ *
+ * The reason to put Kafka test utility class in src is to test Python related Kafka APIs.
+ */
+private class KafkaTestUtils extends Logging {
+
+  // Zookeeper related configurations
+  private val zkHost = "localhost"
+  private var zkPort: Int = 0
+  private val zkConnectionTimeout = 6000
+  private val zkSessionTimeout = 6000
+
+  private var zookeeper: EmbeddedZookeeper = _
+
+  private var zkClient: ZkClient = _
+
+  // Kafka broker related configurations
+  private val brokerHost = "localhost"
+  private var brokerPort = 9092
+  private var brokerConf: KafkaConfig = _
+
+  // Kafka broker server
+  private var server: KafkaServer = _
+
+  // Kafka producer
+  private var producer: Producer[String, String] = _
+
+  // Flag to test whether the system is correctly started
+  private var zkReady = false
+  private var brokerReady = false
+
+  def zkAddress: String = {
+    assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address")
+    s"$zkHost:$zkPort"
+  }
+
+  def brokerAddress: String = {
+    assert(brokerReady, "Kafka not setup yet or already torn down, cannot get broker address")
+    s"$brokerHost:$brokerPort"
+  }
+
+  def zookeeperClient: ZkClient = {
+    assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper client")
+    Option(zkClient).getOrElse(
+      throw new IllegalStateException("Zookeeper client is not yet initialized"))
+  }
+
+  // Set up the Embedded Zookeeper server and get the proper Zookeeper port
+  private def setupEmbeddedZookeeper(): Unit = {
+    // Zookeeper server startup
+    zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort")
+    // Get the actual zookeeper binding port
+    zkPort = zookeeper.actualPort
+    zkClient = new ZkClient(s"$zkHost:$zkPort", zkSessionTimeout, zkConnectionTimeout,
+      ZKStringSerializer)
+    zkReady = true
+  }
+
+  // Set up the Embedded Kafka server
+  private def setupEmbeddedKafkaServer(): Unit = {
+    assert(zkReady, "Zookeeper should be set up beforehand")
+
+    // Kafka broker startup
+    Utils.startServiceOnPort(brokerPort, port => {
+      brokerPort = port
+      brokerConf = new KafkaConfig(brokerConfiguration)
+      server = new KafkaServer(brokerConf)
+      server.startup()
+      (server, port)
+    }, new SparkConf(), "KafkaBroker")
+
+    brokerReady = true
+  }
+
+  /** setup the whole embedded servers, including Zookeeper and Kafka brokers */
+  def setup(): Unit = {
+    setupEmbeddedZookeeper()
+    setupEmbeddedKafkaServer()
+  }
+
+  /** Teardown the whole servers, including Kafka broker and Zookeeper */
+  def teardown(): Unit = {
+    brokerReady = false
+    zkReady = false
+
+    if (producer != null) {
+      producer.close()
+      producer = null
+    }
+
+    if (server != null) {
+      server.shutdown()
+      server = null
+    }
+
+    brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) }
+
+    if (zkClient != null) {
+      zkClient.close()
+      zkClient = null
+    }
+
+    if (zookeeper != null) {
+      zookeeper.shutdown()
+      zookeeper = null
+    }
+  }
+
+  /** Create a Kafka topic and wait until it propagated to the whole cluster */
+  def createTopic(topic: String): Unit = {
+    AdminUtils.createTopic(zkClient, topic, 1, 1)
+    // wait until metadata is propagated
+    waitUntilMetadataIsPropagated(topic, 0)
+  }
+
+  /** Java-friendly function for sending messages to the Kafka broker */
+  def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = {
+    import scala.collection.JavaConversions._
+    sendMessages(topic, Map(messageToFreq.mapValues(_.intValue()).toSeq: _*))
+  }
+
+  /** Send the messages to the Kafka broker */
+  def sendMessages(topic: String, messageToFreq: Map[String, Int]): Unit = {
+    val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray
+    sendMessages(topic, messages)
+  }
+
+  /** Send the array of messages to the Kafka broker */
+  def sendMessages(topic: String, messages: Array[String]): Unit = {
+    producer = new Producer[String, String](new ProducerConfig(producerConfiguration))
+    producer.send(messages.map { new KeyedMessage[String, String](topic, _ ) }: _*)
+    producer.close()
+    producer = null
+  }
+
+  private def brokerConfiguration: Properties = {
+    val props = new Properties()
+    props.put("broker.id", "0")
+    props.put("host.name", "localhost")
+    props.put("port", brokerPort.toString)
+    props.put("log.dir", Utils.createTempDir().getAbsolutePath)
+    props.put("zookeeper.connect", zkAddress)
+    props.put("log.flush.interval.messages", "1")
+    props.put("replica.socket.timeout.ms", "1500")
+    props
+  }
+
+  private def producerConfiguration: Properties = {
+    val props = new Properties()
+    props.put("metadata.broker.list", brokerAddress)
+    props.put("serializer.class", classOf[StringEncoder].getName)
+    props
+  }
+
+  // A simplified version of scalatest eventually, rewritten here to avoid adding extra test
+  // dependency
+  def eventually[T](timeout: Time, interval: Time)(func: => T): T = {
+    def makeAttempt(): Either[Throwable, T] = {
+      try {
+        Right(func)
+      } catch {
+        case e if NonFatal(e) => Left(e)
+      }
+    }
+
+    val startTime = System.currentTimeMillis()
+    @tailrec
+    def tryAgain(attempt: Int): T = {
+      makeAttempt() match {
+        case Right(result) => result
+        case Left(e) =>
+          val duration = System.currentTimeMillis() - startTime
+          if (duration < timeout.milliseconds) {
+            Thread.sleep(interval.milliseconds)
+          } else {
+            throw new TimeoutException(e.getMessage)
+          }
+
+          tryAgain(attempt + 1)
+      }
+    }
+
+    tryAgain(1)
+  }
+
+  private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = {
+    eventually(Time(10000), Time(100)) {
+      assert(
+        server.apis.metadataCache.containsTopicAndPartition(topic, partition),
+        s"Partition [$topic, $partition] metadata not propagated after timeout"
+      )
+    }
+  }
+
+  private class EmbeddedZookeeper(val zkConnect: String) {
+    val snapshotDir = Utils.createTempDir()
+    val logDir = Utils.createTempDir()
+
+    val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500)
+    val (ip, port) = {
+      val splits = zkConnect.split(":")
+      (splits(0), splits(1).toInt)
+    }
+    val factory = new NIOServerCnxnFactory()
+    factory.configure(new InetSocketAddress(ip, port), 16)
+    factory.startup(zookeeper)
+
+    val actualPort = factory.getLocalPort
+
+    def shutdown() {
+      factory.shutdown()
+      Utils.deleteRecursively(snapshotDir)
+      Utils.deleteRecursively(logDir)
+    }
+  }
+}
+
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
index d6ca6d58b56657fad27dc2d24842c3e17573f55e..4c1d6a03eb2b8419cc82d8db1a0a2b505274e8dd 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
@@ -41,24 +41,28 @@ import org.apache.spark.streaming.api.java.JavaStreamingContext;
 
 public class JavaDirectKafkaStreamSuite implements Serializable {
   private transient JavaStreamingContext ssc = null;
-  private transient KafkaStreamSuiteBase suiteBase = null;
+  private transient KafkaTestUtils kafkaTestUtils = null;
 
   @Before
   public void setUp() {
-      suiteBase = new KafkaStreamSuiteBase() { };
-      suiteBase.setupKafka();
-      System.clearProperty("spark.driver.port");
-      SparkConf sparkConf = new SparkConf()
-              .setMaster("local[4]").setAppName(this.getClass().getSimpleName());
-      ssc = new JavaStreamingContext(sparkConf, Durations.milliseconds(200));
+    kafkaTestUtils = new KafkaTestUtils();
+    kafkaTestUtils.setup();
+    SparkConf sparkConf = new SparkConf()
+      .setMaster("local[4]").setAppName(this.getClass().getSimpleName());
+    ssc = new JavaStreamingContext(sparkConf, Durations.milliseconds(200));
   }
 
   @After
   public void tearDown() {
+    if (ssc != null) {
       ssc.stop();
       ssc = null;
-      System.clearProperty("spark.driver.port");
-      suiteBase.tearDownKafka();
+    }
+
+    if (kafkaTestUtils != null) {
+      kafkaTestUtils.teardown();
+      kafkaTestUtils = null;
+    }
   }
 
   @Test
@@ -74,7 +78,7 @@ public class JavaDirectKafkaStreamSuite implements Serializable {
     sent.addAll(Arrays.asList(topic2data));
 
     HashMap<String, String> kafkaParams = new HashMap<String, String>();
-    kafkaParams.put("metadata.broker.list", suiteBase.brokerAddress());
+    kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress());
     kafkaParams.put("auto.offset.reset", "smallest");
 
     JavaDStream<String> stream1 = KafkaUtils.createDirectStream(
@@ -147,8 +151,8 @@ public class JavaDirectKafkaStreamSuite implements Serializable {
 
   private  String[] createTopicAndSendData(String topic) {
     String[] data = { topic + "-1", topic + "-2", topic + "-3"};
-    suiteBase.createTopic(topic);
-    suiteBase.sendMessages(topic, data);
+    kafkaTestUtils.createTopic(topic);
+    kafkaTestUtils.sendMessages(topic, data);
     return data;
   }
 }
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
index 4477b81827c7024b4c2edcd66ba8ccd1b948b142..a9dc6e50613ca58ac6de67bae7f51d556ed1308b 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
@@ -37,13 +37,12 @@ import org.apache.spark.api.java.function.Function;
 
 public class JavaKafkaRDDSuite implements Serializable {
   private transient JavaSparkContext sc = null;
-  private transient KafkaStreamSuiteBase suiteBase = null;
+  private transient KafkaTestUtils kafkaTestUtils = null;
 
   @Before
   public void setUp() {
-    suiteBase = new KafkaStreamSuiteBase() { };
-    suiteBase.setupKafka();
-    System.clearProperty("spark.driver.port");
+    kafkaTestUtils = new KafkaTestUtils();
+    kafkaTestUtils.setup();
     SparkConf sparkConf = new SparkConf()
       .setMaster("local[4]").setAppName(this.getClass().getSimpleName());
     sc = new JavaSparkContext(sparkConf);
@@ -51,10 +50,15 @@ public class JavaKafkaRDDSuite implements Serializable {
 
   @After
   public void tearDown() {
-    sc.stop();
-    sc = null;
-    System.clearProperty("spark.driver.port");
-    suiteBase.tearDownKafka();
+    if (sc != null) {
+      sc.stop();
+      sc = null;
+    }
+
+    if (kafkaTestUtils != null) {
+      kafkaTestUtils.teardown();
+      kafkaTestUtils = null;
+    }
   }
 
   @Test
@@ -66,7 +70,7 @@ public class JavaKafkaRDDSuite implements Serializable {
     String[] topic2data = createTopicAndSendData(topic2);
 
     HashMap<String, String> kafkaParams = new HashMap<String, String>();
-    kafkaParams.put("metadata.broker.list", suiteBase.brokerAddress());
+    kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress());
 
     OffsetRange[] offsetRanges = {
       OffsetRange.create(topic1, 0, 0, 1),
@@ -75,7 +79,7 @@ public class JavaKafkaRDDSuite implements Serializable {
 
     HashMap<TopicAndPartition, Broker> emptyLeaders = new HashMap<TopicAndPartition, Broker>();
     HashMap<TopicAndPartition, Broker> leaders = new HashMap<TopicAndPartition, Broker>();
-    String[] hostAndPort = suiteBase.brokerAddress().split(":");
+    String[] hostAndPort = kafkaTestUtils.brokerAddress().split(":");
     Broker broker = Broker.create(hostAndPort[0], Integer.parseInt(hostAndPort[1]));
     leaders.put(new TopicAndPartition(topic1, 0), broker);
     leaders.put(new TopicAndPartition(topic2, 0), broker);
@@ -144,8 +148,8 @@ public class JavaKafkaRDDSuite implements Serializable {
 
   private  String[] createTopicAndSendData(String topic) {
     String[] data = { topic + "-1", topic + "-2", topic + "-3"};
-    suiteBase.createTopic(topic);
-    suiteBase.sendMessages(topic, data);
+    kafkaTestUtils.createTopic(topic);
+    kafkaTestUtils.sendMessages(topic, data);
     return data;
   }
 }
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
index bad0a93eb2e84021e1fb9b6ab5646d260d1db322..540f4ceabab47259ac8dc4727e74f3c0dd5070be 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
@@ -22,9 +22,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Random;
 
-import scala.Predef;
 import scala.Tuple2;
-import scala.collection.JavaConverters;
 
 import kafka.serializer.StringDecoder;
 import org.junit.After;
@@ -44,13 +42,12 @@ import org.apache.spark.streaming.api.java.JavaStreamingContext;
 public class JavaKafkaStreamSuite implements Serializable {
   private transient JavaStreamingContext ssc = null;
   private transient Random random = new Random();
-  private transient KafkaStreamSuiteBase suiteBase = null;
+  private transient KafkaTestUtils kafkaTestUtils = null;
 
   @Before
   public void setUp() {
-    suiteBase = new KafkaStreamSuiteBase() { };
-    suiteBase.setupKafka();
-    System.clearProperty("spark.driver.port");
+    kafkaTestUtils = new KafkaTestUtils();
+    kafkaTestUtils.setup();
     SparkConf sparkConf = new SparkConf()
       .setMaster("local[4]").setAppName(this.getClass().getSimpleName());
     ssc = new JavaStreamingContext(sparkConf, new Duration(500));
@@ -58,10 +55,15 @@ public class JavaKafkaStreamSuite implements Serializable {
 
   @After
   public void tearDown() {
-    ssc.stop();
-    ssc = null;
-    System.clearProperty("spark.driver.port");
-    suiteBase.tearDownKafka();
+    if (ssc != null) {
+      ssc.stop();
+      ssc = null;
+    }
+
+    if (kafkaTestUtils != null) {
+      kafkaTestUtils.teardown();
+      kafkaTestUtils = null;
+    }
   }
 
   @Test
@@ -75,15 +77,11 @@ public class JavaKafkaStreamSuite implements Serializable {
     sent.put("b", 3);
     sent.put("c", 10);
 
-    suiteBase.createTopic(topic);
-    HashMap<String, Object> tmp = new HashMap<String, Object>(sent);
-    suiteBase.sendMessages(topic,
-        JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap(
-            Predef.<Tuple2<String, Object>>conforms())
-    );
+    kafkaTestUtils.createTopic(topic);
+    kafkaTestUtils.sendMessages(topic, sent);
 
     HashMap<String, String> kafkaParams = new HashMap<String, String>();
-    kafkaParams.put("zookeeper.connect", suiteBase.zkAddress());
+    kafkaParams.put("zookeeper.connect", kafkaTestUtils.zkAddress());
     kafkaParams.put("group.id", "test-consumer-" + random.nextInt(10000));
     kafkaParams.put("auto.offset.reset", "smallest");
 
@@ -126,6 +124,7 @@ public class JavaKafkaStreamSuite implements Serializable {
     );
 
     ssc.start();
+
     long startTime = System.currentTimeMillis();
     boolean sizeMatches = false;
     while (!sizeMatches && System.currentTimeMillis() - startTime < 20000) {
@@ -136,6 +135,5 @@ public class JavaKafkaStreamSuite implements Serializable {
     for (String k : sent.keySet()) {
       Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue());
     }
-    ssc.stop();
   }
 }
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
index 17ca9d145d665434d8451a68a8df27d1623b39fa..415730f5559c59eb819954cc1246d3d842762529 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
@@ -27,31 +27,41 @@ import scala.language.postfixOps
 import kafka.common.TopicAndPartition
 import kafka.message.MessageAndMetadata
 import kafka.serializer.StringDecoder
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
 import org.scalatest.concurrent.Eventually
 
-import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.{Logging, SparkConf, SparkContext}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time}
 import org.apache.spark.streaming.dstream.DStream
 import org.apache.spark.util.Utils
 
-class DirectKafkaStreamSuite extends KafkaStreamSuiteBase
-  with BeforeAndAfter with BeforeAndAfterAll with Eventually {
+class DirectKafkaStreamSuite
+  extends FunSuite
+  with BeforeAndAfter
+  with BeforeAndAfterAll
+  with Eventually
+  with Logging {
   val sparkConf = new SparkConf()
     .setMaster("local[4]")
     .setAppName(this.getClass.getSimpleName)
 
-  var sc: SparkContext = _
-  var ssc: StreamingContext = _
-  var testDir: File = _
+  private var sc: SparkContext = _
+  private var ssc: StreamingContext = _
+  private var testDir: File = _
+
+  private var kafkaTestUtils: KafkaTestUtils = _
 
   override def beforeAll {
-    setupKafka()
+    kafkaTestUtils = new KafkaTestUtils
+    kafkaTestUtils.setup()
   }
 
   override def afterAll {
-    tearDownKafka()
+    if (kafkaTestUtils != null) {
+      kafkaTestUtils.teardown()
+      kafkaTestUtils = null
+    }
   }
 
   after {
@@ -72,12 +82,12 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase
     val topics = Set("basic1", "basic2", "basic3")
     val data = Map("a" -> 7, "b" -> 9)
     topics.foreach { t =>
-      createTopic(t)
-      sendMessages(t, data)
+      kafkaTestUtils.createTopic(t)
+      kafkaTestUtils.sendMessages(t, data)
     }
     val totalSent = data.values.sum * topics.size
     val kafkaParams = Map(
-      "metadata.broker.list" -> s"$brokerAddress",
+      "metadata.broker.list" -> kafkaTestUtils.brokerAddress,
       "auto.offset.reset" -> "smallest"
     )
 
@@ -121,9 +131,9 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase
     val topic = "largest"
     val topicPartition = TopicAndPartition(topic, 0)
     val data = Map("a" -> 10)
-    createTopic(topic)
+    kafkaTestUtils.createTopic(topic)
     val kafkaParams = Map(
-      "metadata.broker.list" -> s"$brokerAddress",
+      "metadata.broker.list" -> kafkaTestUtils.brokerAddress,
       "auto.offset.reset" -> "largest"
     )
     val kc = new KafkaCluster(kafkaParams)
@@ -132,7 +142,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase
     }
 
     // Send some initial messages before starting context
-    sendMessages(topic, data)
+    kafkaTestUtils.sendMessages(topic, data)
     eventually(timeout(10 seconds), interval(20 milliseconds)) {
       assert(getLatestOffset() > 3)
     }
@@ -154,7 +164,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase
     stream.map { _._2 }.foreachRDD { rdd => collectedData ++= rdd.collect() }
     ssc.start()
     val newData = Map("b" -> 10)
-    sendMessages(topic, newData)
+    kafkaTestUtils.sendMessages(topic, newData)
     eventually(timeout(10 seconds), interval(50 milliseconds)) {
       collectedData.contains("b")
     }
@@ -166,9 +176,9 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase
     val topic = "offset"
     val topicPartition = TopicAndPartition(topic, 0)
     val data = Map("a" -> 10)
-    createTopic(topic)
+    kafkaTestUtils.createTopic(topic)
     val kafkaParams = Map(
-      "metadata.broker.list" -> s"$brokerAddress",
+      "metadata.broker.list" -> kafkaTestUtils.brokerAddress,
       "auto.offset.reset" -> "largest"
     )
     val kc = new KafkaCluster(kafkaParams)
@@ -177,7 +187,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase
     }
 
     // Send some initial messages before starting context
-    sendMessages(topic, data)
+    kafkaTestUtils.sendMessages(topic, data)
     eventually(timeout(10 seconds), interval(20 milliseconds)) {
       assert(getLatestOffset() >= 10)
     }
@@ -200,7 +210,7 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase
     stream.foreachRDD { rdd => collectedData ++= rdd.collect() }
     ssc.start()
     val newData = Map("b" -> 10)
-    sendMessages(topic, newData)
+    kafkaTestUtils.sendMessages(topic, newData)
     eventually(timeout(10 seconds), interval(50 milliseconds)) {
       collectedData.contains("b")
     }
@@ -210,18 +220,18 @@ class DirectKafkaStreamSuite extends KafkaStreamSuiteBase
   // Test to verify the offset ranges can be recovered from the checkpoints
   test("offset recovery") {
     val topic = "recovery"
-    createTopic(topic)
+    kafkaTestUtils.createTopic(topic)
     testDir = Utils.createTempDir()
 
     val kafkaParams = Map(
-      "metadata.broker.list" -> s"$brokerAddress",
+      "metadata.broker.list" -> kafkaTestUtils.brokerAddress,
       "auto.offset.reset" -> "smallest"
     )
 
     // Send data to Kafka and wait for it to be received
     def sendDataAndWaitForReceive(data: Seq[Int]) {
       val strings = data.map { _.toString}
-      sendMessages(topic, strings.map { _ -> 1}.toMap)
+      kafkaTestUtils.sendMessages(topic, strings.map { _ -> 1}.toMap)
       eventually(timeout(10 seconds), interval(50 milliseconds)) {
         assert(strings.forall { DirectKafkaStreamSuite.collectedData.contains })
       }
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala
index fc9275b7207befea723a6233e9807f52a7140aa1..2b33d2a220b2b10e1cacb1d3484a645a0c830708 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala
@@ -20,28 +20,35 @@ package org.apache.spark.streaming.kafka
 import scala.util.Random
 
 import kafka.common.TopicAndPartition
-import org.scalatest.BeforeAndAfterAll
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
 
-class KafkaClusterSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll {
-  val topic = "kcsuitetopic" + Random.nextInt(10000)
-  val topicAndPartition = TopicAndPartition(topic, 0)
-  var kc: KafkaCluster = null
+class KafkaClusterSuite extends FunSuite with BeforeAndAfterAll {
+  private val topic = "kcsuitetopic" + Random.nextInt(10000)
+  private val topicAndPartition = TopicAndPartition(topic, 0)
+  private var kc: KafkaCluster = null
+
+  private var kafkaTestUtils: KafkaTestUtils = _
 
   override def beforeAll() {
-    setupKafka()
-    createTopic(topic)
-    sendMessages(topic, Map("a" -> 1))
-    kc = new KafkaCluster(Map("metadata.broker.list" -> s"$brokerAddress"))
+    kafkaTestUtils = new KafkaTestUtils
+    kafkaTestUtils.setup()
+
+    kafkaTestUtils.createTopic(topic)
+    kafkaTestUtils.sendMessages(topic, Map("a" -> 1))
+    kc = new KafkaCluster(Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress))
   }
 
   override def afterAll() {
-    tearDownKafka()
+    if (kafkaTestUtils != null) {
+      kafkaTestUtils.teardown()
+      kafkaTestUtils = null
+    }
   }
 
   test("metadata apis") {
     val leader = kc.findLeaders(Set(topicAndPartition)).right.get(topicAndPartition)
     val leaderAddress = s"${leader._1}:${leader._2}"
-    assert(leaderAddress === brokerAddress, "didn't get leader")
+    assert(leaderAddress === kafkaTestUtils.brokerAddress, "didn't get leader")
 
     val parts = kc.getPartitions(Set(topic)).right.get
     assert(parts(topicAndPartition), "didn't get partitions")
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
index a223da70b043fc9015971b912bae94dfe04e0389..7d26ce50875b363810c3e69e5d14c0bcf29db430 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
@@ -22,18 +22,22 @@ import scala.util.Random
 import kafka.serializer.StringDecoder
 import kafka.common.TopicAndPartition
 import kafka.message.MessageAndMetadata
-import org.scalatest.BeforeAndAfterAll
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
 
 import org.apache.spark._
-import org.apache.spark.SparkContext._
 
-class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll {
-  val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
-  var sc: SparkContext = _
+class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll {
+
+  private var kafkaTestUtils: KafkaTestUtils = _
+
+  private val sparkConf = new SparkConf().setMaster("local[4]")
+    .setAppName(this.getClass.getSimpleName)
+  private var sc: SparkContext = _
+
   override def beforeAll {
     sc = new SparkContext(sparkConf)
-
-    setupKafka()
+    kafkaTestUtils = new KafkaTestUtils
+    kafkaTestUtils.setup()
   }
 
   override def afterAll {
@@ -41,17 +45,21 @@ class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll {
       sc.stop
       sc = null
     }
-    tearDownKafka()
+
+    if (kafkaTestUtils != null) {
+      kafkaTestUtils.teardown()
+      kafkaTestUtils = null
+    }
   }
 
   test("basic usage") {
     val topic = "topicbasic"
-    createTopic(topic)
+    kafkaTestUtils.createTopic(topic)
     val messages = Set("the", "quick", "brown", "fox")
-    sendMessages(topic, messages.toArray)
+    kafkaTestUtils.sendMessages(topic, messages.toArray)
 
 
-    val kafkaParams = Map("metadata.broker.list" -> brokerAddress,
+    val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress,
       "group.id" -> s"test-consumer-${Random.nextInt(10000)}")
 
     val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size))
@@ -67,15 +75,15 @@ class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll {
     // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd
     val topic = "topic1"
     val sent = Map("a" -> 5, "b" -> 3, "c" -> 10)
-    createTopic(topic)
+    kafkaTestUtils.createTopic(topic)
 
-    val kafkaParams = Map("metadata.broker.list" -> brokerAddress,
+    val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress,
       "group.id" -> s"test-consumer-${Random.nextInt(10000)}")
 
     val kc = new KafkaCluster(kafkaParams)
 
     // this is the "lots of messages" case
-    sendMessages(topic, sent)
+    kafkaTestUtils.sendMessages(topic, sent)
     // rdd defined from leaders after sending messages, should get the number sent
     val rdd = getRdd(kc, Set(topic))
 
@@ -92,14 +100,14 @@ class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll {
     // shouldn't get anything, since message is sent after rdd was defined
     val sentOnlyOne = Map("d" -> 1)
 
-    sendMessages(topic, sentOnlyOne)
+    kafkaTestUtils.sendMessages(topic, sentOnlyOne)
     assert(rdd2.isDefined)
     assert(rdd2.get.count === 0, "got messages when there shouldn't be any")
 
     // this is the "exactly 1 message" case, namely the single message from sentOnlyOne above
     val rdd3 = getRdd(kc, Set(topic))
     // send lots of messages after rdd was defined, they shouldn't show up
-    sendMessages(topic, Map("extra" -> 22))
+    kafkaTestUtils.sendMessages(topic, Map("extra" -> 22))
 
     assert(rdd3.isDefined)
     assert(rdd3.get.count === sentOnlyOne.values.sum, "didn't get exactly one message")
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
index e4966eebb9b34cdbb405b33d497f61a801a4f045..24699dfc33adb819e29b153b3a099a5dc554b72b 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
@@ -17,209 +17,38 @@
 
 package org.apache.spark.streaming.kafka
 
-import java.io.File
-import java.net.InetSocketAddress
-import java.util.Properties
-
 import scala.collection.mutable
 import scala.concurrent.duration._
 import scala.language.postfixOps
 import scala.util.Random
 
-import kafka.admin.AdminUtils
-import kafka.common.{KafkaException, TopicAndPartition}
-import kafka.producer.{KeyedMessage, Producer, ProducerConfig}
-import kafka.serializer.{StringDecoder, StringEncoder}
-import kafka.server.{KafkaConfig, KafkaServer}
-import kafka.utils.ZKStringSerializer
-import org.I0Itec.zkclient.ZkClient
-import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer}
-import org.scalatest.{BeforeAndAfter, FunSuite}
+import kafka.serializer.StringDecoder
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
 import org.scalatest.concurrent.Eventually
 
-import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.SparkConf
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.streaming.{Milliseconds, StreamingContext}
-import org.apache.spark.util.Utils
-
-/**
- * This is an abstract base class for Kafka testsuites. This has the functionality to set up
- * and tear down local Kafka servers, and to push data using Kafka producers.
- */
-abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Logging {
-
-  private val zkHost = "localhost"
-  private var zkPort: Int = 0
-  private val zkConnectionTimeout = 6000
-  private val zkSessionTimeout = 6000
-  private var zookeeper: EmbeddedZookeeper = _
-  private val brokerHost = "localhost"
-  private var brokerPort = 9092
-  private var brokerConf: KafkaConfig = _
-  private var server: KafkaServer = _
-  private var producer: Producer[String, String] = _
-  private var zkReady = false
-  private var brokerReady = false
-
-  protected var zkClient: ZkClient = _
-
-  def zkAddress: String = {
-    assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address")
-    s"$zkHost:$zkPort"
-  }
 
-  def brokerAddress: String = {
-    assert(brokerReady, "Kafka not setup yet or already torn down, cannot get broker address")
-    s"$brokerHost:$brokerPort"
-  }
-
-  def setupKafka() {
-    // Zookeeper server startup
-    zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort")
-    // Get the actual zookeeper binding port
-    zkPort = zookeeper.actualPort
-    zkReady = true
-    logInfo("==================== Zookeeper Started ====================")
+class KafkaStreamSuite extends FunSuite with Eventually with BeforeAndAfterAll {
+  private var ssc: StreamingContext = _
+  private var kafkaTestUtils: KafkaTestUtils = _
 
-    zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer)
-    logInfo("==================== Zookeeper Client Created ====================")
-
-    // Kafka broker startup
-    var bindSuccess: Boolean = false
-    while(!bindSuccess) {
-      try {
-        val brokerProps = getBrokerConfig()
-        brokerConf = new KafkaConfig(brokerProps)
-        server = new KafkaServer(brokerConf)
-        server.startup()
-        logInfo("==================== Kafka Broker Started ====================")
-        bindSuccess = true
-      } catch {
-        case e: KafkaException =>
-          if (e.getMessage != null && e.getMessage.contains("Socket server failed to bind to")) {
-            brokerPort += 1
-          }
-        case e: Exception => throw new Exception("Kafka server create failed", e)
-      }
-    }
-
-    Thread.sleep(2000)
-    logInfo("==================== Kafka + Zookeeper Ready ====================")
-    brokerReady = true
+  override def beforeAll(): Unit = {
+    kafkaTestUtils = new KafkaTestUtils
+    kafkaTestUtils.setup()
   }
 
-  def tearDownKafka() {
-    brokerReady = false
-    zkReady = false
-    if (producer != null) {
-      producer.close()
-      producer = null
-    }
-
-    if (server != null) {
-      server.shutdown()
-      server = null
-    }
-
-    brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) }
-
-    if (zkClient != null) {
-      zkClient.close()
-      zkClient = null
-    }
-
-    if (zookeeper != null) {
-      zookeeper.shutdown()
-      zookeeper = null
-    }
-  }
-
-  def createTopic(topic: String) {
-    AdminUtils.createTopic(zkClient, topic, 1, 1)
-    // wait until metadata is propagated
-    waitUntilMetadataIsPropagated(topic, 0)
-    logInfo(s"==================== Topic $topic Created ====================")
-  }
-
-  def sendMessages(topic: String, messageToFreq: Map[String, Int]) {
-    val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray
-    sendMessages(topic, messages)
-  }
-  
-  def sendMessages(topic: String, messages: Array[String]) {
-    producer = new Producer[String, String](new ProducerConfig(getProducerConfig()))
-    producer.send(messages.map { new KeyedMessage[String, String](topic, _ ) }: _*)
-    producer.close()
-    logInfo(s"==================== Sent Messages: ${messages.mkString(", ")} ====================")
-  }
-
-  private def getBrokerConfig(): Properties = {
-    val props = new Properties()
-    props.put("broker.id", "0")
-    props.put("host.name", "localhost")
-    props.put("port", brokerPort.toString)
-    props.put("log.dir", Utils.createTempDir().getAbsolutePath)
-    props.put("zookeeper.connect", zkAddress)
-    props.put("log.flush.interval.messages", "1")
-    props.put("replica.socket.timeout.ms", "1500")
-    props
-  }
-
-  private def getProducerConfig(): Properties = {
-    val brokerAddr = brokerConf.hostName + ":" + brokerConf.port
-    val props = new Properties()
-    props.put("metadata.broker.list", brokerAddr)
-    props.put("serializer.class", classOf[StringEncoder].getName)
-    props
-  }
-
-  private def waitUntilMetadataIsPropagated(topic: String, partition: Int) {
-    eventually(timeout(10000 milliseconds), interval(100 milliseconds)) {
-      assert(
-        server.apis.metadataCache.containsTopicAndPartition(topic, partition),
-        s"Partition [$topic, $partition] metadata not propagated after timeout"
-      )
-    }
-  }
-
-  class EmbeddedZookeeper(val zkConnect: String) {
-    val random = new Random()
-    val snapshotDir = Utils.createTempDir()
-    val logDir = Utils.createTempDir()
-
-    val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500)
-    val (ip, port) = {
-      val splits = zkConnect.split(":")
-      (splits(0), splits(1).toInt)
-    }
-    val factory = new NIOServerCnxnFactory()
-    factory.configure(new InetSocketAddress(ip, port), 16)
-    factory.startup(zookeeper)
-
-    val actualPort = factory.getLocalPort
-
-    def shutdown() {
-      factory.shutdown()
-      Utils.deleteRecursively(snapshotDir)
-      Utils.deleteRecursively(logDir)
-    }
-  }
-}
-
-
-class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter {
-  var ssc: StreamingContext = _
-
-  before {
-    setupKafka()
-  }
-
-  after {
+  override def afterAll(): Unit = {
     if (ssc != null) {
       ssc.stop()
       ssc = null
     }
-    tearDownKafka()
+
+    if (kafkaTestUtils != null) {
+      kafkaTestUtils.teardown()
+      kafkaTestUtils = null
+    }
   }
 
   test("Kafka input stream") {
@@ -227,10 +56,10 @@ class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter {
     ssc = new StreamingContext(sparkConf, Milliseconds(500))
     val topic = "topic1"
     val sent = Map("a" -> 5, "b" -> 3, "c" -> 10)
-    createTopic(topic)
-    sendMessages(topic, sent)
+    kafkaTestUtils.createTopic(topic)
+    kafkaTestUtils.sendMessages(topic, sent)
 
-    val kafkaParams = Map("zookeeper.connect" -> zkAddress,
+    val kafkaParams = Map("zookeeper.connect" -> kafkaTestUtils.zkAddress,
       "group.id" -> s"test-consumer-${Random.nextInt(10000)}",
       "auto.offset.reset" -> "smallest")
 
@@ -244,14 +73,14 @@ class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter {
         result.put(kv._1, count)
       }
     }
+
     ssc.start()
+
     eventually(timeout(10000 milliseconds), interval(100 milliseconds)) {
       assert(sent.size === result.size)
       sent.keys.foreach { k =>
         assert(sent(k) === result(k).toInt)
       }
     }
-    ssc.stop()
   }
 }
-
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
index 3cd960d1fd1d41bc867175dc864d4ff0c2e0d9a1..38548dd73b82c9a94f4fbf19c221ec26ac81b3c4 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.streaming.kafka
 
-
 import java.io.File
 
 import scala.collection.mutable
@@ -27,7 +26,7 @@ import scala.util.Random
 
 import kafka.serializer.StringDecoder
 import kafka.utils.{ZKGroupTopicDirs, ZkUtils}
-import org.scalatest.BeforeAndAfter
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
 import org.scalatest.concurrent.Eventually
 
 import org.apache.spark.SparkConf
@@ -35,47 +34,61 @@ import org.apache.spark.storage.StorageLevel
 import org.apache.spark.streaming.{Milliseconds, StreamingContext}
 import org.apache.spark.util.Utils
 
-class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter with Eventually {
+class ReliableKafkaStreamSuite extends FunSuite
+    with BeforeAndAfterAll with BeforeAndAfter with Eventually {
 
-  val sparkConf = new SparkConf()
+  private val sparkConf = new SparkConf()
     .setMaster("local[4]")
     .setAppName(this.getClass.getSimpleName)
     .set("spark.streaming.receiver.writeAheadLog.enable", "true")
-  val data = Map("a" -> 10, "b" -> 10, "c" -> 10)
+  private val data = Map("a" -> 10, "b" -> 10, "c" -> 10)
 
+  private var kafkaTestUtils: KafkaTestUtils = _
 
-  var groupId: String = _
-  var kafkaParams: Map[String, String] = _
-  var ssc: StreamingContext = _
-  var tempDirectory: File = null
+  private var groupId: String = _
+  private var kafkaParams: Map[String, String] = _
+  private var ssc: StreamingContext = _
+  private var tempDirectory: File = null
+
+  override def beforeAll() : Unit = {
+    kafkaTestUtils = new KafkaTestUtils
+    kafkaTestUtils.setup()
 
-  before {
-    setupKafka()
     groupId = s"test-consumer-${Random.nextInt(10000)}"
     kafkaParams = Map(
-      "zookeeper.connect" -> zkAddress,
+      "zookeeper.connect" -> kafkaTestUtils.zkAddress,
       "group.id" -> groupId,
       "auto.offset.reset" -> "smallest"
     )
 
-    ssc = new StreamingContext(sparkConf, Milliseconds(500))
     tempDirectory = Utils.createTempDir()
+  }
+
+  override def afterAll(): Unit = {
+    Utils.deleteRecursively(tempDirectory)
+
+    if (kafkaTestUtils != null) {
+      kafkaTestUtils.teardown()
+      kafkaTestUtils = null
+    }
+  }
+
+  before {
+    ssc = new StreamingContext(sparkConf, Milliseconds(500))
     ssc.checkpoint(tempDirectory.getAbsolutePath)
   }
 
   after {
     if (ssc != null) {
       ssc.stop()
+      ssc = null
     }
-    Utils.deleteRecursively(tempDirectory)
-    tearDownKafka()
   }
 
-
   test("Reliable Kafka input stream with single topic") {
-    var topic = "test-topic"
-    createTopic(topic)
-    sendMessages(topic, data)
+    val topic = "test-topic"
+    kafkaTestUtils.createTopic(topic)
+    kafkaTestUtils.sendMessages(topic, data)
 
     // Verify whether the offset of this group/topic/partition is 0 before starting.
     assert(getCommitOffset(groupId, topic, 0) === None)
@@ -91,6 +104,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter
         }
       }
     ssc.start()
+
     eventually(timeout(20000 milliseconds), interval(200 milliseconds)) {
       // A basic process verification for ReliableKafkaReceiver.
       // Verify whether received message number is equal to the sent message number.
@@ -100,14 +114,13 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter
       // Verify the offset number whether it is equal to the total message number.
       assert(getCommitOffset(groupId, topic, 0) === Some(29L))
     }
-    ssc.stop()
   }
 
   test("Reliable Kafka input stream with multiple topics") {
     val topics = Map("topic1" -> 1, "topic2" -> 1, "topic3" -> 1)
     topics.foreach { case (t, _) =>
-      createTopic(t)
-      sendMessages(t, data)
+      kafkaTestUtils.createTopic(t)
+      kafkaTestUtils.sendMessages(t, data)
     }
 
     // Before started, verify all the group/topic/partition offsets are 0.
@@ -118,19 +131,18 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter
       ssc, kafkaParams, topics, StorageLevel.MEMORY_ONLY)
     stream.foreachRDD(_ => Unit)
     ssc.start()
+
     eventually(timeout(20000 milliseconds), interval(100 milliseconds)) {
       // Verify the offset for each group/topic to see whether they are equal to the expected one.
       topics.foreach { case (t, _) => assert(getCommitOffset(groupId, t, 0) === Some(29L)) }
     }
-    ssc.stop()
   }
 
 
   /** Getting partition offset from Zookeeper. */
   private def getCommitOffset(groupId: String, topic: String, partition: Int): Option[Long] = {
-    assert(zkClient != null, "Zookeeper client is not initialized")
     val topicDirs = new ZKGroupTopicDirs(groupId, topic)
     val zkPath = s"${topicDirs.consumerOffsetDir}/$partition"
-    ZkUtils.readDataMaybeNull(zkClient, zkPath)._1.map(_.toLong)
+    ZkUtils.readDataMaybeNull(kafkaTestUtils.zookeeperClient, zkPath)._1.map(_.toLong)
   }
 }
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 608f8e26473a6a2677267b0454d36efda16a6e89..9b4635e49020b1b504192366a5d5efb4ad9cdd26 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -23,13 +23,16 @@ import unittest
 import tempfile
 import struct
 
+from py4j.java_collections import MapConverter
+
 from pyspark.context import SparkConf, SparkContext, RDD
 from pyspark.streaming.context import StreamingContext
+from pyspark.streaming.kafka import KafkaUtils
 
 
 class PySparkStreamingTestCase(unittest.TestCase):
 
-    timeout = 10  # seconds
+    timeout = 20  # seconds
     duration = 1
 
     def setUp(self):
@@ -556,5 +559,43 @@ class CheckpointTests(PySparkStreamingTestCase):
         check_output(3)
 
 
+class KafkaStreamTests(PySparkStreamingTestCase):
+
+    def setUp(self):
+        super(KafkaStreamTests, self).setUp()
+
+        kafkaTestUtilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
+            .loadClass("org.apache.spark.streaming.kafka.KafkaTestUtils")
+        self._kafkaTestUtils = kafkaTestUtilsClz.newInstance()
+        self._kafkaTestUtils.setup()
+
+    def tearDown(self):
+        if self._kafkaTestUtils is not None:
+            self._kafkaTestUtils.teardown()
+            self._kafkaTestUtils = None
+
+        super(KafkaStreamTests, self).tearDown()
+
+    def test_kafka_stream(self):
+        """Test the Python Kafka stream API."""
+        topic = "topic1"
+        sendData = {"a": 3, "b": 5, "c": 10}
+        jSendData = MapConverter().convert(sendData,
+                                           self.ssc.sparkContext._gateway._gateway_client)
+
+        self._kafkaTestUtils.createTopic(topic)
+        self._kafkaTestUtils.sendMessages(topic, jSendData)
+
+        stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(),
+                                         "test-streaming-consumer", {topic: 1},
+                                         {"auto.offset.reset": "smallest"})
+
+        result = {}
+        for i in chain.from_iterable(self._collect(stream.map(lambda x: x[1]),
+                                                   sum(sendData.values()))):
+            result[i] = result.get(i, 0) + 1
+
+        self.assertEqual(sendData, result)
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/run-tests b/python/run-tests
index f569a56fb7a9ab32610d90d04ee8aaab39aad7a8..f3a07d8aba56257c181c6c6726777bf32be3f173 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -21,6 +21,8 @@
 # Figure out where the Spark framework is installed
 FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)"
 
+. "$FWDIR"/bin/load-spark-env.sh
+
 # CD into the python directory to find things on the right path
 cd "$FWDIR/python"
 
@@ -57,7 +59,7 @@ function run_core_tests() {
     PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
     PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
     run_test "pyspark/serializers.py"
-    run_test "pyspark/profiler.py" 
+    run_test "pyspark/profiler.py"
     run_test "pyspark/shuffle.py"
     run_test "pyspark/tests.py"
 }
@@ -97,6 +99,21 @@ function run_ml_tests() {
 
 function run_streaming_tests() {
     echo "Run streaming tests ..."
+
+    KAFKA_ASSEMBLY_DIR="$FWDIR"/external/kafka-assembly
+    JAR_PATH="${KAFKA_ASSEMBLY_DIR}/target/scala-${SPARK_SCALA_VERSION}"
+    for f in "${JAR_PATH}"/spark-streaming-kafka-assembly-*.jar; do
+      if [[ ! -e "$f" ]]; then
+        echo "Failed to find Spark Streaming Kafka assembly jar in $KAFKA_ASSEMBLY_DIR" 1>&2
+        echo "You need to build Spark with " \
+             "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or" \
+             "'build/mvn package' before running this program" 1>&2
+        exit 1
+      fi
+      KAFKA_ASSEMBLY_JAR="$f"
+    done
+
+    export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell"
     run_test "pyspark/streaming/util.py"
     run_test "pyspark/streaming/tests.py"
 }