diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
index bb47d373de63d0cda0244d538ec91164079b295d..3e67161363e5084ed5f51a86eef096050d3d16d0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
@@ -17,15 +17,14 @@
 
 package org.apache.spark.streaming.dstream
 
-import scala.collection.mutable.HashMap
 import scala.reflect.ClassTag
 
 import org.apache.spark.rdd.{BlockRDD, RDD}
-import org.apache.spark.storage.BlockId
+import org.apache.spark.storage.{BlockId, StorageLevel}
 import org.apache.spark.streaming._
-import org.apache.spark.streaming.receiver.{WriteAheadLogBasedStoreResult, BlockManagerBasedStoreResult, Receiver}
+import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD
+import org.apache.spark.streaming.receiver.{Receiver, WriteAheadLogBasedStoreResult}
 import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
-import org.apache.spark.SparkException
 
 /**
  * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]]
@@ -40,9 +39,6 @@ import org.apache.spark.SparkException
 abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext)
   extends InputDStream[T](ssc_) {
 
-  /** Keeps all received blocks information */
-  private lazy val receivedBlockInfo = new HashMap[Time, Array[ReceivedBlockInfo]]
-
   /** This is an unique identifier for the network input stream. */
   val id = ssc.getNewReceiverStreamId()
 
@@ -58,24 +54,45 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
 
   def stop() {}
 
-  /** Ask ReceiverInputTracker for received data blocks and generates RDDs with them. */
+  /**
+   * Generates RDDs with blocks received by the receiver of this stream. */
   override def compute(validTime: Time): Option[RDD[T]] = {
-    // If this is called for any time before the start time of the context,
-    // then this returns an empty RDD. This may happen when recovering from a
-    // master failure
-    if (validTime >= graph.startTime) {
-      val blockInfo = ssc.scheduler.receiverTracker.getReceivedBlockInfo(id)
-      receivedBlockInfo(validTime) = blockInfo
-      val blockIds = blockInfo.map { _.blockStoreResult.blockId.asInstanceOf[BlockId] }
-      Some(new BlockRDD[T](ssc.sc, blockIds))
-    } else {
-      Some(new BlockRDD[T](ssc.sc, Array.empty))
-    }
-  }
+    val blockRDD = {
 
-  /** Get information on received blocks. */
-  private[streaming] def getReceivedBlockInfo(time: Time) = {
-    receivedBlockInfo.get(time).getOrElse(Array.empty[ReceivedBlockInfo])
+      if (validTime < graph.startTime) {
+        // If this is called for any time before the start time of the context,
+        // then this returns an empty RDD. This may happen when recovering from a
+        // driver failure without any write ahead log to recover pre-failure data.
+        new BlockRDD[T](ssc.sc, Array.empty)
+      } else {
+        // Otherwise, ask the tracker for all the blocks that have been allocated to this stream
+        // for this batch
+        val blockInfos =
+          ssc.scheduler.receiverTracker.getBlocksOfBatch(validTime).get(id).getOrElse(Seq.empty)
+        val blockStoreResults = blockInfos.map { _.blockStoreResult }
+        val blockIds = blockStoreResults.map { _.blockId.asInstanceOf[BlockId] }.toArray
+
+        // Check whether all the results are of the same type
+        val resultTypes = blockStoreResults.map { _.getClass }.distinct
+        if (resultTypes.size > 1) {
+          logWarning("Multiple result types in block information, WAL information will be ignored.")
+        }
+
+        // If all the results are of type WriteAheadLogBasedStoreResult, then create
+        // WriteAheadLogBackedBlockRDD else create simple BlockRDD.
+        if (resultTypes.size == 1 && resultTypes.head == classOf[WriteAheadLogBasedStoreResult]) {
+          val logSegments = blockStoreResults.map {
+            _.asInstanceOf[WriteAheadLogBasedStoreResult].segment
+          }.toArray
+          // Since storeInBlockManager = false, the storage level does not matter.
+          new WriteAheadLogBackedBlockRDD[T](ssc.sparkContext,
+            blockIds, logSegments, storeInBlockManager = true, StorageLevel.MEMORY_ONLY_SER)
+        } else {
+          new BlockRDD[T](ssc.sc, blockIds)
+        }
+      }
+    }
+    Some(blockRDD)
   }
 
   /**
@@ -86,10 +103,6 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
    */
   private[streaming] override def clearMetadata(time: Time) {
     super.clearMetadata(time)
-    val oldReceivedBlocks = receivedBlockInfo.filter(_._1 <= (time - rememberDuration))
-    receivedBlockInfo --= oldReceivedBlocks.keys
-    logDebug("Cleared " + oldReceivedBlocks.size + " RDDs that were older than " +
-      (time - rememberDuration) + ": " + oldReceivedBlocks.keys.mkString(", "))
+    ssc.scheduler.receiverTracker.cleanupOldMetadata(time - rememberDuration)
   }
 }
