diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index e59c24adb84af8a8c183ef93a6bd56c5430fc51d..0e285d6088ec1ad044eb378d9b7ae48b61c5eece 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -160,6 +160,14 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
     }
   }
 
+  /**
+   * Get the maximum remember duration across all the input streams. This is a conservative but
+   * safe remember duration which can be used to perform cleanup operations.
+   */
+  def getMaxInputStreamRememberDuration(): Duration = {
+    inputStreams.map { _.rememberDuration }.maxBy { _.milliseconds }
+  }
+
   @throws(classOf[IOException])
   private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
     logDebug("DStreamGraph.writeObject used")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
index afd3c4bc4c4fefc630dffbde79affba3d1a5b1de..8be04314c4285ca79568229ff3bd99d7816ddc28 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
@@ -94,15 +94,4 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
     }
     Some(blockRDD)
   }
-
-  /**
-   * Clear metadata that are older than `rememberDuration` of this DStream.
-   * This is an internal method that should not be called directly. This
-   * implementation overrides the default implementation to clear received
-   * block information.
-   */
-  private[streaming] override def clearMetadata(time: Time) {
-    super.clearMetadata(time)
-    ssc.scheduler.receiverTracker.cleanupOldMetadata(time - rememberDuration)
-  }
 }
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala
index ab9fa192191aac34adaf9a7524d1a1327bd92ee2..7bf3c33319491cb9c099bd8a4e89dfef003cf0f4 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala
@@ -17,7 +17,10 @@
 
 package org.apache.spark.streaming.receiver
 
-/** Messages sent to the NetworkReceiver. */
+import org.apache.spark.streaming.Time
+
+/** Messages sent to the Receiver. */
 private[streaming] sealed trait ReceiverMessage extends Serializable
 private[streaming] object StopReceiver extends ReceiverMessage
+private[streaming] case class CleanupOldBlocks(threshTime: Time) extends ReceiverMessage
 
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
index d7229c2b96d0b83510c20676800211a894956f99..716cf2c7f32fcb923a5eb8f85776f5a316c6a220 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
@@ -30,6 +30,7 @@ import org.apache.hadoop.conf.Configuration
 
 import org.apache.spark.{Logging, SparkEnv, SparkException}
 import org.apache.spark.storage.StreamBlockId
+import org.apache.spark.streaming.Time
 import org.apache.spark.streaming.scheduler._
 import org.apache.spark.util.{AkkaUtils, Utils}
 
@@ -82,6 +83,9 @@ private[streaming] class ReceiverSupervisorImpl(
         case StopReceiver =>
           logInfo("Received stop signal")
           stop("Stopped by driver", None)
+        case CleanupOldBlocks(threshTime) =>
+          logDebug("Received delete old batch signal")
+          cleanupOldBlocks(threshTime)
       }
 
       def ref = self
@@ -193,4 +197,9 @@ private[streaming] class ReceiverSupervisorImpl(
 
   /** Generate new block ID */
   private def nextBlockId = StreamBlockId(streamId, newBlockId.getAndIncrement)
+
+  private def cleanupOldBlocks(cleanupThreshTime: Time): Unit = {
+    logDebug(s"Cleaning up blocks older then $cleanupThreshTime")
+    receivedBlockHandler.cleanupOldBlocks(cleanupThreshTime.milliseconds)
+  }
 }
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 39b66e1130768e704e10c40f14dac922941a2d88..d86f852aba97ea9835e285c5cb3b1847250ffe00 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -238,13 +238,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
   /** Clear DStream metadata for the given `time`. */
   private def clearMetadata(time: Time) {
     ssc.graph.clearMetadata(time)
-    jobScheduler.receiverTracker.cleanupOldMetadata(time - graph.batchDuration)
 
     // If checkpointing is enabled, then checkpoint,
     // else mark batch to be fully processed
     if (shouldCheckpoint) {
       eventActor ! DoCheckpoint(time)
     } else {
+      // If checkpointing is not enabled, then delete metadata information about
+      // received blocks (block data not saved in any case). Otherwise, wait for
+      // checkpointing of this batch to complete.
+      val maxRememberDuration = graph.getMaxInputStreamRememberDuration()
+      jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration)
       markBatchFullyProcessed(time)
     }
   }
@@ -252,6 +256,11 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
   /** Clear DStream checkpoint data for the given `time`. */
   private def clearCheckpointData(time: Time) {
     ssc.graph.clearCheckpointData(time)
+
+    // All the checkpoint information about which batches have been processed, etc have
+    // been saved to checkpoints, so its safe to delete block metadata and data WAL files
+    val maxRememberDuration = graph.getMaxInputStreamRememberDuration()
+    jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration)
     markBatchFullyProcessed(time)
   }
 
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
index c3d9d7b6813d325a6f5807812d6f4012e5b2bb53..ef23b5c79f2e17118b78448121ef4fd0e327f2b0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
@@ -150,7 +150,6 @@ private[streaming] class ReceivedBlockTracker(
     writeToLog(BatchCleanupEvent(timesToCleanup))
     timeToAllocatedBlocks --= timesToCleanup
     logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds, waitForCompletion))
