diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
index 48a1933d92f85950965900528df140076d531c1d..8a177077775c632a4b78572fbfc822ba880440ea 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
@@ -29,7 +29,8 @@ import org.apache.spark.{Logging, SparkException}
 import org.apache.spark.streaming.{StreamingContext, Time}
 import org.apache.spark.streaming.dstream._
 import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset
-import org.apache.spark.streaming.scheduler.StreamInputInfo
+import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo}
+import org.apache.spark.streaming.scheduler.rate.RateEstimator
 
 /**
  *  A stream of {@link org.apache.spark.streaming.kafka.KafkaRDD} where
@@ -61,7 +62,7 @@ class DirectKafkaInputDStream[
     val kafkaParams: Map[String, String],
     val fromOffsets: Map[TopicAndPartition, Long],
     messageHandler: MessageAndMetadata[K, V] => R
-) extends InputDStream[R](ssc_) with Logging {
+  ) extends InputDStream[R](ssc_) with Logging {
   val maxRetries = context.sparkContext.getConf.getInt(
     "spark.streaming.kafka.maxRetries", 1)
 
@@ -71,14 +72,35 @@ class DirectKafkaInputDStream[
   protected[streaming] override val checkpointData =
     new DirectKafkaInputDStreamCheckpointData
 
+
+  /**
+   * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker.
+   */
+  override protected[streaming] val rateController: Option[RateController] = {
+    if (RateController.isBackPressureEnabled(ssc.conf)) {
+      Some(new DirectKafkaRateController(id,
+        RateEstimator.create(ssc.conf, ssc_.graph.batchDuration)))
+    } else {
+      None
+    }
+  }
+
   protected val kc = new KafkaCluster(kafkaParams)
 
-  protected val maxMessagesPerPartition: Option[Long] = {
-    val ratePerSec = context.sparkContext.getConf.getInt(
+  private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt(
       "spark.streaming.kafka.maxRatePerPartition", 0)
-    if (ratePerSec > 0) {
+  protected def maxMessagesPerPartition: Option[Long] = {
+    val estimatedRateLimit = rateController.map(_.getLatestRate().toInt)
+    val numPartitions = currentOffsets.keys.size
+
+    val effectiveRateLimitPerPartition = estimatedRateLimit
+      .filter(_ > 0)
+      .map(limit => Math.min(maxRateLimitPerPartition, (limit / numPartitions)))
+      .getOrElse(maxRateLimitPerPartition)
+
+    if (effectiveRateLimitPerPartition > 0) {
       val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000
-      Some((secsPerBatch * ratePerSec).toLong)
+      Some((secsPerBatch * effectiveRateLimitPerPartition).toLong)
     } else {
       None
     }
@@ -170,11 +192,18 @@ class DirectKafkaInputDStream[
       val leaders = KafkaCluster.checkErrors(kc.findLeaders(topics))
 
       batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) =>
-          logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}")
-          generatedRDDs += t -> new KafkaRDD[K, V, U, T, R](
-            context.sparkContext, kafkaParams, b.map(OffsetRange(_)), leaders, messageHandler)
+         logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}")
+         generatedRDDs += t -> new KafkaRDD[K, V, U, T, R](
+           context.sparkContext, kafkaParams, b.map(OffsetRange(_)), leaders, messageHandler)
       }
     }
   }
 
+  /**
+   * A RateController to retrieve the rate from RateEstimator.
+   */
+  private[streaming] class DirectKafkaRateController(id: Int, estimator: RateEstimator)
+    extends RateController(id, estimator) {
+    override def publish(rate: Long): Unit = ()
+  }
 }
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
index 5b3c79444aa6848b4d6fe20578153007445ae6ce..02225d5aa7cc5d4d12f92009946086fa6b55b563 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
@@ -20,6 +20,9 @@ package org.apache.spark.streaming.kafka
 import java.io.File
 import java.util.concurrent.atomic.AtomicLong
 
+import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset
+import org.apache.spark.streaming.scheduler.rate.RateEstimator
+
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.concurrent.duration._
@@ -350,6 +353,77 @@ class DirectKafkaStreamSuite
     ssc.stop()
   }
 
+  test("using rate controller") {
+    val topic = "backpressure"
+    val topicPartition = TopicAndPartition(topic, 0)
+    kafkaTestUtils.createTopic(topic)
+    val kafkaParams = Map(
+      "metadata.broker.list" -> kafkaTestUtils.brokerAddress,
+      "auto.offset.reset" -> "smallest"
+    )
+
+    val batchIntervalMilliseconds = 100
+    val estimator = new ConstantEstimator(100)
+    val messageKeys = (1 to 200).map(_.toString)
+    val messages = messageKeys.map((_, 1)).toMap
+
+    val sparkConf = new SparkConf()
+      // Safe, even with streaming, because we're using the direct API.
+      // Using 1 core is useful to make the test more predictable.
+      .setMaster("local[1]")
+      .setAppName(this.getClass.getSimpleName)
+      .set("spark.streaming.kafka.maxRatePerPartition", "100")
+
+    // Setup the streaming context
+    ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds))
+
+    val kafkaStream = withClue("Error creating direct stream") {
+      val kc = new KafkaCluster(kafkaParams)
+      val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message)
+      val m = kc.getEarliestLeaderOffsets(Set(topicPartition))
+        .fold(e => Map.empty[TopicAndPartition, Long], m => m.mapValues(lo => lo.offset))
+
+      new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)](
+        ssc, kafkaParams, m, messageHandler) {
+        override protected[streaming] val rateController =
+          Some(new DirectKafkaRateController(id, estimator))
+      }
+    }
+
+    val collectedData =
+      new mutable.ArrayBuffer[Array[String]]() with mutable.SynchronizedBuffer[Array[String]]
+
+    // Used for assertion failure messages.
+    def dataToString: String =
+      collectedData.map(_.mkString("[", ",", "]")).mkString("{", ", ", "}")
+
+    // This is to collect the raw data received from Kafka
+    kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) =>
+      val data = rdd.map { _._2 }.collect()
+      collectedData += data
+    }
+
+    ssc.start()
+
+    // Try different rate limits.
+    // Send data to Kafka and wait for arrays of data to appear matching the rate.
+    Seq(100, 50, 20).foreach { rate =>
+      collectedData.clear()       // Empty this buffer on each pass.
+      estimator.updateRate(rate)  // Set a new rate.
+      // Expect blocks of data equal to "rate", scaled by the interval length in secs.
+      val expectedSize = Math.round(rate * batchIntervalMilliseconds * 0.001)
+      kafkaTestUtils.sendMessages(topic, messages)
+      eventually(timeout(5.seconds), interval(batchIntervalMilliseconds.milliseconds)) {
+        // Assert that rate estimator values are used to determine maxMessagesPerPartition.
+        // Funky "-" in message makes the complete assertion message read better.
+        assert(collectedData.exists(_.size == expectedSize),
+          s" - No arrays of size $expectedSize for rate $rate found in $dataToString")
+      }
+    }
+
+    ssc.stop()
+  }
+
   /** Get the generated offset ranges from the DirectKafkaStream */
   private def getOffsetRanges[K, V](
       kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = {
@@ -381,3 +455,18 @@ object DirectKafkaStreamSuite {
     }
   }
 }
+
+private[streaming] class ConstantEstimator(@volatile private var rate: Long)
+  extends RateEstimator {
+
+  def updateRate(newRate: Long): Unit = {
+    rate = newRate
+  }
+
+  def compute(
+      time: Long,
+      elements: Long,
+      processingDelay: Long,
+      schedulingDelay: Long): Option[Double] = Some(rate)
+}
+