-
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
index 23295bf6587126661b196ca144b10fb8334a0ba5..dd1e96334952ff95974d3bf37e732a0e1f49b9f3 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
@@ -48,7 +48,6 @@ class WriteAheadLogBackedBlockRDDPartition(
  * If it does not find them, it looks up the corresponding file segment.
  *
  * @param sc SparkContext
- * @param hadoopConfig Hadoop configuration
  * @param blockIds Ids of the blocks that contains this RDD's data
  * @param segments Segments in write ahead logs that contain this RDD's data
  * @param storeInBlockManager Whether to store in the block manager after reading from the segment
@@ -58,7 +57,6 @@ class WriteAheadLogBackedBlockRDDPartition(
 private[streaming]
 class WriteAheadLogBackedBlockRDD[T: ClassTag](
     @transient sc: SparkContext,
-    @transient hadoopConfig: Configuration,
     @transient blockIds: Array[BlockId],
     @transient segments: Array[WriteAheadLogFileSegment],
     storeInBlockManager: Boolean,
@@ -71,6 +69,7 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
       s"the same as number of segments (${segments.length}})!")
 
   // Hadoop configuration is not serializable, so broadcast it as a serializable.
+  @transient private val hadoopConfig = sc.hadoopConfiguration
   private val broadcastedHadoopConf = new SerializableWritable(hadoopConfig)
 
   override def getPartitions: Array[Partition] = {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 7d73ada12d107f67ea0e2609aa1331c8654a8933..39b66e1130768e704e10c40f14dac922941a2d88 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -112,7 +112,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
       // Wait until all the received blocks in the network input tracker has
       // been consumed by network input DStreams, and jobs have been generated with them
       logInfo("Waiting for all received blocks to be consumed for job generation")
-      while(!hasTimedOut && jobScheduler.receiverTracker.hasMoreReceivedBlockIds) {
+      while(!hasTimedOut && jobScheduler.receiverTracker.hasUnallocatedBlocks) {
         Thread.sleep(pollTime)
       }
       logInfo("Waited for all received blocks to be consumed for job generation")
@@ -217,14 +217,18 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
 
   /** Generate jobs and perform checkpoint for the given `time`.  */
   private def generateJobs(time: Time) {
-    Try(graph.generateJobs(time)) match {
+    // Set the SparkEnv in this thread, so that job generation code can access the environment
+    // Example: BlockRDDs are created in this thread, and it needs to access BlockManager
+    // Update: This is probably redundant after threadlocal stuff in SparkEnv has been removed.
+    SparkEnv.set(ssc.env)
+    Try {
+      jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch
+      graph.generateJobs(time) // generate jobs using allocated block
+    } match {
       case Success(jobs) =>
-        val receivedBlockInfo = graph.getReceiverInputStreams.map { stream =>
-          val streamId = stream.id
-          val receivedBlockInfo = stream.getReceivedBlockInfo(time)
-          (streamId, receivedBlockInfo)
-        }.toMap
-        jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfo))
+        val receivedBlockInfos =
+          jobScheduler.receiverTracker.getBlocksOfBatch(time).mapValues { _.toArray }
+        jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfos))
       case Failure(e) =>
         jobScheduler.reportError("Error generating jobs for time " + time, e)
     }
@@ -234,6 +238,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
   /** Clear DStream metadata for the given `time`. */
   private def clearMetadata(time: Time) {
     ssc.graph.clearMetadata(time)
+    jobScheduler.receiverTracker.cleanupOldMetadata(time - graph.batchDuration)
 
     // If checkpointing is enabled, then checkpoint,
     // else mark batch to be fully processed
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
new file mode 100644
index 0000000000000000000000000000000000000000..5f5e1909908d54e5542bb9f68a40b8937cf85df7
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.scheduler
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable
+import scala.language.implicitConversions
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.{SparkException, Logging, SparkConf}
+import org.apache.spark.streaming.Time
+import org.apache.spark.streaming.util.{Clock, WriteAheadLogManager}
+import org.apache.spark.util.Utils
+
+/** Trait representing any event in the ReceivedBlockTracker that updates its state. */
+private[streaming] sealed trait ReceivedBlockTrackerLogEvent
+
+private[streaming] case class BlockAdditionEvent(receivedBlockInfo: ReceivedBlockInfo)
+  extends ReceivedBlockTrackerLogEvent
+private[streaming] case class BatchAllocationEvent(time: Time, allocatedBlocks: AllocatedBlocks)
+  extends ReceivedBlockTrackerLogEvent
+private[streaming] case class BatchCleanupEvent(times: Seq[Time])
+  extends ReceivedBlockTrackerLogEvent
+
+
+/** Class representing the blocks of all the streams allocated to a batch */
+private[streaming]
+case class AllocatedBlocks(streamIdToAllocatedBlocks: Map[Int, Seq[ReceivedBlockInfo]]) {
+  def getBlocksOfStream(streamId: Int): Seq[ReceivedBlockInfo] = {
+    streamIdToAllocatedBlocks.get(streamId).getOrElse(Seq.empty)
+  }
+}
+
+/**
+ * Class that keep track of all the received blocks, and allocate them to batches
+ * when required. All actions taken by this class can be saved to a write ahead log
+ * (if a checkpoint directory has been provided), so that the state of the tracker
+ * (received blocks and block-to-batch allocations) can be recovered after driver failure.
+ *
+ * Note that when any instance of this class is created with a checkpoint directory,
+ * it will try reading events from logs in the directory.
+ */
+private[streaming] class ReceivedBlockTracker(
+    conf: SparkConf,
+    hadoopConf: Configuration,
+    streamIds: Seq[Int],
+    clock: Clock,
+    checkpointDirOption: Option[String])
+  extends Logging {
+
+  private type ReceivedBlockQueue = mutable.Queue[ReceivedBlockInfo]
+  
+  private val streamIdToUnallocatedBlockQueues = new mutable.HashMap[Int, ReceivedBlockQueue]
+  private val timeToAllocatedBlocks = new mutable.HashMap[Time, AllocatedBlocks]
+
+  private val logManagerRollingIntervalSecs = conf.getInt(
+    "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", 60)
+  private val logManagerOption = checkpointDirOption.map { checkpointDir =>
+    new WriteAheadLogManager(
+      ReceivedBlockTracker.checkpointDirToLogDir(checkpointDir),
+      hadoopConf,
+      rollingIntervalSecs = logManagerRollingIntervalSecs,
+      callerName = "ReceivedBlockHandlerMaster",
+      clock = clock
+    )
+  }
+
+  private var lastAllocatedBatchTime: Time = null
+
+  // Recover block information from write ahead logs
+  recoverFromWriteAheadLogs()
+
+  /** Add received block. This event will get written to the write ahead log (if enabled). */
+  def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = synchronized {
+    try {
+      writeToLog(BlockAdditionEvent(receivedBlockInfo))
+      getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo
+      logDebug(s"Stream ${receivedBlockInfo.streamId} received " +
+        s"block ${receivedBlockInfo.blockStoreResult.blockId}")
+      true
+    } catch {
+      case e: Exception =>
+        logError(s"Error adding block $receivedBlockInfo", e)
+        false
+    }
+  }
+
+  /**
+   * Allocate all unallocated blocks to the given batch.
+   * This event will get written to the write ahead log (if enabled).
+   */
+  def allocateBlocksToBatch(batchTime: Time): Unit = synchronized {
+    if (lastAllocatedBatchTime == null || batchTime > lastAllocatedBatchTime) {
+      val streamIdToBlocks = streamIds.map { streamId =>
+          (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true))
+      }.toMap
+      val allocatedBlocks = AllocatedBlocks(streamIdToBlocks)
+      writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks))
+      timeToAllocatedBlocks(batchTime) = allocatedBlocks
+      lastAllocatedBatchTime = batchTime
+      allocatedBlocks
+    } else {
+      throw new SparkException(s"Unexpected allocation of blocks, " +
+        s"last batch = $lastAllocatedBatchTime, batch time to allocate = $batchTime  ")
+    }
+  }
+
+  /** Get the blocks allocated to the given batch. */
+  def getBlocksOfBatch(batchTime: Time): Map[Int, Seq[ReceivedBlockInfo]] = synchronized {
+    timeToAllocatedBlocks.get(batchTime).map { _.streamIdToAllocatedBlocks }.getOrElse(Map.empty)
+  }
+
+  /** Get the blocks allocated to the given batch and stream. */
+  def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = {
+    synchronized {
+      timeToAllocatedBlocks.get(batchTime).map {
+        _.getBlocksOfStream(streamId)
+      }.getOrElse(Seq.empty)
+    }
+  }
+
+  /** Check if any blocks are left to be allocated to batches. */
+  def hasUnallocatedReceivedBlocks: Boolean = synchronized {
+    !streamIdToUnallocatedBlockQueues.values.forall(_.isEmpty)
+  }
+
+  /**
+   * Get blocks that have been added but not yet allocated to any batch. This method
+   * is primarily used for testing.
+   */
+  def getUnallocatedBlocks(streamId: Int): Seq[ReceivedBlockInfo] = synchronized {
+    getReceivedBlockQueue(streamId).toSeq
+  }
+
+  /** Clean up block information of old batches. */
+  def cleanupOldBatches(cleanupThreshTime: Time): Unit = synchronized {
+    assert(cleanupThreshTime.milliseconds < clock.currentTime())
+    val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq
+    logInfo("Deleting batches " + timesToCleanup)
+    writeToLog(BatchCleanupEvent(timesToCleanup))
+    timeToAllocatedBlocks --= timesToCleanup
+    logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds))
+    log
+  }
+
+  /** Stop the block tracker. */
+  def stop() {
+    logManagerOption.foreach { _.stop() }
+  }
+
+  /**
+   * Recover all the tracker actions from the write ahead logs to recover the state (unallocated
+   * and allocated block info) prior to failure.
+   */
+  private def recoverFromWriteAheadLogs(): Unit = synchronized {
+    // Insert the recovered block information
+    def insertAddedBlock(receivedBlockInfo: ReceivedBlockInfo) {
+      logTrace(s"Recovery: Inserting added block $receivedBlockInfo")
+      getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo
+    }
+
+    // Insert the recovered block-to-batch allocations and clear the queue of received blocks
+    // (when the blocks were originally allocated to the batch, the queue must have been cleared).
+    def insertAllocatedBatch(batchTime: Time, allocatedBlocks: AllocatedBlocks) {
+      logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " +
+        s"${allocatedBlocks.streamIdToAllocatedBlocks}")
+      streamIdToUnallocatedBlockQueues.values.foreach { _.clear() }
+      lastAllocatedBatchTime = batchTime
+      timeToAllocatedBlocks.put(batchTime, allocatedBlocks)
+    }
+
+    // Cleanup the batch allocations
+    def cleanupBatches(batchTimes: Seq[Time]) {
+      logTrace(s"Recovery: Cleaning up batches $batchTimes")
+      timeToAllocatedBlocks --= batchTimes
+    }
+
+    logManagerOption.foreach { logManager =>
+      logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}")
+      logManager.readFromLog().foreach { byteBuffer =>
+        logTrace("Recovering record " + byteBuffer)
+        Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) match {
+          case BlockAdditionEvent(receivedBlockInfo) =>
+            insertAddedBlock(receivedBlockInfo)
+          case BatchAllocationEvent(time, allocatedBlocks) =>
+            insertAllocatedBatch(time, allocatedBlocks)
+          case BatchCleanupEvent(batchTimes) =>
+            cleanupBatches(batchTimes)
+        }
+      }
+    }
+  }
+
+  /** Write an update to the tracker to the write ahead log */
+  private def writeToLog(record: ReceivedBlockTrackerLogEvent) {
+    logDebug(s"Writing to log $record")
+    logManagerOption.foreach { logManager =>
+        logManager.writeToLog(ByteBuffer.wrap(Utils.serialize(record)))
+    }
+  }
+
+  /** Get the queue of received blocks belonging to a particular stream */
+  private def getReceivedBlockQueue(streamId: Int): ReceivedBlockQueue = {
+    streamIdToUnallocatedBlockQueues.getOrElseUpdate(streamId, new ReceivedBlockQueue)
+  }
+}
+
+private[streaming] object ReceivedBlockTracker {
+  def checkpointDirToLogDir(checkpointDir: String): String = {
+    new Path(checkpointDir, "receivedBlockMetadata").toString
+  }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index d696563bcee83fe614957c5f9f3dbb2a381e2e22..1c3984d968d205887015dafb4fed48fe355a2be4 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -17,15 +17,16 @@
 
 package org.apache.spark.streaming.scheduler
 
-import scala.collection.mutable.{HashMap, SynchronizedMap, SynchronizedQueue}
+
+import scala.collection.mutable.{HashMap, SynchronizedMap}
 import scala.language.existentials
 
 import akka.actor._
-import org.apache.spark.{SerializableWritable, Logging, SparkEnv, SparkException}
+
+import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException}
 import org.apache.spark.SparkContext._
 import org.apache.spark.streaming.{StreamingContext, Time}
 import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver}
