diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml
index f695cff410a18e9bf2fa71c0970f26863d9ce244..243ce6eaca6584c16c97e99f98d6592de872e5ce 100644
--- a/external/kafka/pom.xml
+++ b/external/kafka/pom.xml
@@ -44,7 +44,7 @@
     <dependency>
       <groupId>org.apache.kafka</groupId>
       <artifactId>kafka_${scala.binary.version}</artifactId>
-      <version>0.8.1.1</version>
+      <version>0.8.2.1</version>
       <exclusions>
         <exclusion>
           <groupId>com.sun.jmx</groupId>
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
index bd767031c18493e35d29a3daf7a861430f28d5c1..6cf254a7b69cbd9f6acb4fa4465c1df340193d49 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
@@ -20,9 +20,10 @@ package org.apache.spark.streaming.kafka
 import scala.util.control.NonFatal
 import scala.util.Random
 import scala.collection.mutable.ArrayBuffer
+import scala.collection.JavaConverters._
 import java.util.Properties
 import kafka.api._
-import kafka.common.{ErrorMapping, OffsetMetadataAndError, TopicAndPartition}
+import kafka.common.{ErrorMapping, OffsetAndMetadata, OffsetMetadataAndError, TopicAndPartition}
 import kafka.consumer.{ConsumerConfig, SimpleConsumer}
 import org.apache.spark.SparkException
 
