diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala
index e4f6ba626ebbf3ce0c39e2a72d13aa52b4a07b29..97db9ded83367377be72d5cdddd99c3d5220070d 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.streaming.receiver
 
 import org.apache.spark.{Logging, SparkConf}
-import java.util.concurrent.TimeUnit._
+import com.google.common.util.concurrent.{RateLimiter=>GuavaRateLimiter}
 
 /** Provides waitToPush() method to limit the rate at which receivers consume data.
   *
@@ -33,37 +33,12 @@ import java.util.concurrent.TimeUnit._
   */
 private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging {
 
-  private var lastSyncTime = System.nanoTime
-  private var messagesWrittenSinceSync = 0L
   private val desiredRate = conf.getInt("spark.streaming.receiver.maxRate", 0)
-  private val SYNC_INTERVAL = NANOSECONDS.convert(10, SECONDS)
+  private lazy val rateLimiter = GuavaRateLimiter.create(desiredRate)
 
   def waitToPush() {
-    if( desiredRate <= 0 ) {
-      return
-    }
-    val now = System.nanoTime
-    val elapsedNanosecs = math.max(now - lastSyncTime, 1)
-    val rate = messagesWrittenSinceSync.toDouble * 1000000000 / elapsedNanosecs
-    if (rate < desiredRate) {
-      // It's okay to write; just update some variables and return
-      messagesWrittenSinceSync += 1
-      if (now > lastSyncTime + SYNC_INTERVAL) {
-        // Sync interval has passed; let's resync
-        lastSyncTime = now
-        messagesWrittenSinceSync = 1
-      }
-    } else {
-      // Calculate how much time we should sleep to bring ourselves to the desired rate.
-      val targetTimeInMillis = messagesWrittenSinceSync * 1000 / desiredRate
-      val elapsedTimeInMillis = elapsedNanosecs / 1000000
-      val sleepTimeInMillis = targetTimeInMillis - elapsedTimeInMillis
-      if (sleepTimeInMillis > 0) {
-        logTrace("Natural rate is " + rate + " per second but desired rate is " +
-          desiredRate + ", sleeping for " + sleepTimeInMillis + " ms to compensate.")
-        Thread.sleep(sleepTimeInMillis)
-      }
-      waitToPush()
+    if (desiredRate > 0) {
+      rateLimiter.acquire()
     }
   }
 }
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
index 91261a9db736070af6ac7466d905299e5e24fc80..e7aee6eadbfc7e1af8b5735e21f05376132b90b7 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
@@ -158,7 +158,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable {
   test("block generator throttling") {
     val blockGeneratorListener = new FakeBlockGeneratorListener
     val blockIntervalMs = 100
-    val maxRate = 100
+    val maxRate = 1001
     val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms").
       set("spark.streaming.receiver.maxRate", maxRate.toString)
     val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf)
@@ -176,7 +176,6 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable {
       blockGenerator.addData(count)
       generatedData += count
       count += 1
-      Thread.sleep(1)
     }
     blockGenerator.stop()
 
@@ -185,25 +184,31 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable {
     assert(blockGeneratorListener.arrayBuffers.size > 0, "No blocks received")
     assert(recordedData.toSet === generatedData.toSet, "Received data not same")
 
-    // recordedData size should be close to the expected rate
-    val minExpectedMessages = expectedMessages - 3
-    val maxExpectedMessages = expectedMessages + 1
+    // recordedData size should be close to the expected rate; use an error margin proportional to
+    // the value, so that rate changes don't cause a brittle test
+    val minExpectedMessages = expectedMessages - 0.05 * expectedMessages
+    val maxExpectedMessages = expectedMessages + 0.05 * expectedMessages
     val numMessages = recordedData.size
     assert(
       numMessages >= minExpectedMessages && numMessages <= maxExpectedMessages,
       s"#records received = $numMessages, not between $minExpectedMessages and $maxExpectedMessages"
     )
 
-    val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 3
-    val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 1
+    // XXX Checking every block would require an even distribution of messages across blocks,
+    // which throttling code does not control. Therefore, test against the average.
+    val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 0.05 * expectedMessagesPerBlock
+    val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 0.05 * expectedMessagesPerBlock
     val receivedBlockSizes = recordedBlocks.map { _.size }.mkString(",")
+
+    // the first and last block may be incomplete, so we slice them out
+    val validBlocks = recordedBlocks.drop(1).dropRight(1)
+    val averageBlockSize = validBlocks.map(block => block.size).sum / validBlocks.size
+
     assert(
-      // the first and last block may be incomplete, so we slice them out
-      recordedBlocks.drop(1).dropRight(1).forall { block =>
-        block.size >= minExpectedMessagesPerBlock && block.size <= maxExpectedMessagesPerBlock
-      },
+      averageBlockSize >= minExpectedMessagesPerBlock &&
+        averageBlockSize <= maxExpectedMessagesPerBlock,
       s"# records in received blocks = [$receivedBlockSizes], not between " +
-        s"$minExpectedMessagesPerBlock and $maxExpectedMessagesPerBlock"
+        s"$minExpectedMessagesPerBlock and $maxExpectedMessagesPerBlock, on average"
     )
   }