-import org.apache.spark.util.AkkaUtils
 
 /**
  * Messages used by the NetworkReceiver and the ReceiverTracker to communicate
@@ -48,23 +49,28 @@ private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, err
  * This class manages the execution of the receivers of NetworkInputDStreams. Instance of
  * this class must be created after all input streams have been added and StreamingContext.start()
  * has been called because it needs the final set of input streams at the time of instantiation.
+ *
+ * @param skipReceiverLaunch Do not launch the receiver. This is useful for testing.
  */
 private[streaming]
-class ReceiverTracker(ssc: StreamingContext) extends Logging {
+class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false) extends Logging {
 
-  val receiverInputStreams = ssc.graph.getReceiverInputStreams()
-  val receiverInputStreamMap = Map(receiverInputStreams.map(x => (x.id, x)): _*)
-  val receiverExecutor = new ReceiverLauncher()
-  val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo]
-  val receivedBlockInfo = new HashMap[Int, SynchronizedQueue[ReceivedBlockInfo]]
-    with SynchronizedMap[Int, SynchronizedQueue[ReceivedBlockInfo]]
-  val timeout = AkkaUtils.askTimeout(ssc.conf)
-  val listenerBus = ssc.scheduler.listenerBus
+  private val receiverInputStreams = ssc.graph.getReceiverInputStreams()
+  private val receiverInputStreamIds = receiverInputStreams.map { _.id }
+  private val receiverExecutor = new ReceiverLauncher()
+  private val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo]
+  private val receivedBlockTracker = new ReceivedBlockTracker(
+    ssc.sparkContext.conf,
+    ssc.sparkContext.hadoopConfiguration,
+    receiverInputStreamIds,
+    ssc.scheduler.clock,
+    Option(ssc.checkpointDir)
+  )
+  private val listenerBus = ssc.scheduler.listenerBus
 
   // actor is created when generator starts.
   // This not being null means the tracker has been started and not stopped