@@ -220,12 +221,22 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable {
   // https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-OffsetCommit/FetchAPI
   // scalastyle:on
 
+  // this 0 here indicates api version, in this case the original ZK backed api.
+  private def defaultConsumerApiVersion: Short = 0
+
   /** Requires Kafka >= 0.8.1.1 */
   def getConsumerOffsets(
       groupId: String,
       topicAndPartitions: Set[TopicAndPartition]
+    ): Either[Err, Map[TopicAndPartition, Long]] =
+    getConsumerOffsets(groupId, topicAndPartitions, defaultConsumerApiVersion)
+
+  def getConsumerOffsets(
+      groupId: String,
+      topicAndPartitions: Set[TopicAndPartition],
+      consumerApiVersion: Short
     ): Either[Err, Map[TopicAndPartition, Long]] = {
-    getConsumerOffsetMetadata(groupId, topicAndPartitions).right.map { r =>
+    getConsumerOffsetMetadata(groupId, topicAndPartitions, consumerApiVersion).right.map { r =>
       r.map { kv =>
         kv._1 -> kv._2.offset
       }
@@ -236,9 +247,16 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable {
   def getConsumerOffsetMetadata(
       groupId: String,
       topicAndPartitions: Set[TopicAndPartition]
+    ): Either[Err, Map[TopicAndPartition, OffsetMetadataAndError]] =
+    getConsumerOffsetMetadata(groupId, topicAndPartitions, defaultConsumerApiVersion)
+
+  def getConsumerOffsetMetadata(
+      groupId: String,
+      topicAndPartitions: Set[TopicAndPartition],
+      consumerApiVersion: Short
     ): Either[Err, Map[TopicAndPartition, OffsetMetadataAndError]] = {
     var result = Map[TopicAndPartition, OffsetMetadataAndError]()
-    val req = OffsetFetchRequest(groupId, topicAndPartitions.toSeq)
+    val req = OffsetFetchRequest(groupId, topicAndPartitions.toSeq, consumerApiVersion)
     val errs = new Err
     withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer =>
       val resp = consumer.fetchOffsets(req)
@@ -266,24 +284,39 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable {
   def setConsumerOffsets(
       groupId: String,
       offsets: Map[TopicAndPartition, Long]
+    ): Either[Err, Map[TopicAndPartition, Short]] =
+    setConsumerOffsets(groupId, offsets, defaultConsumerApiVersion)
+
+  def setConsumerOffsets(
+      groupId: String,
+      offsets: Map[TopicAndPartition, Long],
+      consumerApiVersion: Short
     ): Either[Err, Map[TopicAndPartition, Short]] = {
-    setConsumerOffsetMetadata(groupId, offsets.map { kv =>
-      kv._1 -> OffsetMetadataAndError(kv._2)
-    })
+    val meta = offsets.map { kv =>
+      kv._1 -> OffsetAndMetadata(kv._2)
+    }
+    setConsumerOffsetMetadata(groupId, meta, consumerApiVersion)
   }
 
   /** Requires Kafka >= 0.8.1.1 */
   def setConsumerOffsetMetadata(
       groupId: String,
-      metadata: Map[TopicAndPartition, OffsetMetadataAndError]
+      metadata: Map[TopicAndPartition, OffsetAndMetadata]
+    ): Either[Err, Map[TopicAndPartition, Short]] =
+    setConsumerOffsetMetadata(groupId, metadata, defaultConsumerApiVersion)
+
+  def setConsumerOffsetMetadata(
+      groupId: String,
+      metadata: Map[TopicAndPartition, OffsetAndMetadata],
+      consumerApiVersion: Short
     ): Either[Err, Map[TopicAndPartition, Short]] = {
     var result = Map[TopicAndPartition, Short]()
-    val req = OffsetCommitRequest(groupId, metadata)
+    val req = OffsetCommitRequest(groupId, metadata, consumerApiVersion)
     val errs = new Err
     val topicAndPartitions = metadata.keySet
     withBrokers(Random.shuffle(config.seedBrokers), errs) { consumer =>
       val resp = consumer.commitOffsets(req)
-      val respMap = resp.requestInfo
+      val respMap = resp.commitStatus
       val needed = topicAndPartitions.diff(result.keySet)
       needed.foreach { tp: TopicAndPartition =>
         respMap.get(tp).foreach { err: Short =>
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
index 13e947506597987d72c02f09f1bffa116469b6e7..6dc4e9517d5a4aaca3bac716e88b7e1c989cd316 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
@@ -29,10 +29,12 @@ import scala.language.postfixOps
 import scala.util.control.NonFatal
 
 import kafka.admin.AdminUtils
+import kafka.api.Request
+import kafka.common.TopicAndPartition
 import kafka.producer.{KeyedMessage, Producer, ProducerConfig}
 import kafka.serializer.StringEncoder
 import kafka.server.{KafkaConfig, KafkaServer}
-import kafka.utils.ZKStringSerializer
+import kafka.utils.{ZKStringSerializer, ZkUtils}
 import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer}
 import org.I0Itec.zkclient.ZkClient
 
@@ -227,12 +229,35 @@ private class KafkaTestUtils extends Logging {
     tryAgain(1)
   }
 
-  private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = {
+  /** Wait until the leader offset for the given topic/partition equals the specified offset */
+  def waitUntilLeaderOffset(
+      topic: String,
+      partition: Int,
+      offset: Long): Unit = {
     eventually(Time(10000), Time(100)) {
+      val kc = new KafkaCluster(Map("metadata.broker.list" -> brokerAddress))
+      val tp = TopicAndPartition(topic, partition)
+      val llo = kc.getLatestLeaderOffsets(Set(tp)).right.get.apply(tp).offset
       assert(
-        server.apis.metadataCache.containsTopicAndPartition(topic, partition),
-        s"Partition [$topic, $partition] metadata not propagated after timeout"
-      )
+        llo == offset,
+        s"$topic $partition $offset not reached after timeout")
+    }
+  }
+
+  private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = {
+    def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match {
+      case Some(partitionState) =>
+        val leaderAndInSyncReplicas = partitionState.leaderIsrAndControllerEpoch.leaderAndIsr
+
+        ZkUtils.getLeaderForPartition(zkClient, topic, partition).isDefined &&
+          Request.isValidBrokerId(leaderAndInSyncReplicas.leader) &&
+          leaderAndInSyncReplicas.isr.size >= 1
+
+      case _ =>
+        false
+    }
+    eventually(Time(10000), Time(100)) {
+      assert(isPropagated, s"Partition [$topic, $partition] metadata not propagated after timeout")
     }
   }
 
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
index a9dc6e50613ca58ac6de67bae7f51d556ed1308b..5cf379635354f96f15b51e1e1692c6842b0d86e7 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
@@ -72,6 +72,9 @@ public class JavaKafkaRDDSuite implements Serializable {
     HashMap<String, String> kafkaParams = new HashMap<String, String>();
     kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress());
 
+    kafkaTestUtils.waitUntilLeaderOffset(topic1, 0, topic1data.length);
+    kafkaTestUtils.waitUntilLeaderOffset(topic2, 0, topic2data.length);
+
     OffsetRange[] offsetRanges = {
       OffsetRange.create(topic1, 0, 0, 1),
       OffsetRange.create(topic2, 0, 0, 1)
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
index 7d26ce50875b363810c3e69e5d14c0bcf29db430..39c3fb448ff5771d53089c7063b184ae52808e20 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
@@ -53,14 +53,15 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll {
   }
 
   test("basic usage") {
-    val topic = "topicbasic"
+    val topic = s"topicbasic-${Random.nextInt}"
     kafkaTestUtils.createTopic(topic)
     val messages = Set("the", "quick", "brown", "fox")
     kafkaTestUtils.sendMessages(topic, messages.toArray)
 
-
     val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress,
-      "group.id" -> s"test-consumer-${Random.nextInt(10000)}")
+      "group.id" -> s"test-consumer-${Random.nextInt}")
+
+    kafkaTestUtils.waitUntilLeaderOffset(topic, 0, messages.size)
 
     val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size))
 
@@ -73,27 +74,38 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll {
 
   test("iterator boundary conditions") {
     // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd
-    val topic = "topic1"
+    val topic = s"topicboundary-${Random.nextInt}"
     val sent = Map("a" -> 5, "b" -> 3, "c" -> 10)
     kafkaTestUtils.createTopic(topic)
 
     val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress,
-      "group.id" -> s"test-consumer-${Random.nextInt(10000)}")
+      "group.id" -> s"test-consumer-${Random.nextInt}")
 
     val kc = new KafkaCluster(kafkaParams)
 
     // this is the "lots of messages" case
     kafkaTestUtils.sendMessages(topic, sent)
+    val sentCount = sent.values.sum
+    kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount)
+
     // rdd defined from leaders after sending messages, should get the number sent
     val rdd = getRdd(kc, Set(topic))
 
     assert(rdd.isDefined)
-    assert(rdd.get.count === sent.values.sum, "didn't get all sent messages")
 
-    val ranges = rdd.get.asInstanceOf[HasOffsetRanges]
-      .offsetRanges.map(o => TopicAndPartition(o.topic, o.partition) -> o.untilOffset).toMap
+    val ranges = rdd.get.asInstanceOf[HasOffsetRanges].offsetRanges
+    val rangeCount = ranges.map(o => o.untilOffset - o.fromOffset).sum
 
-    kc.setConsumerOffsets(kafkaParams("group.id"), ranges)
+    assert(rangeCount === sentCount, "offset range didn't include all sent messages")
+    assert(rdd.get.count === sentCount, "didn't get all sent messages")
+
+    val rangesMap = ranges.map(o => TopicAndPartition(o.topic, o.partition) -> o.untilOffset).toMap
+
+    // make sure consumer offsets are committed before the next getRdd call
+    kc.setConsumerOffsets(kafkaParams("group.id"), rangesMap).fold(
+      err => throw new Exception(err.mkString("\n")),
+      _ => ()
+    )
 
     // this is the "0 messages" case
     val rdd2 = getRdd(kc, Set(topic))
@@ -101,6 +113,8 @@ class KafkaRDDSuite extends FunSuite with BeforeAndAfterAll {
     val sentOnlyOne = Map("d" -> 1)
 
     kafkaTestUtils.sendMessages(topic, sentOnlyOne)
+    kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount + 1)
+
     assert(rdd2.isDefined)
     assert(rdd2.get.count === 0, "got messages when there shouldn't be any")
 
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 7c06c203455d938a116d7b15a7e87419a0960668..33ea8c9293d74a09ee9a87a33010df2a28b14efa 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -606,7 +606,6 @@ class KafkaStreamTests(PySparkStreamingTestCase):
         result = {}
         for i in rdd.map(lambda x: x[1]).collect():
             result[i] = result.get(i, 0) + 1
-
         self.assertEqual(sendData, result)
 
     def test_kafka_stream(self):
@@ -616,6 +615,7 @@ class KafkaStreamTests(PySparkStreamingTestCase):
 
         self._kafkaTestUtils.createTopic(topic)
         self._kafkaTestUtils.sendMessages(topic, sendData)
+        self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
 
         stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(),
                                          "test-streaming-consumer", {topic: 1},
@@ -631,6 +631,7 @@ class KafkaStreamTests(PySparkStreamingTestCase):
 
         self._kafkaTestUtils.createTopic(topic)
         self._kafkaTestUtils.sendMessages(topic, sendData)
+        self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
 
         stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
         self._validateStreamResult(sendData, stream)
@@ -645,6 +646,7 @@ class KafkaStreamTests(PySparkStreamingTestCase):
 
         self._kafkaTestUtils.createTopic(topic)
         self._kafkaTestUtils.sendMessages(topic, sendData)
+        self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
 
         stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets)
         self._validateStreamResult(sendData, stream)
@@ -659,7 +661,7 @@ class KafkaStreamTests(PySparkStreamingTestCase):
 
         self._kafkaTestUtils.createTopic(topic)
         self._kafkaTestUtils.sendMessages(topic, sendData)
-
+        self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
         rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges)
         self._validateRddResult(sendData, rdd)
 
@@ -675,7 +677,7 @@ class KafkaStreamTests(PySparkStreamingTestCase):
 
         self._kafkaTestUtils.createTopic(topic)
         self._kafkaTestUtils.sendMessages(topic, sendData)
-
+        self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
         rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders)
         self._validateRddResult(sendData, rdd)