diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala
index 6ae56a68ad88c340779bdc0e45a0ea6155add6b9..84a3ca9d74e58f626fed30a615774d349a6a96bd 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.streaming.scheduler.rate
 
+import org.apache.spark.Logging
+
 /**
  * Implements a proportional-integral-derivative (PID) controller which acts on
  * the speed of ingestion of elements into Spark Streaming. A PID controller works
@@ -26,7 +28,7 @@ package org.apache.spark.streaming.scheduler.rate
  *
  * @see https://en.wikipedia.org/wiki/PID_controller
  *
- * @param batchDurationMillis the batch duration, in milliseconds
+ * @param batchIntervalMillis the batch duration, in milliseconds
  * @param proportional how much the correction should depend on the current
  *        error. This term usually provides the bulk of correction and should be positive or zero.
  *        A value too large would make the controller overshoot the setpoint, while a small value
@@ -39,13 +41,17 @@ package org.apache.spark.streaming.scheduler.rate
  *        of future errors, based on current rate of change. This value should be positive or 0.
  *        This term is not used very often, as it impacts stability of the system. The default
  *        value is 0.
+ * @param minRate what is the minimum rate that can be estimated.
+ *        This must be greater than zero, so that the system always receives some data for rate
+ *        estimation to work.
  */
 private[streaming] class PIDRateEstimator(
     batchIntervalMillis: Long,
-    proportional: Double = 1D,
-    integral: Double = .2D,
-    derivative: Double = 0D)
-  extends RateEstimator {
+    proportional: Double,
+    integral: Double,
+    derivative: Double,
+    minRate: Double
+  ) extends RateEstimator with Logging {
 
   private var firstRun: Boolean = true
   private var latestTime: Long = -1L
@@ -64,16 +70,23 @@ private[streaming] class PIDRateEstimator(
   require(
     derivative >= 0,
     s"Derivative term $derivative in PIDRateEstimator should be >= 0.")
+  require(
+    minRate > 0,
+    s"Minimum rate in PIDRateEstimator should be > 0")
 
+  logInfo(s"Created PIDRateEstimator with proportional = $proportional, integral = $integral, " +
+    s"derivative = $derivative, min rate = $minRate")
 
-  def compute(time: Long, // in milliseconds
+  def compute(
+      time: Long, // in milliseconds
       numElements: Long,
       processingDelay: Long, // in milliseconds
       schedulingDelay: Long // in milliseconds
     ): Option[Double] = {
-
+    logTrace(s"\ntime = $time, # records = $numElements, " +
+      s"processing time = $processingDelay, scheduling delay = $schedulingDelay")
     this.synchronized {
-      if (time > latestTime && processingDelay > 0 && batchIntervalMillis > 0) {
+      if (time > latestTime && numElements > 0 && processingDelay > 0) {
 
         // in seconds, should be close to batchDuration
         val delaySinceUpdate = (time - latestTime).toDouble / 1000
@@ -104,21 +117,30 @@ private[streaming] class PIDRateEstimator(
 
         val newRate = (latestRate - proportional * error -
                                     integral * historicalError -
-                                    derivative * dError).max(0.0)
+                                    derivative * dError).max(minRate)
+        logTrace(s"""
+            | latestRate = $latestRate, error = $error
+            | latestError = $latestError, historicalError = $historicalError
+            | delaySinceUpdate = $delaySinceUpdate, dError = $dError
+            """.stripMargin)
+
         latestTime = time
         if (firstRun) {
           latestRate = processingRate
           latestError = 0D
           firstRun = false
-
+          logTrace("First run, rate estimation skipped")
           None
         } else {
           latestRate = newRate
           latestError = error
-
+          logTrace(s"New rate = $newRate")
           Some(newRate)
         }
-      } else None
+      } else {
+        logTrace("Rate estimation skipped")
+        None
+      }
     }
   }
 }
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala
index 17ccebc1ed41baedf48d89e396992303718ab0ed..d7210f64fcc362a07ebc23fb4a132f18c04bc246 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.streaming.scheduler.rate
 
 import org.apache.spark.SparkConf
-import org.apache.spark.SparkException
 import org.apache.spark.streaming.Duration
 
 /**
@@ -61,7 +60,8 @@ object RateEstimator {
         val proportional = conf.getDouble("spark.streaming.backpressure.pid.proportional", 1.0)
         val integral = conf.getDouble("spark.streaming.backpressure.pid.integral", 0.2)
         val derived = conf.getDouble("spark.streaming.backpressure.pid.derived", 0.0)
-        new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived)
+        val minRate = conf.getDouble("spark.streaming.backpressure.pid.minRate", 100)
+        new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived, minRate)
 
       case estimator =>
         throw new IllegalArgumentException(s"Unkown rate estimator: $estimator")
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala
index 97c32d8f2d59e9da219ee260623373f6fed3568c..a1af95be81c8e809e309dd34391d37a30527324c 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala
@@ -36,72 +36,89 @@ class PIDRateEstimatorSuite extends SparkFunSuite with Matchers {
 
   test("estimator checks ranges") {
     intercept[IllegalArgumentException] {
-      new PIDRateEstimator(0, 1, 2, 3)
+      new PIDRateEstimator(batchIntervalMillis = 0, 1, 2, 3, 10)
     }
     intercept[IllegalArgumentException] {
-      new PIDRateEstimator(100, -1, 2, 3)
+      new PIDRateEstimator(100, proportional = -1, 2, 3, 10)
     }
     intercept[IllegalArgumentException] {
-      new PIDRateEstimator(100, 0, -1, 3)
+      new PIDRateEstimator(100, 0, integral = -1, 3, 10)
     }
     intercept[IllegalArgumentException] {
-      new PIDRateEstimator(100, 0, 0, -1)
+      new PIDRateEstimator(100, 0, 0, derivative = -1, 10)
+    }
+    intercept[IllegalArgumentException] {
+      new PIDRateEstimator(100, 0, 0, 0, minRate = 0)
+    }
+    intercept[IllegalArgumentException] {
+      new PIDRateEstimator(100, 0, 0, 0, minRate = -10)
     }
   }
 
-  private def createDefaultEstimator: PIDRateEstimator = {
-    new PIDRateEstimator(20, 1D, 0D, 0D)
-  }
-
-  test("first bound is None") {
-    val p = createDefaultEstimator
+  test("first estimate is None") {
+    val p = createDefaultEstimator()
     p.compute(0, 10, 10, 0) should equal(None)
   }
 
-  test("second bound is rate") {
-    val p = createDefaultEstimator
+  test("second estimate is not None") {
+    val p = createDefaultEstimator()
     p.compute(0, 10, 10, 0)
     // 1000 elements / s
     p.compute(10, 10, 10, 0) should equal(Some(1000))
   }
 
-  test("works even with no time between updates") {
-    val p = createDefaultEstimator
+  test("no estimate when no time difference between successive calls") {
+    val p = createDefaultEstimator()
+    p.compute(0, 10, 10, 0)
+    p.compute(time = 10, 10, 10, 0) shouldNot equal(None)
+    p.compute(time = 10, 10, 10, 0) should equal(None)
+  }
+
+  test("no estimate when no records in previous batch") {
+    val p = createDefaultEstimator()
     p.compute(0, 10, 10, 0)
-    p.compute(10, 10, 10, 0)
-    p.compute(10, 10, 10, 0) should equal(None)
+    p.compute(10, numElements = 0, 10, 0) should equal(None)
+    p.compute(20, numElements = -10, 10, 0) should equal(None)
   }
 
-  test("bound is never negative") {
-    val p = new PIDRateEstimator(20, 1D, 1D, 0D)
+  test("no estimate when there is no processing delay") {
+    val p = createDefaultEstimator()
+    p.compute(0, 10, 10, 0)
+    p.compute(10, 10, processingDelay = 0, 0) should equal(None)
+    p.compute(20, 10, processingDelay = -10, 0) should equal(None)
+  }
+
+  test("estimate is never less than min rate") {
+    val minRate = 5D
+    val p = new PIDRateEstimator(20, 1D, 1D, 0D, minRate)
     // prepare a series of batch updates, one every 20ms, 0 processed elements, 2ms of processing
     // this might point the estimator to try and decrease the bound, but we test it never
-    // goes below zero, which would be nonsensical.
+    // goes below the min rate, which would be nonsensical.
     val times = List.tabulate(50)(x => x * 20) // every 20ms
-    val elements = List.fill(50)(0) // no processing
+    val elements = List.fill(50)(1) // no processing
     val proc = List.fill(50)(20) // 20ms of processing
     val sched = List.fill(50)(100) // strictly positive accumulation
     val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i))
     res.head should equal(None)
-    res.tail should equal(List.fill(49)(Some(0D)))
+    res.tail should equal(List.fill(49)(Some(minRate)))
   }
 
   test("with no accumulated or positive error, |I| > 0, follow the processing speed") {
-    val p = new PIDRateEstimator(20, 1D, 1D, 0D)
+    val p = new PIDRateEstimator(20, 1D, 1D, 0D, 10)
     // prepare a series of batch updates, one every 20ms with an increasing number of processed
     // elements in each batch, but constant processing time, and no accumulated error. Even though
     // the integral part is non-zero, the estimated rate should follow only the proportional term
     val times = List.tabulate(50)(x => x * 20) // every 20ms
-    val elements = List.tabulate(50)(x => x * 20) // increasing
+    val elements = List.tabulate(50)(x => (x + 1) * 20) // increasing
     val proc = List.fill(50)(20) // 20ms of processing
     val sched = List.fill(50)(0)
     val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i))
     res.head should equal(None)
-    res.tail should equal(List.tabulate(50)(x => Some(x * 1000D)).tail)
+    res.tail should equal(List.tabulate(50)(x => Some((x + 1) * 1000D)).tail)
   }
 
   test("with no accumulated but some positive error, |I| > 0, follow the processing speed") {
-    val p = new PIDRateEstimator(20, 1D, 1D, 0D)
+    val p = new PIDRateEstimator(20, 1D, 1D, 0D, 10)
     // prepare a series of batch updates, one every 20ms with an decreasing number of processed
     // elements in each batch, but constant processing time, and no accumulated error. Even though
     // the integral part is non-zero, the estimated rate should follow only the proportional term,
@@ -116,13 +133,14 @@ class PIDRateEstimatorSuite extends SparkFunSuite with Matchers {
   }
 
   test("with some accumulated and some positive error, |I| > 0, stay below the processing speed") {
-    val p = new PIDRateEstimator(20, 1D, .01D, 0D)
+    val minRate = 10D
+    val p = new PIDRateEstimator(20, 1D, .01D, 0D, minRate)
     val times = List.tabulate(50)(x => x * 20) // every 20ms
     val rng = new Random()
-    val elements = List.tabulate(50)(x => rng.nextInt(1000))
+    val elements = List.tabulate(50)(x => rng.nextInt(1000) + 1000)
     val procDelayMs = 20
     val proc = List.fill(50)(procDelayMs) // 20ms of processing
-    val sched = List.tabulate(50)(x => rng.nextInt(19)) // random wait
+    val sched = List.tabulate(50)(x => rng.nextInt(19) + 1) // random wait
     val speeds = elements map ((x) => x.toDouble / procDelayMs * 1000)
 
     val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i))
@@ -131,7 +149,12 @@ class PIDRateEstimatorSuite extends SparkFunSuite with Matchers {
       res(n) should not be None
       if (res(n).get > 0 && sched(n) > 0) {
         res(n).get should be < speeds(n)
+        res(n).get should be >= minRate
       }
     }
   }
+
+  private def createDefaultEstimator(): PIDRateEstimator = {
+    new PIDRateEstimator(20, 1D, 0D, 0D, 10)
+  }
 }