diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala
index 7e57bb18cbd50fe2a87e98467f1ad167f45948b0..794f53c5abfd04e7137bd9e2612172eb586c2b39 100644
--- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala
@@ -57,7 +57,8 @@ import org.apache.spark.streaming.scheduler.rate.RateEstimator
 private[spark] class DirectKafkaInputDStream[K, V](
     _ssc: StreamingContext,
     locationStrategy: LocationStrategy,
-    consumerStrategy: ConsumerStrategy[K, V]
+    consumerStrategy: ConsumerStrategy[K, V],
+    ppc: PerPartitionConfig
   ) extends InputDStream[ConsumerRecord[K, V]](_ssc) with Logging with CanCommitOffsets {
 
   val executorKafkaParams = {
@@ -128,12 +129,9 @@ private[spark] class DirectKafkaInputDStream[K, V](
     }
   }
 
-  private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt(
-    "spark.streaming.kafka.maxRatePerPartition", 0)
-
   protected[streaming] def maxMessagesPerPartition(
     offsets: Map[TopicPartition, Long]): Option[Map[TopicPartition, Long]] = {
-    val estimatedRateLimit = rateController.map(_.getLatestRate().toInt)
+    val estimatedRateLimit = rateController.map(_.getLatestRate())
 
     // calculate a per-partition rate limit based on current lag
     val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match {
@@ -144,11 +142,12 @@ private[spark] class DirectKafkaInputDStream[K, V](
         val totalLag = lagPerPartition.values.sum
 
         lagPerPartition.map { case (tp, lag) =>
+          val maxRateLimitPerPartition = ppc.maxRatePerPartition(tp)
           val backpressureRate = Math.round(lag / totalLag.toFloat * rate)
           tp -> (if (maxRateLimitPerPartition > 0) {
             Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate)
         }
-      case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition }
+      case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp) }
     }
 
     if (effectiveRateLimitPerPartition.values.sum > 0) {
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala
index b2190bfa05a3a0b3c880f0be30c42a971397c655..c11917f59d5b8e287f58dc3eae269a25aa5bcc7c 100644
--- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala
@@ -123,7 +123,31 @@ object KafkaUtils extends Logging {
       locationStrategy: LocationStrategy,
       consumerStrategy: ConsumerStrategy[K, V]
     ): InputDStream[ConsumerRecord[K, V]] = {
-    new DirectKafkaInputDStream[K, V](ssc, locationStrategy, consumerStrategy)
+    val ppc = new DefaultPerPartitionConfig(ssc.sparkContext.getConf)
+    createDirectStream[K, V](ssc, locationStrategy, consumerStrategy, ppc)
+  }
+
+  /**
+   * :: Experimental ::
+   * Scala constructor for a DStream where
+   * each given Kafka topic/partition corresponds to an RDD partition.
+   * @param locationStrategy In most cases, pass in LocationStrategies.preferConsistent,
+   *   see [[LocationStrategies]] for more details.
+   * @param consumerStrategy In most cases, pass in ConsumerStrategies.subscribe,
+   *   see [[ConsumerStrategies]] for more details.
+   * @param perPartitionConfig configuration of settings such as max rate on a per-partition basis.
+   *   see [[PerPartitionConfig]] for more details.
+   * @tparam K type of Kafka message key
+   * @tparam V type of Kafka message value
+   */
+  @Experimental
+  def createDirectStream[K, V](
+      ssc: StreamingContext,
+      locationStrategy: LocationStrategy,
+      consumerStrategy: ConsumerStrategy[K, V],
+      perPartitionConfig: PerPartitionConfig
+    ): InputDStream[ConsumerRecord[K, V]] = {
+    new DirectKafkaInputDStream[K, V](ssc, locationStrategy, consumerStrategy, perPartitionConfig)
   }
 
   /**
@@ -150,6 +174,33 @@ object KafkaUtils extends Logging {
         jssc.ssc, locationStrategy, consumerStrategy))
   }
 
+  /**
+   * :: Experimental ::
+   * Java constructor for a DStream where
+   * each given Kafka topic/partition corresponds to an RDD partition.
+   * @param keyClass Class of the keys in the Kafka records
+   * @param valueClass Class of the values in the Kafka records
+   * @param locationStrategy In most cases, pass in LocationStrategies.preferConsistent,
+   *   see [[LocationStrategies]] for more details.
+   * @param consumerStrategy In most cases, pass in ConsumerStrategies.subscribe,
+   *   see [[ConsumerStrategies]] for more details
+   * @param perPartitionConfig configuration of settings such as max rate on a per-partition basis.
+   *   see [[PerPartitionConfig]] for more details.
+   * @tparam K type of Kafka message key
+   * @tparam V type of Kafka message value
+   */
+  @Experimental
+  def createDirectStream[K, V](
+      jssc: JavaStreamingContext,
+      locationStrategy: LocationStrategy,
+      consumerStrategy: ConsumerStrategy[K, V],
+      perPartitionConfig: PerPartitionConfig
+    ): JavaInputDStream[ConsumerRecord[K, V]] = {
+    new JavaInputDStream(
+      createDirectStream[K, V](
+        jssc.ssc, locationStrategy, consumerStrategy, perPartitionConfig))
+  }
+
   /**
    * Tweak kafka params to prevent issues on executors
    */
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala
new file mode 100644
index 0000000000000000000000000000000000000000..4792f2a95511090de4e93a0e16a84d782c7fb9e5
--- /dev/null
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import org.apache.kafka.common.TopicPartition
+
+import org.apache.spark.SparkConf
+import org.apache.spark.annotation.Experimental
+
+/**
+ * :: Experimental ::
+ * Interface for user-supplied configurations that can't otherwise be set via Spark properties,
+ * because they need tweaking on a per-partition basis,
+ */
+@Experimental
+abstract class PerPartitionConfig extends Serializable {
+  /**
+   *  Maximum rate (number of records per second) at which data will be read
+   *  from each Kafka partition.
+   */
+  def maxRatePerPartition(topicPartition: TopicPartition): Long
+}
+
+/**
+ * Default per-partition configuration
+ */
+private class DefaultPerPartitionConfig(conf: SparkConf)
+    extends PerPartitionConfig {
+  val maxRate = conf.getLong("spark.streaming.kafka.maxRatePerPartition", 0)
+
+  def maxRatePerPartition(topicPartition: TopicPartition): Long = maxRate
+}
diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala
index 02aec43c3b34f7b527593ff48dc9a70697ae5f0a..f36e0a901f7b01d24b7a359f73e7158f38bc240a 100644
--- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala
+++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala
@@ -252,7 +252,8 @@ class DirectKafkaStreamSuite
       val s = new DirectKafkaInputDStream[String, String](
         ssc,
         preferredHosts,
-        ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala))
+        ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala),
+        new DefaultPerPartitionConfig(sparkConf))
       s.consumer.poll(0)
       assert(
         s.consumer.position(topicPartition) >= offsetBeforeStart,
@@ -306,7 +307,8 @@ class DirectKafkaStreamSuite
         ConsumerStrategies.Assign[String, String](
           List(topicPartition),
           kafkaParams.asScala,
-          Map(topicPartition -> 11L)))
+          Map(topicPartition -> 11L)),
+        new DefaultPerPartitionConfig(sparkConf))
       s.consumer.poll(0)
       assert(
         s.consumer.position(topicPartition) >= offsetBeforeStart,
@@ -518,7 +520,7 @@ class DirectKafkaStreamSuite
 
   test("maxMessagesPerPartition with backpressure disabled") {
     val topic = "maxMessagesPerPartition"
-    val kafkaStream = getDirectKafkaStream(topic, None)
+    val kafkaStream = getDirectKafkaStream(topic, None, None)
 
     val input = Map(new TopicPartition(topic, 0) -> 50L, new TopicPartition(topic, 1) -> 50L)
     assert(kafkaStream.maxMessagesPerPartition(input).get ==
@@ -528,7 +530,7 @@ class DirectKafkaStreamSuite
   test("maxMessagesPerPartition with no lag") {
     val topic = "maxMessagesPerPartition"
     val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 100))
-    val kafkaStream = getDirectKafkaStream(topic, rateController)
+    val kafkaStream = getDirectKafkaStream(topic, rateController, None)
 
     val input = Map(new TopicPartition(topic, 0) -> 0L, new TopicPartition(topic, 1) -> 0L)
     assert(kafkaStream.maxMessagesPerPartition(input).isEmpty)
@@ -537,11 +539,19 @@ class DirectKafkaStreamSuite
   test("maxMessagesPerPartition respects max rate") {
     val topic = "maxMessagesPerPartition"
     val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 1000))
-    val kafkaStream = getDirectKafkaStream(topic, rateController)
+    val ppc = Some(new PerPartitionConfig {
+      def maxRatePerPartition(tp: TopicPartition) =
+        if (tp.topic == topic && tp.partition == 0) {
+          50
+        } else {
+          100
+        }
+    })
+    val kafkaStream = getDirectKafkaStream(topic, rateController, ppc)
 
     val input = Map(new TopicPartition(topic, 0) -> 1000L, new TopicPartition(topic, 1) -> 1000L)
     assert(kafkaStream.maxMessagesPerPartition(input).get ==
-      Map(new TopicPartition(topic, 0) -> 10L, new TopicPartition(topic, 1) -> 10L))
+      Map(new TopicPartition(topic, 0) -> 5L, new TopicPartition(topic, 1) -> 10L))
   }
 
   test("using rate controller") {
@@ -570,7 +580,9 @@ class DirectKafkaStreamSuite
       new DirectKafkaInputDStream[String, String](
         ssc,
         preferredHosts,
-        ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala)) {
+        ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala),
+        new DefaultPerPartitionConfig(sparkConf)
+      ) {
         override protected[streaming] val rateController =
           Some(new DirectKafkaRateController(id, estimator))
       }.map(r => (r.key, r.value))
