diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
index 54d8c8b03f2060d4cf0fec395a1c7f235f37d2dd..0eaaf408c0112962b34735a9551935b8d6e4f0d3 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
@@ -89,23 +89,32 @@ class DirectKafkaInputDStream[
 
   private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt(
       "spark.streaming.kafka.maxRatePerPartition", 0)
-  protected def maxMessagesPerPartition: Option[Long] = {
+
+  protected[streaming] def maxMessagesPerPartition(
+      offsets: Map[TopicAndPartition, Long]): Option[Map[TopicAndPartition, Long]] = {
     val estimatedRateLimit = rateController.map(_.getLatestRate().toInt)
-    val numPartitions = currentOffsets.keys.size
-
-    val effectiveRateLimitPerPartition = estimatedRateLimit
-      .filter(_ > 0)
-      .map { limit =>
-        if (maxRateLimitPerPartition > 0) {
-          Math.min(maxRateLimitPerPartition, (limit / numPartitions))
-        } else {
-          limit / numPartitions
+
+    // calculate a per-partition rate limit based on current lag
+    val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match {
+      case Some(rate) =>
+        val lagPerPartition = offsets.map { case (tp, offset) =>
+          tp -> Math.max(offset - currentOffsets(tp), 0)
+        }
+        val totalLag = lagPerPartition.values.sum
+
+        lagPerPartition.map { case (tp, lag) =>
+          val backpressureRate = Math.round(lag / totalLag.toFloat * rate)
+          tp -> (if (maxRateLimitPerPartition > 0) {
+            Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate)
         }
-      }.getOrElse(maxRateLimitPerPartition)
+      case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition }
+    }
 
-    if (effectiveRateLimitPerPartition > 0) {
+    if (effectiveRateLimitPerPartition.values.sum > 0) {
       val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000
-      Some((secsPerBatch * effectiveRateLimitPerPartition).toLong)
+      Some(effectiveRateLimitPerPartition.map {
+        case (tp, limit) => tp -> (secsPerBatch * limit).toLong
+      })
     } else {
       None
     }
@@ -134,9 +143,12 @@ class DirectKafkaInputDStream[
   // limits the maximum number of messages per partition
   protected def clamp(
     leaderOffsets: Map[TopicAndPartition, LeaderOffset]): Map[TopicAndPartition, LeaderOffset] = {
-    maxMessagesPerPartition.map { mmp =>
-      leaderOffsets.map { case (tp, lo) =>
-        tp -> lo.copy(offset = Math.min(currentOffsets(tp) + mmp, lo.offset))
+    val offsets = leaderOffsets.mapValues(lo => lo.offset)
+
+    maxMessagesPerPartition(offsets).map { mmp =>
+      mmp.map { case (tp, messages) =>
+        val lo = leaderOffsets(tp)
+        tp -> lo.copy(offset = Math.min(currentOffsets(tp) + messages, lo.offset))
       }
     }.getOrElse(leaderOffsets)
   }
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
index a76fa6671a4b08e293a052f55859c4f7895eb923..a5ea1d6d2848d3ce6b5f04eb2654e8af5aa43502 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
@@ -152,12 +152,15 @@ private[kafka] class KafkaTestUtils extends Logging {
   }
 
   /** Create a Kafka topic and wait until it is propagated to the whole cluster */
-  def createTopic(topic: String): Unit = {
-    AdminUtils.createTopic(zkClient, topic, 1, 1)
+  def createTopic(topic: String, partitions: Int): Unit = {
+    AdminUtils.createTopic(zkClient, topic, partitions, 1)
     // wait until metadata is propagated
-    waitUntilMetadataIsPropagated(topic, 0)
+    (0 until partitions).foreach { p => waitUntilMetadataIsPropagated(topic, p) }
   }
 
+  /** Single-argument version for backwards compatibility */
+  def createTopic(topic: String): Unit = createTopic(topic, 1)
+
   /** Java-friendly function for sending messages to the Kafka broker */
   def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = {
     sendMessages(topic, Map(messageToFreq.asScala.mapValues(_.intValue()).toSeq: _*))
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
index 4891e4f4a17bc5d4dd3fd38efb6104c7b51d6599..fa6b0dbc8c2197b4087dd717246966a366913571 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
@@ -168,7 +168,7 @@ public class JavaDirectKafkaStreamSuite implements Serializable {
 
   private  String[] createTopicAndSendData(String topic) {
     String[] data = { topic + "-1", topic + "-2", topic + "-3"};
-    kafkaTestUtils.createTopic(topic);
+    kafkaTestUtils.createTopic(topic, 1);
     kafkaTestUtils.sendMessages(topic, data);
     return data;
   }
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
index afcc6cfccd39a8ade454e22cf226d15c8d7bc728..c41b6297b0481a78d80a234fa239cfe090b81e58 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
@@ -149,7 +149,7 @@ public class JavaKafkaRDDSuite implements Serializable {
 
   private  String[] createTopicAndSendData(String topic) {
     String[] data = { topic + "-1", topic + "-2", topic + "-3"};
-    kafkaTestUtils.createTopic(topic);
+    kafkaTestUtils.createTopic(topic, 1);
     kafkaTestUtils.sendMessages(topic, data);
     return data;
   }
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
index 617c92a008fc54c2ac2a3f421201dfca424dab19..868df64e8c94449ad5cd39d1c85f9ba68b636fd7 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
@@ -76,7 +76,7 @@ public class JavaKafkaStreamSuite implements Serializable {
     sent.put("b", 3);
     sent.put("c", 10);
 
-    kafkaTestUtils.createTopic(topic);
+    kafkaTestUtils.createTopic(topic, 1);
     kafkaTestUtils.sendMessages(topic, sent);
 
     Map<String, String> kafkaParams = new HashMap<>();
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
index 8398178e9b79b8f3e7fc5752ee070425b6dd6330..b2c81d1534ee687b005c83098d42b227dbdca1bb 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
@@ -353,10 +353,38 @@ class DirectKafkaStreamSuite
     ssc.stop()
   }
 
+  test("maxMessagesPerPartition with backpressure disabled") {
+    val topic = "maxMessagesPerPartition"
+    val kafkaStream = getDirectKafkaStream(topic, None)
+
+    val input = Map(TopicAndPartition(topic, 0) -> 50L, TopicAndPartition(topic, 1) -> 50L)
+    assert(kafkaStream.maxMessagesPerPartition(input).get ==
+      Map(TopicAndPartition(topic, 0) -> 10L, TopicAndPartition(topic, 1) -> 10L))
+  }
+
+  test("maxMessagesPerPartition with no lag") {
+    val topic = "maxMessagesPerPartition"
+    val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 100))
+    val kafkaStream = getDirectKafkaStream(topic, rateController)
+
+    val input = Map(TopicAndPartition(topic, 0) -> 0L, TopicAndPartition(topic, 1) -> 0L)
+    assert(kafkaStream.maxMessagesPerPartition(input).isEmpty)
+  }
+
+  test("maxMessagesPerPartition respects max rate") {
+    val topic = "maxMessagesPerPartition"
+    val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 1000))
+    val kafkaStream = getDirectKafkaStream(topic, rateController)
+
+    val input = Map(TopicAndPartition(topic, 0) -> 1000L, TopicAndPartition(topic, 1) -> 1000L)
+    assert(kafkaStream.maxMessagesPerPartition(input).get ==
+      Map(TopicAndPartition(topic, 0) -> 10L, TopicAndPartition(topic, 1) -> 10L))
+  }
+
   test("using rate controller") {
     val topic = "backpressure"
-    val topicPartition = TopicAndPartition(topic, 0)
-    kafkaTestUtils.createTopic(topic)
+    val topicPartitions = Set(TopicAndPartition(topic, 0), TopicAndPartition(topic, 1))
+    kafkaTestUtils.createTopic(topic, 2)
     val kafkaParams = Map(
       "metadata.broker.list" -> kafkaTestUtils.brokerAddress,
       "auto.offset.reset" -> "smallest"
@@ -364,8 +392,8 @@ class DirectKafkaStreamSuite
 
     val batchIntervalMilliseconds = 100
     val estimator = new ConstantEstimator(100)
-    val messageKeys = (1 to 200).map(_.toString)
-    val messages = messageKeys.map((_, 1)).toMap
+    val messages = Map("foo" -> 200)
+    kafkaTestUtils.sendMessages(topic, messages)
 
     val sparkConf = new SparkConf()
       // Safe, even with streaming, because we're using the direct API.
@@ -380,11 +408,11 @@ class DirectKafkaStreamSuite
     val kafkaStream = withClue("Error creating direct stream") {
       val kc = new KafkaCluster(kafkaParams)
       val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message)
-      val m = kc.getEarliestLeaderOffsets(Set(topicPartition))
+      val m = kc.getEarliestLeaderOffsets(topicPartitions)
         .fold(e => Map.empty[TopicAndPartition, Long], m => m.mapValues(lo => lo.offset))
 
       new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)](
-        ssc, kafkaParams, m, messageHandler) {
+          ssc, kafkaParams, m, messageHandler) {
         override protected[streaming] val rateController =
           Some(new DirectKafkaRateController(id, estimator))
       }
@@ -405,13 +433,12 @@ class DirectKafkaStreamSuite
     ssc.start()
 
     // Try different rate limits.
-    // Send data to Kafka and wait for arrays of data to appear matching the rate.
+    // Wait for arrays of data to appear matching the rate.
     Seq(100, 50, 20).foreach { rate =>
       collectedData.clear()       // Empty this buffer on each pass.
       estimator.updateRate(rate)  // Set a new rate.
       // Expect blocks of data equal to "rate", scaled by the interval length in secs.
       val expectedSize = Math.round(rate * batchIntervalMilliseconds * 0.001)
-      kafkaTestUtils.sendMessages(topic, messages)
       eventually(timeout(5.seconds), interval(batchIntervalMilliseconds.milliseconds)) {
         // Assert that rate estimator values are used to determine maxMessagesPerPartition.
         // Funky "-" in message makes the complete assertion message read better.
@@ -430,6 +457,25 @@ class DirectKafkaStreamSuite
       rdd.asInstanceOf[KafkaRDD[K, V, _, _, (K, V)]].offsetRanges
     }.toSeq.sortBy { _._1 }
   }
+
+  private def getDirectKafkaStream(topic: String, mockRateController: Option[RateController]) = {
+    val batchIntervalMilliseconds = 100
+
+    val sparkConf = new SparkConf()
+      .setMaster("local[1]")
+      .setAppName(this.getClass.getSimpleName)
+      .set("spark.streaming.kafka.maxRatePerPartition", "100")
+
+    // Setup the streaming context
+    ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds))
+
+    val earliestOffsets = Map(TopicAndPartition(topic, 0) -> 0L, TopicAndPartition(topic, 1) -> 0L)
+    val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message)
+    new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)](
+      ssc, Map[String, String](), earliestOffsets, messageHandler) {
+      override protected[streaming] val rateController = mockRateController
+    }
+  }
 }
 
 object DirectKafkaStreamSuite {
@@ -468,3 +514,9 @@ private[streaming] class ConstantEstimator(@volatile private var rate: Long)
       processingDelay: Long,
       schedulingDelay: Long): Option[Double] = Some(rate)
 }
+
+private[streaming] class ConstantRateController(id: Int, estimator: RateEstimator, rate: Long)
+  extends RateController(id, estimator) {
+  override def publish(rate: Long): Unit = ()
+  override def getLatestRate(): Long = rate
+}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 9ce37fc753c46a20f34db4b1dd6ca75a55fa41b9..983f71684c38b46aa6d1de5ce3a9db9eb2ad18c1 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -288,6 +288,10 @@ object MimaExcludes {
         ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry"),
         ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$"),
         ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry$")
+      ) ++ Seq(
+        // SPARK-12073: backpressure rate controller consumes events preferentially from lagging partitions
+        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.KafkaTestUtils.createTopic"),
+        ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.DirectKafkaInputDStream.maxMessagesPerPartition")
       )
     case v if v.startsWith("1.6") =>
       Seq(