-  var actor: ActorRef = null
-  var currentTime: Time = null
+  private var actor: ActorRef = null
 
   /** Start the actor and receiver execution thread. */
   def start() = synchronized {
@@ -75,7 +81,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
     if (!receiverInputStreams.isEmpty) {
       actor = ssc.env.actorSystem.actorOf(Props(new ReceiverTrackerActor),
         "ReceiverTracker")
-      receiverExecutor.start()
+      if (!skipReceiverLaunch) receiverExecutor.start()
       logInfo("ReceiverTracker started")
     }
   }
@@ -84,45 +90,59 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
   def stop() = synchronized {
     if (!receiverInputStreams.isEmpty && actor != null) {
       // First, stop the receivers
-      receiverExecutor.stop()
+      if (!skipReceiverLaunch) receiverExecutor.stop()
 
       // Finally, stop the actor
       ssc.env.actorSystem.stop(actor)
       actor = null
+      receivedBlockTracker.stop()
       logInfo("ReceiverTracker stopped")
     }
   }
 
-  /** Return all the blocks received from a receiver. */
-  def getReceivedBlockInfo(streamId: Int): Array[ReceivedBlockInfo] = {
-    val receivedBlockInfo = getReceivedBlockInfoQueue(streamId).dequeueAll(x => true)
-    logInfo("Stream " + streamId + " received " + receivedBlockInfo.size + " blocks")
-    receivedBlockInfo.toArray
+  /** Allocate all unallocated blocks to the given batch. */
+  def allocateBlocksToBatch(batchTime: Time): Unit = {
+    if (receiverInputStreams.nonEmpty) {
+      receivedBlockTracker.allocateBlocksToBatch(batchTime)
+    }
+  }
+
+  /** Get the blocks for the given batch and all input streams. */
+  def getBlocksOfBatch(batchTime: Time): Map[Int, Seq[ReceivedBlockInfo]] = {
+    receivedBlockTracker.getBlocksOfBatch(batchTime)
   }
 
-  private def getReceivedBlockInfoQueue(streamId: Int) = {
-    receivedBlockInfo.getOrElseUpdate(streamId, new SynchronizedQueue[ReceivedBlockInfo])
+  /** Get the blocks allocated to the given batch and stream. */
+  def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = {
+    synchronized {
+      receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId)
+    }
+  }
+
+    /** Clean up metadata older than the given threshold time */
+  def cleanupOldMetadata(cleanupThreshTime: Time) {
+    receivedBlockTracker.cleanupOldBatches(cleanupThreshTime)
   }
 
   /** Register a receiver */
-  def registerReceiver(
+  private def registerReceiver(
       streamId: Int,
       typ: String,
       host: String,
       receiverActor: ActorRef,
       sender: ActorRef
     ) {
-    if (!receiverInputStreamMap.contains(streamId)) {
-      throw new Exception("Register received for unexpected id " + streamId)
+    if (!receiverInputStreamIds.contains(streamId)) {
+      throw new SparkException("Register received for unexpected id " + streamId)
     }
     receiverInfo(streamId) = ReceiverInfo(
       streamId, s"${typ}-${streamId}", receiverActor, true, host)
-    ssc.scheduler.listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId)))
+    listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId)))
     logInfo("Registered receiver for stream " + streamId + " from " + sender.path.address)
   }
 
   /** Deregister a receiver */