-    log
   }
 
   /** Stop the block tracker. */
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index 8dbb42a86e3bdb359888c6ae222f645095ec2b93..4f998869731ed46ceda0684dfbfb37e5378b7298 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -24,9 +24,8 @@ import scala.language.existentials
 import akka.actor._
 
 import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException}
-import org.apache.spark.SparkContext._
 import org.apache.spark.streaming.{StreamingContext, Time}
-import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver}
+import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, StopReceiver}
 
 /**
  * Messages used by the NetworkReceiver and the ReceiverTracker to communicate
@@ -119,9 +118,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
     }
   }
 
-    /** Clean up metadata older than the given threshold time */
-  def cleanupOldMetadata(cleanupThreshTime: Time) {
+  /**
+   * Clean up the data and metadata of blocks and batches that are strictly
+   * older than the threshold time. Note that this does not
+   */
+  def cleanupOldBlocksAndBatches(cleanupThreshTime: Time) {
+    // Clean up old block and batch metadata
     receivedBlockTracker.cleanupOldBatches(cleanupThreshTime, waitForCompletion = false)
+
+    // Signal the receivers to delete old block data
+    if (ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) {
+      logInfo(s"Cleanup old received batch data: $cleanupThreshTime")
+      receiverInfo.values.flatMap { info => Option(info.actor) }
+        .foreach { _ ! CleanupOldBlocks(cleanupThreshTime) }
+    }
   }
 
   /** Register a receiver */
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala
index 27a28bab83ed5f89992f6ee9acf1b1179702c143..858ba3c9eb4e5ed67685a0167594dfa8707c9b00 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala
@@ -63,7 +63,7 @@ private[streaming] object HdfsUtils {
   }
 
   def getFileSystemForPath(path: Path, conf: Configuration): FileSystem = {
-    // For local file systems, return the raw loca file system, such calls to flush()
+    // For local file systems, return the raw local file system, such calls to flush()
     // actually flushes the stream.
     val fs = path.getFileSystem(conf)
     fs match {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
index e26c0c6859e570458a53347bb2fabf9445319ff6..e8c34a9ee40b9b3b5580206525565e6d87521336 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
@@ -17,21 +17,26 @@
 
 package org.apache.spark.streaming
 
+import java.io.File
 import java.nio.ByteBuffer
 import java.util.concurrent.Semaphore
 
+import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.SparkConf
-import org.apache.spark.storage.{StorageLevel, StreamBlockId}
-import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver, ReceiverSupervisor}
-import org.scalatest.FunSuite
+import com.google.common.io.Files
 import org.scalatest.concurrent.Timeouts
 import org.scalatest.concurrent.Eventually._
 import org.scalatest.time.SpanSugar._
 
+import org.apache.spark.SparkConf
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.StreamBlockId
+import org.apache.spark.streaming.receiver._
+import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._
+
 /** Testsuite for testing the network receiver behavior */
-class ReceiverSuite extends FunSuite with Timeouts {
+class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable {
 
   test("receiver life cycle") {
 
@@ -192,7 +197,6 @@ class ReceiverSuite extends FunSuite with Timeouts {
     val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 3
     val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 1
     val receivedBlockSizes = recordedBlocks.map { _.size }.mkString(",")
-    println(minExpectedMessagesPerBlock, maxExpectedMessagesPerBlock, ":", receivedBlockSizes)
     assert(
       // the first and last block may be incomplete, so we slice them out
       recordedBlocks.drop(1).dropRight(1).forall { block =>
@@ -203,39 +207,91 @@ class ReceiverSuite extends FunSuite with Timeouts {
     )
   }
 
-
   /**
-   * An implementation of NetworkReceiver that is used for testing a receiver's life cycle.
+   * Test whether write ahead logs are generated by received,
+   * and automatically cleaned up. The clean up must be aware of the
+   * remember duration of the input streams. E.g., input streams on which window()
+   * has been applied must remember the data for longer, and hence corresponding
+   * WALs should be cleaned later.
    */
-  class FakeReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
-    @volatile var otherThread: Thread = null
-    @volatile var receiving = false
-    @volatile var onStartCalled = false
-    @volatile var onStopCalled = false
-
-    def onStart() {
-      otherThread = new Thread() {
-        override def run() {
-          receiving = true
-          while(!isStopped()) {
-            Thread.sleep(10)
-          }
+  test("write ahead log - generating and cleaning") {
+    val sparkConf = new SparkConf()
+      .setMaster("local[4]")  // must be at least 3 as we are going to start 2 receivers
+      .setAppName(framework)
+      .set("spark.ui.enabled", "true")
+      .set("spark.streaming.receiver.writeAheadLog.enable", "true")
+      .set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1")
+    val batchDuration = Milliseconds(500)
+    val tempDirectory = Files.createTempDir()
+    val logDirectory1 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 0))
+    val logDirectory2 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 1))
+    val allLogFiles1 = new mutable.HashSet[String]()
+    val allLogFiles2 = new mutable.HashSet[String]()
+    logInfo("Temp checkpoint directory = " + tempDirectory)
+
+    def getBothCurrentLogFiles(): (Seq[String], Seq[String]) = {
+      (getCurrentLogFiles(logDirectory1), getCurrentLogFiles(logDirectory2))
+    }
+
+    def getCurrentLogFiles(logDirectory: File): Seq[String] = {
+      try {
+        if (logDirectory.exists()) {
+          logDirectory1.listFiles().filter { _.getName.startsWith("log") }.map { _.toString }
+        } else {
+          Seq.empty
         }
+      } catch {
+        case e: Exception =>
+          Seq.empty
       }
-      onStartCalled = true
-      otherThread.start()
-
     }
 
-    def onStop() {
-      onStopCalled = true
-      otherThread.join()
+    def printLogFiles(message: String, files: Seq[String]) {
+      logInfo(s"$message (${files.size} files):\n" + files.mkString("\n"))
     }
 
-    def reset() {
-      receiving = false
-      onStartCalled = false
-      onStopCalled = false
+    withStreamingContext(new StreamingContext(sparkConf, batchDuration)) { ssc =>
+      tempDirectory.deleteOnExit()
+      val receiver1 = ssc.sparkContext.clean(new FakeReceiver(sendData = true))
+      val receiver2 = ssc.sparkContext.clean(new FakeReceiver(sendData = true))
+      val receiverStream1 = ssc.receiverStream(receiver1)
+      val receiverStream2 = ssc.receiverStream(receiver2)
+      receiverStream1.register()
+      receiverStream2.window(batchDuration * 6).register()  // 3 second window
+      ssc.checkpoint(tempDirectory.getAbsolutePath())
+      ssc.start()
+
+      // Run until sufficient WAL files have been generated and
+      // the first WAL files has been deleted
+      eventually(timeout(20 seconds), interval(batchDuration.milliseconds millis)) {
+        val (logFiles1, logFiles2) = getBothCurrentLogFiles()
+        allLogFiles1 ++= logFiles1
+        allLogFiles2 ++= logFiles2
+        if (allLogFiles1.size > 0) {
+          assert(!logFiles1.contains(allLogFiles1.toSeq.sorted.head))
+        }
+        if (allLogFiles2.size > 0) {
+          assert(!logFiles2.contains(allLogFiles2.toSeq.sorted.head))
+        }
+        assert(allLogFiles1.size >= 7)
+        assert(allLogFiles2.size >= 7)
+      }
+      ssc.stop(stopSparkContext = true, stopGracefully = true)
+
+      val sortedAllLogFiles1 = allLogFiles1.toSeq.sorted
+      val sortedAllLogFiles2 = allLogFiles2.toSeq.sorted
+      val (leftLogFiles1, leftLogFiles2) = getBothCurrentLogFiles()
+
+      printLogFiles("Receiver 0: all", sortedAllLogFiles1)
+      printLogFiles("Receiver 0: left", leftLogFiles1)
+      printLogFiles("Receiver 1: all", sortedAllLogFiles2)
+      printLogFiles("Receiver 1: left", leftLogFiles2)
+
+      // Verify that necessary latest log files are not deleted
+      //   receiverStream1 needs to retain just the last batch = 1 log file
+      //   receiverStream2 needs to retain 3 seconds (3-seconds window) = 3 log files
+      assert(sortedAllLogFiles1.takeRight(1).forall(leftLogFiles1.contains))
+      assert(sortedAllLogFiles2.takeRight(3).forall(leftLogFiles2.contains))
     }
   }
 
@@ -315,3 +371,42 @@ class ReceiverSuite extends FunSuite with Timeouts {
   }
 }
 
+/**
+ * An implementation of Receiver that is used for testing a receiver's life cycle.
+ */
+class FakeReceiver(sendData: Boolean = false) extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
+  @volatile var otherThread: Thread = null
+  @volatile var receiving = false
+  @volatile var onStartCalled = false
+  @volatile var onStopCalled = false
+
+  def onStart() {
+    otherThread = new Thread() {
+      override def run() {
+        receiving = true
+        var count = 0
+        while(!isStopped()) {
+          if (sendData) {
+            store(count)
+            count += 1
+          }
+          Thread.sleep(10)
+        }
+      }
+    }
+    onStartCalled = true
+    otherThread.start()
+  }
+
+  def onStop() {
+    onStopCalled = true
+    otherThread.join()
+  }
+
+  def reset() {
+    receiving = false
+    onStartCalled = false
+    onStopCalled = false
+  }
+}
+