diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
index 5d9a8ac0d92971276760bf771b72d9097484e8ed..dacff69d55dd249cafa09a1dc6fa2d83cd9bcd42 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
@@ -193,12 +193,15 @@ private[streaming] class ReceivedBlockTracker(
       getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo
     }
 
-    // Insert the recovered block-to-batch allocations and clear the queue of received blocks
-    // (when the blocks were originally allocated to the batch, the queue must have been cleared).
+    // Insert the recovered block-to-batch allocations and removes them from queue of
+    // received blocks.
     def insertAllocatedBatch(batchTime: Time, allocatedBlocks: AllocatedBlocks) {
       logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " +
         s"${allocatedBlocks.streamIdToAllocatedBlocks}")
-      streamIdToUnallocatedBlockQueues.values.foreach { _.clear() }
+      allocatedBlocks.streamIdToAllocatedBlocks.foreach {
+        case (streamId, allocatedBlocksInStream) =>
+          getReceivedBlockQueue(streamId).dequeueAll(allocatedBlocksInStream.toSet)
+      }
       timeToAllocatedBlocks.put(batchTime, allocatedBlocks)
       lastAllocatedBatchTime = batchTime
     }
@@ -227,7 +230,7 @@ private[streaming] class ReceivedBlockTracker(
   }
 
   /** Write an update to the tracker to the write ahead log */
-  private def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = {
+  private[streaming] def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = {
     if (isWriteAheadLogEnabled) {
       logTrace(s"Writing record: $record")
       try {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
index 107c3f5dcc08dbaceaf7842cdecbe7d3610ecf9e..4fa236bd3966332aec3d3c1265f1b090aa115676 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
@@ -33,7 +33,7 @@ import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
 import org.apache.spark.internal.Logging
 import org.apache.spark.storage.StreamBlockId
 import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult
-import org.apache.spark.streaming.scheduler._
+import org.apache.spark.streaming.scheduler.{AllocatedBlocks, _}
 import org.apache.spark.streaming.util._
 import org.apache.spark.streaming.util.WriteAheadLogSuite._
 import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils}
@@ -94,6 +94,27 @@ class ReceivedBlockTrackerSuite
     receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos
   }
 
+  test("recovery with write ahead logs should remove only allocated blocks from received queue") {
+    val manualClock = new ManualClock
+    val batchTime = manualClock.getTimeMillis()
+
+    val tracker1 = createTracker(clock = manualClock)
+    tracker1.isWriteAheadLogEnabled should be (true)
+
+    val allocatedBlockInfos = generateBlockInfos()
+    val unallocatedBlockInfos = generateBlockInfos()
+    val receivedBlockInfos = allocatedBlockInfos ++ unallocatedBlockInfos
+    receivedBlockInfos.foreach { b => tracker1.writeToLog(BlockAdditionEvent(b)) }
+    val allocatedBlocks = AllocatedBlocks(Map(streamId -> allocatedBlockInfos))
+    tracker1.writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks))
+    tracker1.stop()
+
+    val tracker2 = createTracker(clock = manualClock, recoverFromWriteAheadLog = true)
+    tracker2.getBlocksOfBatch(batchTime) shouldEqual allocatedBlocks.streamIdToAllocatedBlocks
+    tracker2.getUnallocatedBlocks(streamId) shouldEqual unallocatedBlockInfos
+    tracker2.stop()
+  }
+
   test("recovery and cleanup with write ahead logs") {
     val manualClock = new ManualClock
     // Set the time increment level to twice the rotation interval so that every increment creates