@@ -616,7 +628,10 @@ class DirectKafkaStreamSuite
     }.toSeq.sortBy { _._1 }
   }
 
-  private def getDirectKafkaStream(topic: String, mockRateController: Option[RateController]) = {
+  private def getDirectKafkaStream(
+      topic: String,
+      mockRateController: Option[RateController],
+      ppc: Option[PerPartitionConfig]) = {
     val batchIntervalMilliseconds = 100
 
     val sparkConf = new SparkConf()
@@ -643,7 +658,8 @@ class DirectKafkaStreamSuite
           tps.foreach(tp => consumer.seek(tp, 0))
           consumer
         }
-      }
+      },
+      ppc.getOrElse(new DefaultPerPartitionConfig(sparkConf))
     ) {
         override protected[streaming] val rateController = mockRateController
     }
diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
index c3c799375bbeb3ae6822e78dab4fe8f87122dc0c..d52c230eb78496c4251ffcf4795750c76454895c 100644
--- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
+++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
@@ -88,12 +88,12 @@ class DirectKafkaInputDStream[
 
   protected val kc = new KafkaCluster(kafkaParams)
 
-  private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt(
+  private val maxRateLimitPerPartition: Long = context.sparkContext.getConf.getLong(
       "spark.streaming.kafka.maxRatePerPartition", 0)
 
   protected[streaming] def maxMessagesPerPartition(
       offsets: Map[TopicAndPartition, Long]): Option[Map[TopicAndPartition, Long]] = {
-    val estimatedRateLimit = rateController.map(_.getLatestRate().toInt)
+    val estimatedRateLimit = rateController.map(_.getLatestRate())
 
     // calculate a per-partition rate limit based on current lag
     val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match {