diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
index 6dc4e9517d5a4aaca3bac716e88b7e1c989cd316..b608b75952721b5fce13bd3f04e839f6d4f6308b 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
@@ -195,6 +195,8 @@ private class KafkaTestUtils extends Logging {
     val props = new Properties()
     props.put("metadata.broker.list", brokerAddress)
     props.put("serializer.class", classOf[StringEncoder].getName)
+    // wait for all in-sync replicas to ack sends
+    props.put("request.required.acks", "-1")
     props
   }
 
@@ -229,21 +231,6 @@ private class KafkaTestUtils extends Logging {
     tryAgain(1)
   }
 
-  /** Wait until the leader offset for the given topic/partition equals the specified offset */
-  def waitUntilLeaderOffset(
-      topic: String,
-      partition: Int,
-      offset: Long): Unit = {
-    eventually(Time(10000), Time(100)) {
-      val kc = new KafkaCluster(Map("metadata.broker.list" -> brokerAddress))
-      val tp = TopicAndPartition(topic, partition)
-      val llo = kc.getLatestLeaderOffsets(Set(tp)).right.get.apply(tp).offset
-      assert(
-        llo == offset,
-        s"$topic $partition $offset not reached after timeout")
-    }
-  }
-
   private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = {
     def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match {
       case Some(partitionState) =>
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
index 5cf379635354f96f15b51e1e1692c6842b0d86e7..a9dc6e50613ca58ac6de67bae7f51d556ed1308b 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
@@ -72,9 +72,6 @@ public class JavaKafkaRDDSuite implements Serializable {
     HashMap<String, String> kafkaParams = new HashMap<String, String>();
     kafkaParams.put("metadata.broker.list", kafkaTestUtils.brokerAddress());
 
-    kafkaTestUtils.waitUntilLeaderOffset(topic1, 0, topic1data.length);
-    kafkaTestUtils.waitUntilLeaderOffset(topic2, 0, topic2data.length);
-
     OffsetRange[] offsetRanges = {
       OffsetRange.create(topic1, 0, 0, 1),
       OffsetRange.create(topic2, 0, 0, 1)
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
index 054487269a9359fae9736b727490ab32a1475d76..d5baf5fd899947a7b34193455881f92298cb13bd 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
@@ -61,8 +61,6 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
     val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress,
       "group.id" -> s"test-consumer-${Random.nextInt}")
 
-    kafkaTestUtils.waitUntilLeaderOffset(topic, 0, messages.size)
-
     val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size))
 
     val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder](
@@ -86,7 +84,6 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
     // this is the "lots of messages" case
     kafkaTestUtils.sendMessages(topic, sent)
     val sentCount = sent.values.sum
-    kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount)
 
     // rdd defined from leaders after sending messages, should get the number sent
     val rdd = getRdd(kc, Set(topic))
@@ -113,7 +110,6 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
     val sentOnlyOne = Map("d" -> 1)
 
     kafkaTestUtils.sendMessages(topic, sentOnlyOne)
-    kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sentCount + 1)
 
     assert(rdd2.isDefined)
     assert(rdd2.get.count === 0, "got messages when there shouldn't be any")
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 73e4bfd78e577f0f4436afac7b17f7399ffbf2da..8a93ca29995101f61ac28e35557454a2a1cad713 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -47,6 +47,9 @@ object MimaExcludes {
             // Mima false positive (was a private[spark] class)
             ProblemFilters.exclude[MissingClassProblem](
               "org.apache.spark.util.collection.PairIterator"),
+            // Removing a testing method from a private class
+            ProblemFilters.exclude[MissingMethodProblem](
+              "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"),
             // SQL execution is considered private.
             excludePackage("org.apache.spark.sql.execution")
           )
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 46cb18b2e8ef98b1f2d650b0492a1da4be87d898..57049beea4dbacffb91475ef273db5bc90fb9f20 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -615,7 +615,6 @@ class KafkaStreamTests(PySparkStreamingTestCase):
 
         self._kafkaTestUtils.createTopic(topic)
         self._kafkaTestUtils.sendMessages(topic, sendData)
-        self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
 
         stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(),
                                          "test-streaming-consumer", {topic: 1},
@@ -631,7 +630,6 @@ class KafkaStreamTests(PySparkStreamingTestCase):
 
         self._kafkaTestUtils.createTopic(topic)
         self._kafkaTestUtils.sendMessages(topic, sendData)
-        self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
 
         stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams)
         self._validateStreamResult(sendData, stream)
@@ -646,7 +644,6 @@ class KafkaStreamTests(PySparkStreamingTestCase):
 
         self._kafkaTestUtils.createTopic(topic)
         self._kafkaTestUtils.sendMessages(topic, sendData)
-        self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
 
         stream = KafkaUtils.createDirectStream(self.ssc, [topic], kafkaParams, fromOffsets)
         self._validateStreamResult(sendData, stream)
@@ -661,7 +658,6 @@ class KafkaStreamTests(PySparkStreamingTestCase):
 
         self._kafkaTestUtils.createTopic(topic)
         self._kafkaTestUtils.sendMessages(topic, sendData)
-        self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
         rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges)
         self._validateRddResult(sendData, rdd)
 
@@ -677,7 +673,6 @@ class KafkaStreamTests(PySparkStreamingTestCase):
 
         self._kafkaTestUtils.createTopic(topic)
         self._kafkaTestUtils.sendMessages(topic, sendData)
-        self._kafkaTestUtils.waitUntilLeaderOffset(topic, 0, sum(sendData.values()))
         rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders)
         self._validateRddResult(sendData, rdd)