-  def deregisterReceiver(streamId: Int, message: String, error: String) {
+  private def deregisterReceiver(streamId: Int, message: String, error: String) {
     val newReceiverInfo = receiverInfo.get(streamId) match {
       case Some(oldInfo) =>
         oldInfo.copy(actor = null, active = false, lastErrorMessage = message, lastError = error)
@@ -131,7 +151,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
         ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error)
     }
     receiverInfo(streamId) = newReceiverInfo
-    ssc.scheduler.listenerBus.post(StreamingListenerReceiverStopped(receiverInfo(streamId)))
+    listenerBus.post(StreamingListenerReceiverStopped(receiverInfo(streamId)))
     val messageWithError = if (error != null && !error.isEmpty) {
       s"$message - $error"
     } else {
@@ -141,14 +161,12 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
   }
 
   /** Add new blocks for the given stream */
-  def addBlocks(receivedBlockInfo: ReceivedBlockInfo) {
-    getReceivedBlockInfoQueue(receivedBlockInfo.streamId) += receivedBlockInfo
-    logDebug("Stream " + receivedBlockInfo.streamId + " received new blocks: " +
-      receivedBlockInfo.blockStoreResult.blockId)
+  private def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = {
+    receivedBlockTracker.addBlock(receivedBlockInfo)
   }
 
   /** Report error sent by a receiver */
-  def reportError(streamId: Int, message: String, error: String) {
+  private def reportError(streamId: Int, message: String, error: String) {
     val newReceiverInfo = receiverInfo.get(streamId) match {
       case Some(oldInfo) =>
         oldInfo.copy(lastErrorMessage = message, lastError = error)
@@ -157,7 +175,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
         ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, lastError = error)
     }
     receiverInfo(streamId) = newReceiverInfo
-    ssc.scheduler.listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId)))
+    listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId)))
     val messageWithError = if (error != null && !error.isEmpty) {
       s"$message - $error"
     } else {
@@ -167,8 +185,8 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
   }
 
   /** Check if any blocks are left to be processed */
-  def hasMoreReceivedBlockIds: Boolean = {
-    !receivedBlockInfo.values.forall(_.isEmpty)
+  def hasUnallocatedBlocks: Boolean = {
+    receivedBlockTracker.hasUnallocatedReceivedBlocks
   }
 
   /** Actor to receive messages from the receivers. */
@@ -178,8 +196,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
         registerReceiver(streamId, typ, host, receiverActor, sender)
         sender ! true
       case AddBlock(receivedBlockInfo) =>
-        addBlocks(receivedBlockInfo)
-        sender ! true
+        sender ! addBlock(receivedBlockInfo)
       case ReportError(streamId, message, error) =>
         reportError(streamId, message, error)
       case DeregisterReceiver(streamId, message, error) =>
@@ -194,6 +211,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
     @transient val thread  = new Thread() {
       override def run() {
         try {
+          SparkEnv.set(env)
           startReceivers()
         } catch {
           case ie: InterruptedException => logInfo("ReceiverLauncher interrupted")
@@ -267,7 +285,7 @@ class ReceiverTracker(ssc: StreamingContext) extends Logging {
 
       // Distribute the receivers and start them
       logInfo("Starting " + receivers.length + " receivers")
-      ssc.sparkContext.runJob(tempRDD, startReceiver)
+      ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver))
       logInfo("All of the receivers have been terminated")
     }
 
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index 6c8bb501453677317a042058142e1b4cae758b5c..dbab685dc351145d8f2a653395f6dd6408a50681 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -17,18 +17,19 @@
 
 package org.apache.spark.streaming
 
-import org.apache.spark.streaming.StreamingContext._
-
-import org.apache.spark.rdd.{BlockRDD, RDD}
-import org.apache.spark.SparkContext._
+import scala.collection.mutable
+import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
+import scala.language.existentials
+import scala.reflect.ClassTag
 
 import util.ManualClock
-import org.apache.spark.{SparkException, SparkConf}
-import org.apache.spark.streaming.dstream.{WindowedDStream, DStream}
-import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
-import scala.reflect.ClassTag
+
+import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd.{BlockRDD, RDD}
 import org.apache.spark.storage.StorageLevel
-import scala.collection.mutable
+import org.apache.spark.streaming.StreamingContext._
+import org.apache.spark.streaming.dstream.{DStream, WindowedDStream}
 
 class BasicOperationsSuite extends TestSuiteBase {
   test("map") {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..fd9c97f551c62f6f8b10700e40156fc9361d68ec
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
@@ -0,0 +1,242 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import java.io.File
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.duration._
+import scala.language.{implicitConversions, postfixOps}
+import scala.util.Random
+
+import com.google.common.io.Files
+import org.apache.commons.io.FileUtils
+import org.apache.hadoop.conf.Configuration
+import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.storage.StreamBlockId
+import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult
+import org.apache.spark.streaming.scheduler._
+import org.apache.spark.streaming.util.{Clock, ManualClock, SystemClock, WriteAheadLogReader}
+import org.apache.spark.streaming.util.WriteAheadLogSuite._
+import org.apache.spark.util.Utils
+
+class ReceivedBlockTrackerSuite
+  extends FunSuite with BeforeAndAfter with Matchers with Logging {
+
+  val conf = new SparkConf().setMaster("local[2]").setAppName("ReceivedBlockTrackerSuite")
+  conf.set("spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", "1")
+
+  val hadoopConf = new Configuration()
+  val akkaTimeout = 10 seconds
+  val streamId = 1
+
+  var allReceivedBlockTrackers = new ArrayBuffer[ReceivedBlockTracker]()
+  var checkpointDirectory: File = null
+
+  before {
+    checkpointDirectory = Files.createTempDir()
+  }
+
+  after {
+    allReceivedBlockTrackers.foreach { _.stop() }
+    if (checkpointDirectory != null && checkpointDirectory.exists()) {
+      FileUtils.deleteDirectory(checkpointDirectory)
+      checkpointDirectory = null
+    }
+  }
+
+  test("block addition, and block to batch allocation") {
+    val receivedBlockTracker = createTracker(enableCheckpoint = false)
+    receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual Seq.empty
+
+    val blockInfos = generateBlockInfos()
+    blockInfos.map(receivedBlockTracker.addBlock)
+
+    // Verify added blocks are unallocated blocks
+    receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos
+
+    // Allocate the blocks to a batch and verify that all of them have been allocated
+    receivedBlockTracker.allocateBlocksToBatch(1)
+    receivedBlockTracker.getBlocksOfBatchAndStream(1, streamId) shouldEqual blockInfos
+    receivedBlockTracker.getUnallocatedBlocks(streamId) shouldBe empty
+
+    // Allocate no blocks to another batch
+    receivedBlockTracker.allocateBlocksToBatch(2)
+    receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty
+
+    // Verify that batch 2 cannot be allocated again
+    intercept[SparkException] {
+      receivedBlockTracker.allocateBlocksToBatch(2)
+    }
+
+    // Verify that older batches cannot be allocated again
+    intercept[SparkException] {
+      receivedBlockTracker.allocateBlocksToBatch(1)
+    }
+  }
+
+  test("block addition, block to batch allocation and cleanup with write ahead log") {
+    val manualClock = new ManualClock
+    conf.getInt(
+      "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", -1) should be (1)
+
+    // Set the time increment level to twice the rotation interval so that every increment creates
+    // a new log file
+    val timeIncrementMillis = 2000L
+    def incrementTime() {
+      manualClock.addToTime(timeIncrementMillis)
+    }
+
+    // Generate and add blocks to the given tracker
+    def addBlockInfos(tracker: ReceivedBlockTracker): Seq[ReceivedBlockInfo] = {
+      val blockInfos = generateBlockInfos()
+      blockInfos.map(tracker.addBlock)
+      blockInfos
+    }
+
+    // Print the data present in the log ahead files in the log directory
+    def printLogFiles(message: String) {
+      val fileContents = getWriteAheadLogFiles().map { file =>
+        (s"\n>>>>> $file: <<<<<\n${getWrittenLogData(file).mkString("\n")}")
+      }.mkString("\n")
+      logInfo(s"\n\n=====================\n$message\n$fileContents\n=====================\n")
+    }
+
+    // Start tracker and add blocks
+    val tracker1 = createTracker(enableCheckpoint = true, clock = manualClock)
+    val blockInfos1 = addBlockInfos(tracker1)
+    tracker1.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1
+
+    // Verify whether write ahead log has correct contents
+    val expectedWrittenData1 = blockInfos1.map(BlockAdditionEvent)
+    getWrittenLogData() shouldEqual expectedWrittenData1
+    getWriteAheadLogFiles() should have size 1
+
+    // Restart tracker and verify recovered list of unallocated blocks
+    incrementTime()
+    val tracker2 = createTracker(enableCheckpoint = true, clock = manualClock)
+    tracker2.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1
+
+    // Allocate blocks to batch and verify whether the unallocated blocks got allocated
+    val batchTime1 = manualClock.currentTime
+    tracker2.allocateBlocksToBatch(batchTime1)
+    tracker2.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual blockInfos1
+
+    // Add more blocks and allocate to another batch
+    incrementTime()
+    val batchTime2 = manualClock.currentTime
+    val blockInfos2 = addBlockInfos(tracker2)
+    tracker2.allocateBlocksToBatch(batchTime2)
+    tracker2.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2
+
+    // Verify whether log has correct contents
+    val expectedWrittenData2 = expectedWrittenData1 ++
+      Seq(createBatchAllocation(batchTime1, blockInfos1)) ++
+      blockInfos2.map(BlockAdditionEvent) ++
+      Seq(createBatchAllocation(batchTime2, blockInfos2))
+    getWrittenLogData() shouldEqual expectedWrittenData2
+
+    // Restart tracker and verify recovered state
+    incrementTime()
+    val tracker3 = createTracker(enableCheckpoint = true, clock = manualClock)
+    tracker3.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual blockInfos1
+    tracker3.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2
+    tracker3.getUnallocatedBlocks(streamId) shouldBe empty
+
+    // Cleanup first batch but not second batch
+    val oldestLogFile = getWriteAheadLogFiles().head
+    incrementTime()
+    tracker3.cleanupOldBatches(batchTime2)
+
+    // Verify that the batch allocations have been cleaned, and the act has been written to log
+    tracker3.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual Seq.empty
+    getWrittenLogData(getWriteAheadLogFiles().last) should contain(createBatchCleanup(batchTime1))
+
+    // Verify that at least one log file gets deleted
+    eventually(timeout(10 seconds), interval(10 millisecond)) {
+      getWriteAheadLogFiles() should not contain oldestLogFile
+    }
+    printLogFiles("After cleanup")
+
+    // Restart tracker and verify recovered state, specifically whether info about the first
+    // batch has been removed, but not the second batch
+    incrementTime()
+    val tracker4 = createTracker(enableCheckpoint = true, clock = manualClock)
+    tracker4.getUnallocatedBlocks(streamId) shouldBe empty
+    tracker4.getBlocksOfBatchAndStream(batchTime1, streamId) shouldBe empty  // should be cleaned
+    tracker4.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2
+  }
+
+  /**
+   * Create tracker object with the optional provided clock. Use fake clock if you
+   * want to control time by manually incrementing it to test log cleanup.
+   */
+  def createTracker(enableCheckpoint: Boolean, clock: Clock = new SystemClock): ReceivedBlockTracker = {
+    val cpDirOption = if (enableCheckpoint) Some(checkpointDirectory.toString) else None
+    val tracker = new ReceivedBlockTracker(conf, hadoopConf, Seq(streamId), clock, cpDirOption)
+    allReceivedBlockTrackers += tracker
+    tracker
+  }
+
+  /** Generate blocks infos using random ids */
+  def generateBlockInfos(): Seq[ReceivedBlockInfo] = {
+    List.fill(5)(ReceivedBlockInfo(streamId, 0,
+      BlockManagerBasedStoreResult(StreamBlockId(streamId, math.abs(Random.nextInt)))))
+  }
+
+  /** Get all the data written in the given write ahead log file. */
+  def getWrittenLogData(logFile: String): Seq[ReceivedBlockTrackerLogEvent] = {
+    getWrittenLogData(Seq(logFile))
+  }
+
+  /**
+   * Get all the data written in the given write ahead log files. By default, it will read all
+   * files in the test log directory.
+   */
+  def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles): Seq[ReceivedBlockTrackerLogEvent] = {
+    logFiles.flatMap {
+      file => new WriteAheadLogReader(file, hadoopConf).toSeq
+    }.map { byteBuffer =>
+      Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array)
+    }.toList
+  }
+
+  /** Get all the write ahead log files in the test directory */
+  def getWriteAheadLogFiles(): Seq[String] = {
+    import ReceivedBlockTracker._
+    val logDir = checkpointDirToLogDir(checkpointDirectory.toString)
+    getLogFilesInDirectory(logDir).map { _.toString }
+  }
+
+  /** Create batch allocation object from the given info */
+  def createBatchAllocation(time: Long, blockInfos: Seq[ReceivedBlockInfo]): BatchAllocationEvent = {
+    BatchAllocationEvent(time, AllocatedBlocks(Map((streamId -> blockInfos))))
+  }
+
+  /** Create batch cleanup object from the given info */
+  def createBatchCleanup(time: Long, moreTimes: Long*): BatchCleanupEvent = {
+    BatchCleanupEvent((Seq(time) ++ moreTimes).map(Time.apply))
+  }
+
+  implicit def millisToTime(milliseconds: Long): Time = Time(milliseconds)
+
+  implicit def timeToMillis(time: Time): Long = time.milliseconds
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
index 10160244bcc91518b3142ba34df0dbc6564ad059..d2b983c4b4d1a3a48ca217cb6eca34152bfa6d31 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
@@ -117,12 +117,12 @@ class WriteAheadLogBackedBlockRDDSuite extends FunSuite with BeforeAndAfterAll {
     )
 
     // Create the RDD and verify whether the returned data is correct
-    val rdd = new WriteAheadLogBackedBlockRDD[String](sparkContext, hadoopConf, blockIds.toArray,
+    val rdd = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray,
       segments.toArray, storeInBlockManager = false, StorageLevel.MEMORY_ONLY)
     assert(rdd.collect() === data.flatten)
 
     if (testStoreInBM) {
-      val rdd2 = new WriteAheadLogBackedBlockRDD[String](sparkContext, hadoopConf, blockIds.toArray,
+      val rdd2 = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray,
         segments.toArray, storeInBlockManager = true, StorageLevel.MEMORY_ONLY)
       assert(rdd2.collect() === data.flatten)
       assert(