From 1868bd40dcce23990b98748b0239bd00452b1ca5 Mon Sep 17 00:00:00 2001
From: Tathagata Das <tathagata.das1565@gmail.com>
Date: Wed, 29 Apr 2015 13:06:11 -0700
Subject: [PATCH] [SPARK-7056] [STREAMING] Make the Write Ahead Log pluggable

Users may want the WAL data to be written to non-HDFS data storage systems. To allow that, we have to make the WAL pluggable. The following design doc outlines the plan.

https://docs.google.com/a/databricks.com/document/d/1A2XaOLRFzvIZSi18i_luNw5Rmm9j2j4AigktXxIYxmY/edit?usp=sharing

Things to add.
* Unit tests for WriteAheadLogUtils

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #5645 from tdas/wal-pluggable and squashes the following commits:

2c431fd [Tathagata Das] Minor fixes.
c2bc7384 [Tathagata Das] More changes based on PR comments.
569a416 [Tathagata Das] fixed long line
bde26b1 [Tathagata Das] Renamed segment to record handle everywhere
b65e155 [Tathagata Das] More changes based on PR comments.
d7cd15b [Tathagata Das] Fixed test
1a32a4b [Tathagata Das] Fixed test
e0d19fb [Tathagata Das] Fixed defaults
9310cbf [Tathagata Das] style fix.
86abcb1 [Tathagata Das] Refactored WriteAheadLogUtils, and consolidated all WAL related configuration into it.
84ce469 [Tathagata Das] Added unit test and fixed compilation error.
bce5e75 [Tathagata Das] Fixed long lines.
837c4f5 [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into wal-pluggable
754fbf8 [Tathagata Das] Added license and docs.
09bc6fe [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into wal-pluggable
7dd2d4b [Tathagata Das] Added pluggable WriteAheadLog interface, and refactored all code along with it
---
 .../spark/streaming/kafka/KafkaUtils.scala    |   3 +-
 .../spark/streaming/util/WriteAheadLog.java   |  60 ++++++
 .../util/WriteAheadLogRecordHandle.java       |  30 +++
 .../dstream/ReceiverInputDStream.scala        |   2 +-
 .../rdd/WriteAheadLogBackedBlockRDD.scala     |  79 +++++--
 .../receiver/ReceivedBlockHandler.scala       |  38 ++--
 .../receiver/ReceiverSupervisorImpl.scala     |   5 +-
 .../scheduler/ReceivedBlockTracker.scala      |  38 ++--
 .../streaming/scheduler/ReceiverTracker.scala |   3 +-
 ...ger.scala => FileBasedWriteAheadLog.scala} |  76 ++++---
 ... FileBasedWriteAheadLogRandomReader.scala} |   8 +-
 ...ala => FileBasedWriteAheadLogReader.scala} |   4 +-
 ...la => FileBasedWriteAheadLogSegment.scala} |   3 +-
 ...ala => FileBasedWriteAheadLogWriter.scala} |   9 +-
 .../streaming/util/WriteAheadLogUtils.scala   | 129 ++++++++++++
 .../streaming/JavaWriteAheadLogSuite.java     | 129 ++++++++++++
 .../streaming/ReceivedBlockHandlerSuite.scala |  18 +-
 .../streaming/ReceivedBlockTrackerSuite.scala |  28 +--
 .../spark/streaming/ReceiverSuite.scala       |   2 +-
 .../streaming/StreamingContextSuite.scala     |   4 +-
 .../WriteAheadLogBackedBlockRDDSuite.scala    |  31 +--
 .../streaming/util/WriteAheadLogSuite.scala   | 194 ++++++++++++------
 22 files changed, 686 insertions(+), 207 deletions(-)
 create mode 100644 streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java
 create mode 100644 streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java
 rename streaming/src/main/scala/org/apache/spark/streaming/util/{WriteAheadLogManager.scala => FileBasedWriteAheadLog.scala} (79%)
 rename streaming/src/main/scala/org/apache/spark/streaming/util/{WriteAheadLogRandomReader.scala => FileBasedWriteAheadLogRandomReader.scala} (83%)
 rename streaming/src/main/scala/org/apache/spark/streaming/util/{WriteAheadLogReader.scala => FileBasedWriteAheadLogReader.scala} (93%)
 rename streaming/src/main/scala/org/apache/spark/streaming/util/{WriteAheadLogFileSegment.scala => FileBasedWriteAheadLogSegment.scala} (86%)
 rename streaming/src/main/scala/org/apache/spark/streaming/util/{WriteAheadLogWriter.scala => FileBasedWriteAheadLogWriter.scala} (88%)
 create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala
 create mode 100644 streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java

diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
index 0721ddaf70..d7cf500577 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
@@ -31,6 +31,7 @@ import kafka.message.MessageAndMetadata
 import kafka.serializer.{DefaultDecoder, Decoder, StringDecoder}
 
 import org.apache.spark.api.java.function.{Function => JFunction}
+import org.apache.spark.streaming.util.WriteAheadLogUtils
 import org.apache.spark.{SparkContext, SparkException}
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.rdd.RDD
@@ -80,7 +81,7 @@ object KafkaUtils {
       topics: Map[String, Int],
       storageLevel: StorageLevel
     ): ReceiverInputDStream[(K, V)] = {
-    val walEnabled = ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)
+    val walEnabled = WriteAheadLogUtils.enableReceiverLog(ssc.conf)
     new KafkaInputDStream[K, V, U, T](ssc, kafkaParams, topics, walEnabled, storageLevel)
   }
 
diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java
new file mode 100644
index 0000000000..8c0fdfa9c7
--- /dev/null
+++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.util;
+
+import java.nio.ByteBuffer;
+import java.util.Iterator;
+
+/**
+ * This abstract class represents a write ahead log (aka journal) that is used by Spark Streaming
+ * to save the received data (by receivers) and associated metadata to a reliable storage, so that
+ * they can be recovered after driver failures. See the Spark documentation for more information
+ * on how to plug in your own custom implementation of a write ahead log.
+ */
+@org.apache.spark.annotation.DeveloperApi
+public abstract class WriteAheadLog {
+  /**
+   * Write the record to the log and return a record handle, which contains all the information
+   * necessary to read back the written record. The time is used to the index the record,
+   * such that it can be cleaned later. Note that implementations of this abstract class must
+   * ensure that the written data is durable and readable (using the record handle) by the
+   * time this function returns.
+   */
+  abstract public WriteAheadLogRecordHandle write(ByteBuffer record, long time);
+
+  /**
+   * Read a written record based on the given record handle.
+   */
+  abstract public ByteBuffer read(WriteAheadLogRecordHandle handle);
+
+  /**
+   * Read and return an iterator of all the records that have been written but not yet cleaned up.
+   */
+  abstract public Iterator<ByteBuffer> readAll();
+
+  /**
+   * Clean all the records that are older than the threshold time. It can wait for
+   * the completion of the deletion.
+   */
+  abstract public void clean(long threshTime, boolean waitForCompletion);
+
+  /**
+   * Close this log and release any resources.
+   */
+  abstract public void close();
+}
diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java
new file mode 100644
index 0000000000..02324189b7
--- /dev/null
+++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLogRecordHandle.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.util;
+
+/**
+ * This abstract class represents a handle that refers to a record written in a
+ * {@link org.apache.spark.streaming.util.WriteAheadLog WriteAheadLog}.
+ * It must contain all the information necessary for the record to be read and returned by
+ * an implemenation of the WriteAheadLog class.
+ *
+ * @see org.apache.spark.streaming.util.WriteAheadLog
+ */
+@org.apache.spark.annotation.DeveloperApi
+public abstract class WriteAheadLogRecordHandle implements java.io.Serializable {
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
index 8be04314c4..4c7fd2c57c 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
@@ -82,7 +82,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
         // WriteAheadLogBackedBlockRDD else create simple BlockRDD.
         if (resultTypes.size == 1 && resultTypes.head == classOf[WriteAheadLogBasedStoreResult]) {
           val logSegments = blockStoreResults.map {
-            _.asInstanceOf[WriteAheadLogBasedStoreResult].segment
+            _.asInstanceOf[WriteAheadLogBasedStoreResult].walRecordHandle
           }.toArray
           // Since storeInBlockManager = false, the storage level does not matter.
           new WriteAheadLogBackedBlockRDD[T](ssc.sparkContext,
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
index 93caa4ba35..ebdf418f4a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala
@@ -16,14 +16,17 @@
  */
 package org.apache.spark.streaming.rdd
 
+import java.nio.ByteBuffer
+
 import scala.reflect.ClassTag
+import scala.util.control.NonFatal
 
-import org.apache.hadoop.conf.Configuration
+import org.apache.commons.io.FileUtils
 
 import org.apache.spark._
 import org.apache.spark.rdd.BlockRDD
 import org.apache.spark.storage.{BlockId, StorageLevel}
-import org.apache.spark.streaming.util.{HdfsUtils, WriteAheadLogFileSegment, WriteAheadLogRandomReader}
+import org.apache.spark.streaming.util._
 
 /**
  * Partition class for [[org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD]].
@@ -31,26 +34,27 @@ import org.apache.spark.streaming.util.{HdfsUtils, WriteAheadLogFileSegment, Wri
  * the segment of the write ahead log that backs the partition.
  * @param index index of the partition
  * @param blockId id of the block having the partition data
- * @param segment segment of the write ahead log having the partition data
+ * @param walRecordHandle Handle of the record in a write ahead log having the partition data
  */
 private[streaming]
 class WriteAheadLogBackedBlockRDDPartition(
     val index: Int,
     val blockId: BlockId,
-    val segment: WriteAheadLogFileSegment)
+    val walRecordHandle: WriteAheadLogRecordHandle)
   extends Partition
 
 
 /**
  * This class represents a special case of the BlockRDD where the data blocks in
- * the block manager are also backed by segments in write ahead logs. For reading
+ * the block manager are also backed by data in write ahead logs. For reading
  * the data, this RDD first looks up the blocks by their ids in the block manager.
- * If it does not find them, it looks up the corresponding file segment.
+ * If it does not find them, it looks up the corresponding data in the write ahead log.
  *
  * @param sc SparkContext
  * @param blockIds Ids of the blocks that contains this RDD's data
- * @param segments Segments in write ahead logs that contain this RDD's data
- * @param storeInBlockManager Whether to store in the block manager after reading from the segment
+ * @param walRecordHandles Record handles in write ahead logs that contain this RDD's data
+ * @param storeInBlockManager Whether to store in the block manager after reading
+ *                            from the WAL record
  * @param storageLevel storage level to store when storing in block manager
  *                     (applicable when storeInBlockManager = true)
  */
@@ -58,15 +62,15 @@ private[streaming]
 class WriteAheadLogBackedBlockRDD[T: ClassTag](
     @transient sc: SparkContext,
     @transient blockIds: Array[BlockId],
-    @transient segments: Array[WriteAheadLogFileSegment],
+    @transient walRecordHandles: Array[WriteAheadLogRecordHandle],
     storeInBlockManager: Boolean,
     storageLevel: StorageLevel)
   extends BlockRDD[T](sc, blockIds) {
 
   require(
-    blockIds.length == segments.length,
+    blockIds.length == walRecordHandles.length,
     s"Number of block ids (${blockIds.length}) must be " +
-      s"the same as number of segments (${segments.length}})!")
+      s"the same as number of WAL record handles (${walRecordHandles.length}})!")
 
   // Hadoop configuration is not serializable, so broadcast it as a serializable.
   @transient private val hadoopConfig = sc.hadoopConfiguration
@@ -75,13 +79,13 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
   override def getPartitions: Array[Partition] = {
     assertValid()
     Array.tabulate(blockIds.size) { i =>
-      new WriteAheadLogBackedBlockRDDPartition(i, blockIds(i), segments(i))
+      new WriteAheadLogBackedBlockRDDPartition(i, blockIds(i), walRecordHandles(i))
     }
   }
 
   /**
    * Gets the partition data by getting the corresponding block from the block manager.
-   * If the block does not exist, then the data is read from the corresponding segment
+   * If the block does not exist, then the data is read from the corresponding record
    * in write ahead log files.
    */
   override def compute(split: Partition, context: TaskContext): Iterator[T] = {
@@ -96,10 +100,35 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
         logDebug(s"Read partition data of $this from block manager, block $blockId")
         iterator
       case None => // Data not found in Block Manager, grab it from write ahead log file
-        val reader = new WriteAheadLogRandomReader(partition.segment.path, hadoopConf)
-        val dataRead = reader.read(partition.segment)
-        reader.close()
-        logInfo(s"Read partition data of $this from write ahead log, segment ${partition.segment}")
+        var dataRead: ByteBuffer = null
+        var writeAheadLog: WriteAheadLog = null
+        try {
+          // The WriteAheadLogUtils.createLog*** method needs a directory to create a
+          // WriteAheadLog object as the default FileBasedWriteAheadLog needs a directory for
+          // writing log data. However, the directory is not needed if data needs to be read, hence
+          // a dummy path is provided to satisfy the method parameter requirements.
+          // FileBasedWriteAheadLog will not create any file or directory at that path.
+          val dummyDirectory = FileUtils.getTempDirectoryPath()
+          writeAheadLog = WriteAheadLogUtils.createLogForReceiver(
+            SparkEnv.get.conf, dummyDirectory, hadoopConf)
+          dataRead = writeAheadLog.read(partition.walRecordHandle)
+        } catch {
+          case NonFatal(e) =>
+            throw new SparkException(
+              s"Could not read data from write ahead log record ${partition.walRecordHandle}", e)
+        } finally {
+          if (writeAheadLog != null) {
+            writeAheadLog.close()
+            writeAheadLog = null
+          }
+        }
+        if (dataRead == null) {
+          throw new SparkException(
+            s"Could not read data from write ahead log record ${partition.walRecordHandle}, " +
+              s"read returned null")
+        }
+        logInfo(s"Read partition data of $this from write ahead log, record handle " +
+          partition.walRecordHandle)
         if (storeInBlockManager) {
           blockManager.putBytes(blockId, dataRead, storageLevel)
           logDebug(s"Stored partition data of $this into block manager with level $storageLevel")
@@ -111,14 +140,20 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag](
 
   /**
    * Get the preferred location of the partition. This returns the locations of the block
-   * if it is present in the block manager, else it returns the location of the
-   * corresponding segment in HDFS.
+   * if it is present in the block manager, else if FileBasedWriteAheadLogSegment is used,
+   * it returns the location of the corresponding file segment in HDFS .
    */
   override def getPreferredLocations(split: Partition): Seq[String] = {
     val partition = split.asInstanceOf[WriteAheadLogBackedBlockRDDPartition]
     val blockLocations = getBlockIdLocations().get(partition.blockId)
-    blockLocations.getOrElse(
-      HdfsUtils.getFileSegmentLocations(
-        partition.segment.path, partition.segment.offset, partition.segment.length, hadoopConfig))
+    blockLocations.getOrElse {
+      partition.walRecordHandle match {
+        case fileSegment: FileBasedWriteAheadLogSegment =>
+          HdfsUtils.getFileSegmentLocations(
+            fileSegment.path, fileSegment.offset, fileSegment.length, hadoopConfig)
+        case _ =>
+          Seq.empty
+      }
+    }
   }
 }
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
index 297bf04c0c..4b3d9ee4b0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
@@ -17,18 +17,18 @@
 
 package org.apache.spark.streaming.receiver
 
-import scala.concurrent.{Await, ExecutionContext, Future}
 import scala.concurrent.duration._
+import scala.concurrent.{Await, ExecutionContext, Future}
 import scala.language.{existentials, postfixOps}
 
-import WriteAheadLogBasedBlockHandler._
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.Path
 
-import org.apache.spark.{Logging, SparkConf, SparkException}
 import org.apache.spark.storage._
-import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogManager}
-import org.apache.spark.util.{ThreadUtils, Clock, SystemClock}
+import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._
+import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils}
+import org.apache.spark.util.{Clock, SystemClock, ThreadUtils}
+import org.apache.spark.{Logging, SparkConf, SparkException}
 
 /** Trait that represents the metadata related to storage of blocks */
 private[streaming] trait ReceivedBlockStoreResult {
@@ -96,7 +96,7 @@ private[streaming] class BlockManagerBasedBlockHandler(
  */
 private[streaming] case class WriteAheadLogBasedStoreResult(
     blockId: StreamBlockId,
-    segment: WriteAheadLogFileSegment
+    walRecordHandle: WriteAheadLogRecordHandle
   ) extends ReceivedBlockStoreResult
 
 
@@ -116,10 +116,6 @@ private[streaming] class WriteAheadLogBasedBlockHandler(
 
   private val blockStoreTimeout = conf.getInt(
     "spark.streaming.receiver.blockStoreTimeout", 30).seconds
-  private val rollingInterval = conf.getInt(
-    "spark.streaming.receiver.writeAheadLog.rollingInterval", 60)
-  private val maxFailures = conf.getInt(
-    "spark.streaming.receiver.writeAheadLog.maxFailures", 3)
 
   private val effectiveStorageLevel = {
     if (storageLevel.deserialized) {
@@ -139,13 +135,9 @@ private[streaming] class WriteAheadLogBasedBlockHandler(
       s"$effectiveStorageLevel when write ahead log is enabled")
   }
 
-  // Manages rolling log files
-  private val logManager = new WriteAheadLogManager(
-    checkpointDirToLogDir(checkpointDir, streamId),
-    hadoopConf, rollingInterval, maxFailures,
-    callerName = this.getClass.getSimpleName,
-    clock = clock
-  )
+  // Write ahead log manages
+  private val writeAheadLog = WriteAheadLogUtils.createLogForReceiver(
+    conf, checkpointDirToLogDir(checkpointDir, streamId), hadoopConf)
 
   // For processing futures used in parallel block storing into block manager and write ahead log
   // # threads = 2, so that both writing to BM and WAL can proceed in parallel
@@ -183,21 +175,21 @@ private[streaming] class WriteAheadLogBasedBlockHandler(
 
     // Store the block in write ahead log
     val storeInWriteAheadLogFuture = Future {
-      logManager.writeToLog(serializedBlock)
+      writeAheadLog.write(serializedBlock, clock.getTimeMillis())
     }
 
-    // Combine the futures, wait for both to complete, and return the write ahead log segment
+    // Combine the futures, wait for both to complete, and return the write ahead log record handle
     val combinedFuture = storeInBlockManagerFuture.zip(storeInWriteAheadLogFuture).map(_._2)
-    val segment = Await.result(combinedFuture, blockStoreTimeout)
-    WriteAheadLogBasedStoreResult(blockId, segment)
+    val walRecordHandle = Await.result(combinedFuture, blockStoreTimeout)
+    WriteAheadLogBasedStoreResult(blockId, walRecordHandle)
   }
 
   def cleanupOldBlocks(threshTime: Long) {
-    logManager.cleanupOldLogs(threshTime, waitForCompletion = false)
+    writeAheadLog.clean(threshTime, false)
   }
 
   def stop() {
-    logManager.stop()
+    writeAheadLog.close()
   }
 }
 
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
index f2379366f3..93f047b910 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
@@ -25,12 +25,13 @@ import scala.collection.mutable.ArrayBuffer
 import com.google.common.base.Throwables
 import org.apache.hadoop.conf.Configuration
 
-import org.apache.spark.{Logging, SparkEnv, SparkException}
 import org.apache.spark.rpc.{RpcEnv, ThreadSafeRpcEndpoint}
 import org.apache.spark.storage.StreamBlockId
 import org.apache.spark.streaming.Time
 import org.apache.spark.streaming.scheduler._
+import org.apache.spark.streaming.util.WriteAheadLogUtils
 import org.apache.spark.util.{RpcUtils, Utils}
+import org.apache.spark.{Logging, SparkEnv, SparkException}
 
 /**
  * Concrete implementation of [[org.apache.spark.streaming.receiver.ReceiverSupervisor]]
@@ -46,7 +47,7 @@ private[streaming] class ReceiverSupervisorImpl(
   ) extends ReceiverSupervisor(receiver, env.conf) with Logging {
 
   private val receivedBlockHandler: ReceivedBlockHandler = {
-    if (env.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) {
+    if (WriteAheadLogUtils.enableReceiverLog(env.conf)) {
       if (checkpointDirOption.isEmpty) {
         throw new SparkException(
           "Cannot enable receiver write-ahead log without checkpoint directory set. " +
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
index 200cf4ef4b..14e769a281 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
@@ -25,10 +25,10 @@ import scala.language.implicitConversions
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.Path
 
-import org.apache.spark.{SparkException, Logging, SparkConf}
 import org.apache.spark.streaming.Time
-import org.apache.spark.streaming.util.WriteAheadLogManager
+import org.apache.spark.streaming.util.{WriteAheadLog, WriteAheadLogUtils}
 import org.apache.spark.util.{Clock, Utils}
+import org.apache.spark.{Logging, SparkConf, SparkException}
 
 /** Trait representing any event in the ReceivedBlockTracker that updates its state. */
 private[streaming] sealed trait ReceivedBlockTrackerLogEvent
@@ -70,7 +70,7 @@ private[streaming] class ReceivedBlockTracker(
 
   private val streamIdToUnallocatedBlockQueues = new mutable.HashMap[Int, ReceivedBlockQueue]
   private val timeToAllocatedBlocks = new mutable.HashMap[Time, AllocatedBlocks]
-  private val logManagerOption = createLogManager()
+  private val writeAheadLogOption = createWriteAheadLog()
 
   private var lastAllocatedBatchTime: Time = null
 
@@ -155,12 +155,12 @@ private[streaming] class ReceivedBlockTracker(
     logInfo("Deleting batches " + timesToCleanup)
     writeToLog(BatchCleanupEvent(timesToCleanup))
     timeToAllocatedBlocks --= timesToCleanup
-    logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds, waitForCompletion))
+    writeAheadLogOption.foreach(_.clean(cleanupThreshTime.milliseconds, waitForCompletion))
   }
 
   /** Stop the block tracker. */
   def stop() {
-    logManagerOption.foreach { _.stop() }
+    writeAheadLogOption.foreach { _.close() }
   }
 
   /**
@@ -190,9 +190,10 @@ private[streaming] class ReceivedBlockTracker(
       timeToAllocatedBlocks --= batchTimes
     }
 
-    logManagerOption.foreach { logManager =>
+    writeAheadLogOption.foreach { writeAheadLog =>
       logInfo(s"Recovering from write ahead logs in ${checkpointDirOption.get}")
-      logManager.readFromLog().foreach { byteBuffer =>
+      import scala.collection.JavaConversions._
+      writeAheadLog.readAll().foreach { byteBuffer =>
         logTrace("Recovering record " + byteBuffer)
         Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array) match {
           case BlockAdditionEvent(receivedBlockInfo) =>
@@ -208,10 +209,10 @@ private[streaming] class ReceivedBlockTracker(
 
   /** Write an update to the tracker to the write ahead log */
   private def writeToLog(record: ReceivedBlockTrackerLogEvent) {
-    if (isLogManagerEnabled) {
+    if (isWriteAheadLogEnabled) {
       logDebug(s"Writing to log $record")
-      logManagerOption.foreach { logManager =>
-        logManager.writeToLog(ByteBuffer.wrap(Utils.serialize(record)))
+      writeAheadLogOption.foreach { logManager =>
+        logManager.write(ByteBuffer.wrap(Utils.serialize(record)), clock.getTimeMillis())
       }
     }
   }
@@ -222,8 +223,8 @@ private[streaming] class ReceivedBlockTracker(
   }
 
   /** Optionally create the write ahead log manager only if the feature is enabled */
-  private def createLogManager(): Option[WriteAheadLogManager] = {
-    if (conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) {
+  private def createWriteAheadLog(): Option[WriteAheadLog] = {
+    if (WriteAheadLogUtils.enableReceiverLog(conf)) {
       if (checkpointDirOption.isEmpty) {
         throw new SparkException(
           "Cannot enable receiver write-ahead log without checkpoint directory set. " +
@@ -231,19 +232,16 @@ private[streaming] class ReceivedBlockTracker(
             "See documentation for more details.")
       }
       val logDir = ReceivedBlockTracker.checkpointDirToLogDir(checkpointDirOption.get)
-      val rollingIntervalSecs = conf.getInt(
-        "spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", 60)
-      val logManager = new WriteAheadLogManager(logDir, hadoopConf,
-        rollingIntervalSecs = rollingIntervalSecs, clock = clock,
-        callerName = "ReceivedBlockHandlerMaster")
-      Some(logManager)
+
+      val log = WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf)
+      Some(log)
     } else {
       None
     }
   }
 
-  /** Check if the log manager is enabled. This is only used for testing purposes. */
-  private[streaming] def isLogManagerEnabled: Boolean = logManagerOption.nonEmpty
+  /** Check if the write ahead log is enabled. This is only used for testing purposes. */
+  private[streaming] def isWriteAheadLogEnabled: Boolean = writeAheadLogOption.nonEmpty
 }
 
 private[streaming] object ReceivedBlockTracker {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index c4ead6f30a..1af65716d3 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -20,6 +20,7 @@ package org.apache.spark.streaming.scheduler
 import scala.collection.mutable.{HashMap, SynchronizedMap}
 import scala.language.existentials
 
+import org.apache.spark.streaming.util.WriteAheadLogUtils
 import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException}
 import org.apache.spark.rpc._
 import org.apache.spark.streaming.{StreamingContext, Time}
@@ -125,7 +126,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
     receivedBlockTracker.cleanupOldBatches(cleanupThreshTime, waitForCompletion = false)
 
     // Signal the receivers to delete old block data
-    if (ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) {
+    if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) {
       logInfo(s"Cleanup old received batch data: $cleanupThreshTime")
       receiverInfo.values.flatMap { info => Option(info.endpoint) }
         .foreach { _.send(CleanupOldBlocks(cleanupThreshTime)) }
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
similarity index 79%
rename from streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala
rename to streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
index 38a93cc3c9..9985fedc35 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
@@ -17,6 +17,7 @@
 package org.apache.spark.streaming.util
 
 import java.nio.ByteBuffer
+import java.util.{Iterator => JIterator}
 
 import scala.collection.mutable.ArrayBuffer
 import scala.concurrent.{Await, ExecutionContext, Future}
@@ -24,9 +25,9 @@ import scala.language.postfixOps
 
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.Path
-import org.apache.spark.Logging
-import org.apache.spark.util.{ThreadUtils, Clock, SystemClock}
-import WriteAheadLogManager._
+
+import org.apache.spark.util.ThreadUtils
+import org.apache.spark.{Logging, SparkConf}
 
 /**
  * This class manages write ahead log files.
@@ -34,37 +35,32 @@ import WriteAheadLogManager._
  * - Recovers the log files and the reads the recovered records upon failures.
  * - Cleans up old log files.
  *
- * Uses [[org.apache.spark.streaming.util.WriteAheadLogWriter]] to write
- * and [[org.apache.spark.streaming.util.WriteAheadLogReader]] to read.
+ * Uses [[org.apache.spark.streaming.util.FileBasedWriteAheadLogWriter]] to write
+ * and [[org.apache.spark.streaming.util.FileBasedWriteAheadLogReader]] to read.
  *
  * @param logDirectory Directory when rotating log files will be created.
  * @param hadoopConf Hadoop configuration for reading/writing log files.
- * @param rollingIntervalSecs The interval in seconds with which logs will be rolled over.
- *                            Default is one minute.
- * @param maxFailures Max number of failures that is tolerated for every attempt to write to log.
- *                    Default is three.
- * @param callerName Optional name of the class who is using this manager.
- * @param clock Optional clock that is used to check for rotation interval.
  */
-private[streaming] class WriteAheadLogManager(
+private[streaming] class FileBasedWriteAheadLog(
+    conf: SparkConf,
     logDirectory: String,
     hadoopConf: Configuration,
-    rollingIntervalSecs: Int = 60,
-    maxFailures: Int = 3,
-    callerName: String = "",
-    clock: Clock = new SystemClock
-  ) extends Logging {
+    rollingIntervalSecs: Int,
+    maxFailures: Int
+  ) extends WriteAheadLog with Logging {
+
+  import FileBasedWriteAheadLog._
 
   private val pastLogs = new ArrayBuffer[LogInfo]
-  private val callerNameTag =
-    if (callerName.nonEmpty) s" for $callerName" else ""
+  private val callerNameTag = getCallerName.map(c => s" for $c").getOrElse("")
+
   private val threadpoolName = s"WriteAheadLogManager $callerNameTag"
   implicit private val executionContext = ExecutionContext.fromExecutorService(
     ThreadUtils.newDaemonSingleThreadExecutor(threadpoolName))
   override protected val logName = s"WriteAheadLogManager $callerNameTag"
 
   private var currentLogPath: Option[String] = None
-  private var currentLogWriter: WriteAheadLogWriter = null
+  private var currentLogWriter: FileBasedWriteAheadLogWriter = null
   private var currentLogWriterStartTime: Long = -1L
   private var currentLogWriterStopTime: Long = -1L
 
@@ -75,14 +71,14 @@ private[streaming] class WriteAheadLogManager(
    * ByteBuffer to HDFS. When this method returns, the data is guaranteed to have been flushed
    * to HDFS, and will be available for readers to read.
    */
-  def writeToLog(byteBuffer: ByteBuffer): WriteAheadLogFileSegment = synchronized {
-    var fileSegment: WriteAheadLogFileSegment = null
+  def write(byteBuffer: ByteBuffer, time: Long): FileBasedWriteAheadLogSegment = synchronized {
+    var fileSegment: FileBasedWriteAheadLogSegment = null
     var failures = 0
     var lastException: Exception = null
     var succeeded = false
     while (!succeeded && failures < maxFailures) {
       try {
-        fileSegment = getLogWriter(clock.getTimeMillis()).write(byteBuffer)
+        fileSegment = getLogWriter(time).write(byteBuffer)
         succeeded = true
       } catch {
         case ex: Exception =>
@@ -99,6 +95,19 @@ private[streaming] class WriteAheadLogManager(
     fileSegment
   }
 
+  def read(segment: WriteAheadLogRecordHandle): ByteBuffer = {
+    val fileSegment = segment.asInstanceOf[FileBasedWriteAheadLogSegment]
+    var reader: FileBasedWriteAheadLogRandomReader = null
+    var byteBuffer: ByteBuffer = null
+    try {
+      reader = new FileBasedWriteAheadLogRandomReader(fileSegment.path, hadoopConf)
+      byteBuffer = reader.read(fileSegment)
+    } finally {
+      reader.close()
+    }
+    byteBuffer
+  }
+
   /**
    * Read all the existing logs from the log directory.
    *
@@ -108,12 +117,14 @@ private[streaming] class WriteAheadLogManager(
    * the latest the records. This does not deal with currently active log files, and
    * hence the implementation is kept simple.
    */
-  def readFromLog(): Iterator[ByteBuffer] = synchronized {
+  def readAll(): JIterator[ByteBuffer] = synchronized {
+    import scala.collection.JavaConversions._
     val logFilesToRead = pastLogs.map{ _.path} ++ currentLogPath
     logInfo("Reading from the logs: " + logFilesToRead.mkString("\n"))
+
     logFilesToRead.iterator.map { file =>
       logDebug(s"Creating log reader with $file")
-      new WriteAheadLogReader(file, hadoopConf)
+      new FileBasedWriteAheadLogReader(file, hadoopConf)
     } flatMap { x => x }
   }
 
@@ -129,7 +140,7 @@ private[streaming] class WriteAheadLogManager(
    * deleted. This should be set to true only for testing. Else the files will be deleted
    * asynchronously.
    */
-  def cleanupOldLogs(threshTime: Long, waitForCompletion: Boolean): Unit = {
+  def clean(threshTime: Long, waitForCompletion: Boolean): Unit = {
     val oldLogFiles = synchronized { pastLogs.filter { _.endTime < threshTime } }
     logInfo(s"Attempting to clear ${oldLogFiles.size} old log files in $logDirectory " +
       s"older than $threshTime: ${oldLogFiles.map { _.path }.mkString("\n")}")
@@ -160,7 +171,7 @@ private[streaming] class WriteAheadLogManager(
 
 
   /** Stop the manager, close any open log writer */
-  def stop(): Unit = synchronized {
+  def close(): Unit = synchronized {
     if (currentLogWriter != null) {
       currentLogWriter.close()
     }
@@ -169,7 +180,7 @@ private[streaming] class WriteAheadLogManager(
   }
 
   /** Get the current log writer while taking care of rotation */
-  private def getLogWriter(currentTime: Long): WriteAheadLogWriter = synchronized {
+  private def getLogWriter(currentTime: Long): FileBasedWriteAheadLogWriter = synchronized {
     if (currentLogWriter == null || currentTime > currentLogWriterStopTime) {
       resetWriter()
       currentLogPath.foreach {
@@ -180,7 +191,7 @@ private[streaming] class WriteAheadLogManager(
       val newLogPath = new Path(logDirectory,
         timeToLogFile(currentLogWriterStartTime, currentLogWriterStopTime))
       currentLogPath = Some(newLogPath.toString)
-      currentLogWriter = new WriteAheadLogWriter(currentLogPath.get, hadoopConf)
+      currentLogWriter = new FileBasedWriteAheadLogWriter(currentLogPath.get, hadoopConf)
     }
     currentLogWriter
   }
@@ -207,7 +218,7 @@ private[streaming] class WriteAheadLogManager(
   }
 }
 
-private[util] object WriteAheadLogManager {
+private[streaming] object FileBasedWriteAheadLog {
 
   case class LogInfo(startTime: Long, endTime: Long, path: String)
 
@@ -217,6 +228,11 @@ private[util] object WriteAheadLogManager {
     s"log-$startTime-$stopTime"
   }
 
+  def getCallerName(): Option[String] = {
+    val stackTraceClasses = Thread.currentThread.getStackTrace().map(_.getClassName)
+    stackTraceClasses.find(!_.contains("WriteAheadLog")).flatMap(_.split(".").lastOption)
+  }
+
   /** Convert a sequence of files to a sequence of sorted LogInfo objects */
   def logFilesTologInfo(files: Seq[Path]): Seq[LogInfo] = {
     files.flatMap { file =>
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala
similarity index 83%
rename from streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala
rename to streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala
index 003989092a..f7168229ec 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogRandomReader.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogRandomReader.scala
@@ -23,16 +23,16 @@ import org.apache.hadoop.conf.Configuration
 
 /**
  * A random access reader for reading write ahead log files written using
- * [[org.apache.spark.streaming.util.WriteAheadLogWriter]]. Given the file segment info,
- * this reads the record (bytebuffer) from the log file.
+ * [[org.apache.spark.streaming.util.FileBasedWriteAheadLogWriter]]. Given the file segment info,
+ * this reads the record (ByteBuffer) from the log file.
  */
-private[streaming] class WriteAheadLogRandomReader(path: String, conf: Configuration)
+private[streaming] class FileBasedWriteAheadLogRandomReader(path: String, conf: Configuration)
   extends Closeable {
 
   private val instream = HdfsUtils.getInputStream(path, conf)
   private var closed = false
 
-  def read(segment: WriteAheadLogFileSegment): ByteBuffer = synchronized {
+  def read(segment: FileBasedWriteAheadLogSegment): ByteBuffer = synchronized {
     assertOpen()
     instream.seek(segment.offset)
     val nextLength = instream.readInt()
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala
similarity index 93%
rename from streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogReader.scala
rename to streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala
index 2afc0d1551..c3bb59f3fe 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogReader.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala
@@ -24,11 +24,11 @@ import org.apache.spark.Logging
 
 /**
  * A reader for reading write ahead log files written using
- * [[org.apache.spark.streaming.util.WriteAheadLogWriter]]. This reads
+ * [[org.apache.spark.streaming.util.FileBasedWriteAheadLogWriter]]. This reads
  * the records (bytebuffers) in the log file sequentially and return them as an
  * iterator of bytebuffers.
  */
-private[streaming] class WriteAheadLogReader(path: String, conf: Configuration)
+private[streaming] class FileBasedWriteAheadLogReader(path: String, conf: Configuration)
   extends Iterator[ByteBuffer] with Closeable with Logging {
 
   private val instream = HdfsUtils.getInputStream(path, conf)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogFileSegment.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogSegment.scala
similarity index 86%
rename from streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogFileSegment.scala
rename to streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogSegment.scala
index 1005a2c8ec..2e1f1528fa 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogFileSegment.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogSegment.scala
@@ -17,4 +17,5 @@
 package org.apache.spark.streaming.util
 
 /** Class for representing a segment of data in a write ahead log file */
-private[streaming] case class WriteAheadLogFileSegment (path: String, offset: Long, length: Int)
+private[streaming] case class FileBasedWriteAheadLogSegment(path: String, offset: Long, length: Int)
+  extends WriteAheadLogRecordHandle
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogWriter.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala
similarity index 88%
rename from streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogWriter.scala
rename to streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala
index 679f6a6dfd..e146bec32a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogWriter.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala
@@ -17,18 +17,17 @@
 package org.apache.spark.streaming.util
 
 import java.io._
-import java.net.URI
 import java.nio.ByteBuffer
 
 import scala.util.Try
 
 import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FSDataOutputStream, FileSystem}
+import org.apache.hadoop.fs.FSDataOutputStream
 
 /**
  * A writer for writing byte-buffers to a write ahead log file.
  */
-private[streaming] class WriteAheadLogWriter(path: String, hadoopConf: Configuration)
+private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf: Configuration)
   extends Closeable {
 
   private lazy val stream = HdfsUtils.getOutputStream(path, hadoopConf)
@@ -43,11 +42,11 @@ private[streaming] class WriteAheadLogWriter(path: String, hadoopConf: Configura
   private var closed = false
 
   /** Write the bytebuffer to the log file */
-  def write(data: ByteBuffer): WriteAheadLogFileSegment = synchronized {
+  def write(data: ByteBuffer): FileBasedWriteAheadLogSegment = synchronized {
     assertOpen()
     data.rewind() // Rewind to ensure all data in the buffer is retrieved
     val lengthToWrite = data.remaining()
-    val segment = new WriteAheadLogFileSegment(path, nextOffset, lengthToWrite)
+    val segment = new FileBasedWriteAheadLogSegment(path, nextOffset, lengthToWrite)
     stream.writeInt(lengthToWrite)
     if (data.hasArray) {
       stream.write(data.array())
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala
new file mode 100644
index 0000000000..7f6ff12c58
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.util
+
+import scala.util.control.NonFatal
+
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SparkConf, SparkException}
+
+/** A helper class with utility functions related to the WriteAheadLog interface */
+private[streaming] object WriteAheadLogUtils extends Logging {
+  val RECEIVER_WAL_ENABLE_CONF_KEY = "spark.streaming.receiver.writeAheadLog.enable"
+  val RECEIVER_WAL_CLASS_CONF_KEY = "spark.streaming.receiver.writeAheadLog.class"
+  val RECEIVER_WAL_ROLLING_INTERVAL_CONF_KEY =
+    "spark.streaming.receiver.writeAheadLog.rollingIntervalSecs"
+  val RECEIVER_WAL_MAX_FAILURES_CONF_KEY = "spark.streaming.receiver.writeAheadLog.maxFailures"
+
+  val DRIVER_WAL_CLASS_CONF_KEY = "spark.streaming.driver.writeAheadLog.class"
+  val DRIVER_WAL_ROLLING_INTERVAL_CONF_KEY =
+    "spark.streaming.driver.writeAheadLog.rollingIntervalSecs"
+  val DRIVER_WAL_MAX_FAILURES_CONF_KEY = "spark.streaming.driver.writeAheadLog.maxFailures"
+
+  val DEFAULT_ROLLING_INTERVAL_SECS = 60
+  val DEFAULT_MAX_FAILURES = 3
+
+  def enableReceiverLog(conf: SparkConf): Boolean = {
+    conf.getBoolean(RECEIVER_WAL_ENABLE_CONF_KEY, false)
+  }
+
+  def getRollingIntervalSecs(conf: SparkConf, isDriver: Boolean): Int = {
+    if (isDriver) {
+      conf.getInt(DRIVER_WAL_ROLLING_INTERVAL_CONF_KEY, DEFAULT_ROLLING_INTERVAL_SECS)
+    } else {
+      conf.getInt(RECEIVER_WAL_ROLLING_INTERVAL_CONF_KEY, DEFAULT_ROLLING_INTERVAL_SECS)
+    }
+  }
+
+  def getMaxFailures(conf: SparkConf, isDriver: Boolean): Int = {
+    if (isDriver) {
+      conf.getInt(DRIVER_WAL_MAX_FAILURES_CONF_KEY, DEFAULT_MAX_FAILURES)
+    } else {
+      conf.getInt(RECEIVER_WAL_MAX_FAILURES_CONF_KEY, DEFAULT_MAX_FAILURES)
+    }
+  }
+
+  /**
+   * Create a WriteAheadLog for the driver. If configured with custom WAL class, it will try
+   * to create instance of that class, otherwise it will create the default FileBasedWriteAheadLog.
+   */
+  def createLogForDriver(
+      sparkConf: SparkConf,
+      fileWalLogDirectory: String,
+      fileWalHadoopConf: Configuration
+    ): WriteAheadLog = {
+    createLog(true, sparkConf, fileWalLogDirectory, fileWalHadoopConf)
+  }
+
+  /**
+   * Create a WriteAheadLog for the receiver. If configured with custom WAL class, it will try
+   * to create instance of that class, otherwise it will create the default FileBasedWriteAheadLog.
+   */
+  def createLogForReceiver(
+      sparkConf: SparkConf,
+      fileWalLogDirectory: String,
+      fileWalHadoopConf: Configuration
+    ): WriteAheadLog = {
+    createLog(false, sparkConf, fileWalLogDirectory, fileWalHadoopConf)
+  }
+
+  /**
+   * Create a WriteAheadLog based on the value of the given config key. The config key is used
+   * to get the class name from the SparkConf. If the class is configured, it will try to
+   * create instance of that class by first trying `new CustomWAL(sparkConf, logDir)` then trying
+   * `new CustomWAL(sparkConf)`. If either fails, it will fail. If no class is configured, then
+   * it will create the default FileBasedWriteAheadLog.
+   */
+  private def createLog(
+      isDriver: Boolean,
+      sparkConf: SparkConf,
+      fileWalLogDirectory: String,
+      fileWalHadoopConf: Configuration
+    ): WriteAheadLog = {
+
+    val classNameOption = if (isDriver) {
+      sparkConf.getOption(DRIVER_WAL_CLASS_CONF_KEY)
+    } else {
+      sparkConf.getOption(RECEIVER_WAL_CLASS_CONF_KEY)
+    }
+    classNameOption.map { className =>
+      try {
+        instantiateClass(
+          Utils.classForName(className).asInstanceOf[Class[_ <: WriteAheadLog]], sparkConf)
+      } catch {
+        case NonFatal(e) =>
+          throw new SparkException(s"Could not create a write ahead log of class $className", e)
+      }
+    }.getOrElse {
+      new FileBasedWriteAheadLog(sparkConf, fileWalLogDirectory, fileWalHadoopConf,
+        getRollingIntervalSecs(sparkConf, isDriver), getMaxFailures(sparkConf, isDriver))
+    }
+  }
+
+  /** Instantiate the class, either using single arg constructor or zero arg constructor */
+  private def instantiateClass(cls: Class[_ <: WriteAheadLog], conf: SparkConf): WriteAheadLog = {
+    try {
+      cls.getConstructor(classOf[SparkConf]).newInstance(conf)
+    } catch {
+      case nsme: NoSuchMethodException =>
+        cls.getConstructor().newInstance()
+    }
+  }
+}
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java
new file mode 100644
index 0000000000..50e8f9fc15
--- /dev/null
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaWriteAheadLogSuite.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming;
+
+import java.util.ArrayList;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.commons.collections.Transformer;
+import org.apache.spark.SparkConf;
+import org.apache.spark.streaming.util.WriteAheadLog;
+import org.apache.spark.streaming.util.WriteAheadLogRecordHandle;
+import org.apache.spark.streaming.util.WriteAheadLogUtils;
+
+import org.junit.Test;
+import org.junit.Assert;
+
+class JavaWriteAheadLogSuiteHandle extends WriteAheadLogRecordHandle {
+  int index = -1;
+  public JavaWriteAheadLogSuiteHandle(int idx) {
+    index = idx;
+  }
+}
+
+public class JavaWriteAheadLogSuite extends WriteAheadLog {
+
+  class Record {
+    long time;
+    int index;
+    ByteBuffer buffer;
+
+    public Record(long tym, int idx, ByteBuffer buf) {
+      index = idx;
+      time = tym;
+      buffer = buf;
+    }
+  }
+  private int index = -1;
+  private ArrayList<Record> records = new ArrayList<Record>();
+
+
+  // Methods for WriteAheadLog
+  @Override
+  public WriteAheadLogRecordHandle write(java.nio.ByteBuffer record, long time) {
+    index += 1;
+    records.add(new org.apache.spark.streaming.JavaWriteAheadLogSuite.Record(time, index, record));
+    return new JavaWriteAheadLogSuiteHandle(index);
+  }
+
+  @Override
+  public java.nio.ByteBuffer read(WriteAheadLogRecordHandle handle) {
+    if (handle instanceof JavaWriteAheadLogSuiteHandle) {
+      int reqdIndex = ((JavaWriteAheadLogSuiteHandle) handle).index;
+      for (Record record: records) {
+        if (record.index == reqdIndex) {
+          return record.buffer;
+        }
+      }
+    }
+    return null;
+  }
+
+  @Override
+  public java.util.Iterator<java.nio.ByteBuffer> readAll() {
+    Collection<ByteBuffer> buffers = CollectionUtils.collect(records, new Transformer() {
+      @Override
+      public Object transform(Object input) {
+        return ((Record) input).buffer;
+      }
+    });
+    return buffers.iterator();
+  }
+
+  @Override
+  public void clean(long threshTime, boolean waitForCompletion) {
+    for (int i = 0; i < records.size(); i++) {
+      if (records.get(i).time < threshTime) {
+        records.remove(i);
+        i--;
+      }
+    }
+  }
+
+  @Override
+  public void close() {
+    records.clear();
+  }
+
+  @Test
+  public void testCustomWAL() {
+    SparkConf conf = new SparkConf();
+    conf.set("spark.streaming.driver.writeAheadLog.class", JavaWriteAheadLogSuite.class.getName());
+    WriteAheadLog wal = WriteAheadLogUtils.createLogForDriver(conf, null, null);
+
+    String data1 = "data1";
+    WriteAheadLogRecordHandle handle = wal.write(ByteBuffer.wrap(data1.getBytes()), 1234);
+    Assert.assertTrue(handle instanceof JavaWriteAheadLogSuiteHandle);
+    Assert.assertTrue(new String(wal.read(handle).array()).equals(data1));
+
+    wal.write(ByteBuffer.wrap("data2".getBytes()), 1235);
+    wal.write(ByteBuffer.wrap("data3".getBytes()), 1236);
+    wal.write(ByteBuffer.wrap("data4".getBytes()), 1237);
+    wal.clean(1236, false);
+
+    java.util.Iterator<java.nio.ByteBuffer> dataIterator = wal.readAll();
+    ArrayList<String> readData = new ArrayList<String>();
+    while (dataIterator.hasNext()) {
+      readData.add(new String(dataIterator.next().array()));
+    }
+    Assert.assertTrue(readData.equals(Arrays.asList("data3", "data4")));
+  }
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index c090eaec29..23804237bd 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -43,7 +43,7 @@ import WriteAheadLogSuite._
 
 class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matchers with Logging {
 
-  val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1")
+  val conf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1")
   val hadoopConf = new Configuration()
   val storageLevel = StorageLevel.MEMORY_ONLY_SER
   val streamId = 1
@@ -130,10 +130,13 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche
           "Unexpected store result type"
         )
         // Verify the data in write ahead log files is correct
-        val fileSegments = storeResults.map { _.asInstanceOf[WriteAheadLogBasedStoreResult].segment}
-        val loggedData = fileSegments.flatMap { segment =>
-          val reader = new WriteAheadLogRandomReader(segment.path, hadoopConf)
-          val bytes = reader.read(segment)
+        val walSegments = storeResults.map { result =>
+          result.asInstanceOf[WriteAheadLogBasedStoreResult].walRecordHandle
+        }
+        val loggedData = walSegments.flatMap { walSegment =>
+          val fileSegment = walSegment.asInstanceOf[FileBasedWriteAheadLogSegment]
+          val reader = new FileBasedWriteAheadLogRandomReader(fileSegment.path, hadoopConf)
+          val bytes = reader.read(fileSegment)
           reader.close()
           blockManager.dataDeserialize(generateBlockId(), bytes).toList
         }
@@ -148,13 +151,13 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche
     }
   }
 
-  test("WriteAheadLogBasedBlockHandler - cleanup old blocks") {
+  test("WriteAheadLogBasedBlockHandler - clean old blocks") {
     withWriteAheadLogBasedBlockHandler { handler =>
       val blocks = Seq.tabulate(10) { i => IteratorBlock(Iterator(1 to i)) }
       storeBlocks(handler, blocks)
 
       val preCleanupLogFiles = getWriteAheadLogFiles()
-      preCleanupLogFiles.size should be > 1
+      require(preCleanupLogFiles.size > 1)
 
       // this depends on the number of blocks inserted using generateAndStoreData()
       manualClock.getTimeMillis() shouldEqual 5000L
@@ -218,6 +221,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche
 
   /** Instantiate a WriteAheadLogBasedBlockHandler and run a code with it */
   private def withWriteAheadLogBasedBlockHandler(body: WriteAheadLogBasedBlockHandler => Unit) {
+    require(WriteAheadLogUtils.getRollingIntervalSecs(conf, isDriver = false) === 1)
     val receivedBlockHandler = new WriteAheadLogBasedBlockHandler(blockManager, 1,
       storageLevel, conf, hadoopConf, tempDirectory.toString, manualClock)
     try {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
index b63b37d9f9..8317fb9720 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
@@ -32,7 +32,7 @@ import org.apache.spark.{Logging, SparkConf, SparkException}
 import org.apache.spark.storage.StreamBlockId
 import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult
 import org.apache.spark.streaming.scheduler._
-import org.apache.spark.streaming.util.WriteAheadLogReader
+import org.apache.spark.streaming.util.{WriteAheadLogUtils, FileBasedWriteAheadLogReader}
 import org.apache.spark.streaming.util.WriteAheadLogSuite._
 import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils}
 
@@ -59,7 +59,7 @@ class ReceivedBlockTrackerSuite
 
   test("block addition, and block to batch allocation") {
     val receivedBlockTracker = createTracker(setCheckpointDir = false)
-    receivedBlockTracker.isLogManagerEnabled should be (false)  // should be disable by default
+    receivedBlockTracker.isWriteAheadLogEnabled should be (false)  // should be disable by default
     receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual Seq.empty
 
     val blockInfos = generateBlockInfos()
@@ -88,7 +88,7 @@ class ReceivedBlockTrackerSuite
     receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos
   }
 
-  test("block addition, block to batch allocation and cleanup with write ahead log") {
+  test("block addition, block to batch allocation and clean up with write ahead log") {
     val manualClock = new ManualClock
     // Set the time increment level to twice the rotation interval so that every increment creates
     // a new log file
@@ -113,11 +113,15 @@ class ReceivedBlockTrackerSuite
       logInfo(s"\n\n=====================\n$message\n$fileContents\n=====================\n")
     }
 
-    // Start tracker and add blocks
+    // Set WAL configuration
     conf.set("spark.streaming.receiver.writeAheadLog.enable", "true")
-    conf.set("spark.streaming.receivedBlockTracker.writeAheadLog.rotationIntervalSecs", "1")
+    conf.set("spark.streaming.driver.writeAheadLog.rollingIntervalSecs", "1")
+    require(WriteAheadLogUtils.enableReceiverLog(conf))
+    require(WriteAheadLogUtils.getRollingIntervalSecs(conf, isDriver = true) === 1)
+
+    // Start tracker and add blocks
     val tracker1 = createTracker(clock = manualClock)
-    tracker1.isLogManagerEnabled should be (true)
+    tracker1.isWriteAheadLogEnabled should be (true)
 
     val blockInfos1 = addBlockInfos(tracker1)
     tracker1.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1
@@ -171,7 +175,7 @@ class ReceivedBlockTrackerSuite
     eventually(timeout(10 seconds), interval(10 millisecond)) {
       getWriteAheadLogFiles() should not contain oldestLogFile
     }
-    printLogFiles("After cleanup")
+    printLogFiles("After clean")
 
     // Restart tracker and verify recovered state, specifically whether info about the first
     // batch has been removed, but not the second batch
@@ -192,17 +196,17 @@ class ReceivedBlockTrackerSuite
   test("setting checkpoint dir but not enabling write ahead log") {
     // When WAL config is not set, log manager should not be enabled
     val tracker1 = createTracker(setCheckpointDir = true)
-    tracker1.isLogManagerEnabled should be (false)
+    tracker1.isWriteAheadLogEnabled should be (false)
 
     // When WAL is explicitly disabled, log manager should not be enabled
     conf.set("spark.streaming.receiver.writeAheadLog.enable", "false")
     val tracker2 = createTracker(setCheckpointDir = true)
-    tracker2.isLogManagerEnabled should be(false)
+    tracker2.isWriteAheadLogEnabled should be(false)
   }
 
   /**
    * Create tracker object with the optional provided clock. Use fake clock if you
-   * want to control time by manually incrementing it to test log cleanup.
+   * want to control time by manually incrementing it to test log clean.
    */
   def createTracker(
       setCheckpointDir: Boolean = true,
@@ -231,7 +235,7 @@ class ReceivedBlockTrackerSuite
   def getWrittenLogData(logFiles: Seq[String] = getWriteAheadLogFiles)
     : Seq[ReceivedBlockTrackerLogEvent] = {
     logFiles.flatMap {
-      file => new WriteAheadLogReader(file, hadoopConf).toSeq
+      file => new FileBasedWriteAheadLogReader(file, hadoopConf).toSeq
     }.map { byteBuffer =>
       Utils.deserialize[ReceivedBlockTrackerLogEvent](byteBuffer.array)
     }.toList
@@ -250,7 +254,7 @@ class ReceivedBlockTrackerSuite
     BatchAllocationEvent(time, AllocatedBlocks(Map((streamId -> blockInfos))))
   }
 
-  /** Create batch cleanup object from the given info */
+  /** Create batch clean object from the given info */
   def createBatchCleanup(time: Long, moreTimes: Long*): BatchCleanupEvent = {
     BatchCleanupEvent((Seq(time) ++ moreTimes).map(Time.apply))
   }
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
index b84129fd70..393a360cfe 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
@@ -225,7 +225,7 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable {
       .setAppName(framework)
       .set("spark.ui.enabled", "true")
       .set("spark.streaming.receiver.writeAheadLog.enable", "true")
-      .set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1")
+      .set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1")
     val batchDuration = Milliseconds(500)
     val tempDirectory = Utils.createTempDir()
     val logDirectory1 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 0))
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index 58353a5f97..09440b1e79 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -363,7 +363,7 @@ class TestReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging
   }
 
   def onStop() {
-    // no cleanup to be done, the receiving thread should stop on it own
+    // no clean to be done, the receiving thread should stop on it own
   }
 }
 
@@ -396,7 +396,7 @@ class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int)
   def onStop() {
     // Simulate slow receiver by waiting for all records to be produced
     while(!SlowTestReceiver.receivedAllRecords) Thread.sleep(100)
-    // no cleanup to be done, the receiving thread should stop on it own
+    // no clean to be done, the receiving thread should stop on it own
   }
 }
 
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
index c3602a5b73..8b300d8dd3 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala
@@ -21,12 +21,12 @@ import java.io.File
 import scala.util.Random
 
 import org.apache.hadoop.conf.Configuration
-import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, FunSuite}
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
 
-import org.apache.spark.{SparkConf, SparkContext}
 import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId}
-import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogWriter}
+import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter}
 import org.apache.spark.util.Utils
+import org.apache.spark.{SparkConf, SparkContext}
 
 class WriteAheadLogBackedBlockRDDSuite
   extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach {
@@ -100,9 +100,10 @@ class WriteAheadLogBackedBlockRDDSuite
       blockManager.putIterator(blockId, block.iterator, StorageLevel.MEMORY_ONLY_SER)
     }
 
-    // Generate write ahead log segments
-    val segments = generateFakeSegments(numPartitionsInBM) ++
-      writeLogSegments(data.takeRight(numPartitionsInWAL), blockIds.takeRight(numPartitionsInWAL))
+    // Generate write ahead log file segments
+    val recordHandles = generateFakeRecordHandles(numPartitionsInBM) ++
+      generateWALRecordHandles(data.takeRight(numPartitionsInWAL),
+        blockIds.takeRight(numPartitionsInWAL))
 
     // Make sure that the left `numPartitionsInBM` blocks are in block manager, and others are not
     require(
@@ -116,24 +117,24 @@ class WriteAheadLogBackedBlockRDDSuite
 
     // Make sure that the right `numPartitionsInWAL` blocks are in WALs, and other are not
     require(
-      segments.takeRight(numPartitionsInWAL).forall(s =>
+      recordHandles.takeRight(numPartitionsInWAL).forall(s =>
         new File(s.path.stripPrefix("file://")).exists()),
       "Expected blocks not in write ahead log"
     )
     require(
-      segments.take(numPartitionsInBM).forall(s =>
+      recordHandles.take(numPartitionsInBM).forall(s =>
         !new File(s.path.stripPrefix("file://")).exists()),
       "Unexpected blocks in write ahead log"
     )
 
     // Create the RDD and verify whether the returned data is correct
     val rdd = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray,
-      segments.toArray, storeInBlockManager = false, StorageLevel.MEMORY_ONLY)
+      recordHandles.toArray, storeInBlockManager = false, StorageLevel.MEMORY_ONLY)
     assert(rdd.collect() === data.flatten)
 
     if (testStoreInBM) {
       val rdd2 = new WriteAheadLogBackedBlockRDD[String](sparkContext, blockIds.toArray,
-        segments.toArray, storeInBlockManager = true, StorageLevel.MEMORY_ONLY)
+        recordHandles.toArray, storeInBlockManager = true, StorageLevel.MEMORY_ONLY)
       assert(rdd2.collect() === data.flatten)
       assert(
         blockIds.forall(blockManager.get(_).nonEmpty),
@@ -142,12 +143,12 @@ class WriteAheadLogBackedBlockRDDSuite
     }
   }
 
-  private def writeLogSegments(
+  private def generateWALRecordHandles(
       blockData: Seq[Seq[String]],
       blockIds: Seq[BlockId]
-    ): Seq[WriteAheadLogFileSegment] = {
+    ): Seq[FileBasedWriteAheadLogSegment] = {
     require(blockData.size === blockIds.size)
-    val writer = new WriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf)
+    val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf)
     val segments = blockData.zip(blockIds).map { case (data, id) =>
       writer.write(blockManager.dataSerialize(id, data.iterator))
     }
@@ -155,7 +156,7 @@ class WriteAheadLogBackedBlockRDDSuite
     segments
   }
 
-  private def generateFakeSegments(count: Int): Seq[WriteAheadLogFileSegment] = {
-    Array.fill(count)(new WriteAheadLogFileSegment("random", 0L, 0))
+  private def generateFakeRecordHandles(count: Int): Seq[FileBasedWriteAheadLogSegment] = {
+    Array.fill(count)(new FileBasedWriteAheadLogSegment("random", 0L, 0))
   }
 }
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
index a3919c43b9..79098bcf48 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
@@ -18,33 +18,38 @@ package org.apache.spark.streaming.util
 
 import java.io._
 import java.nio.ByteBuffer
+import java.util
 
 import scala.collection.mutable.ArrayBuffer
 import scala.concurrent.duration._
 import scala.language.{implicitConversions, postfixOps}
+import scala.reflect.ClassTag
 
-import WriteAheadLogSuite._
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.Path
-import org.apache.spark.util.{ManualClock, Utils}
-import org.scalatest.{BeforeAndAfter, FunSuite}
 import org.scalatest.concurrent.Eventually._
+import org.scalatest.{BeforeAndAfter, FunSuite}
+
+import org.apache.spark.util.{ManualClock, Utils}
+import org.apache.spark.{SparkConf, SparkException}
 
 class WriteAheadLogSuite extends FunSuite with BeforeAndAfter {
 
+  import WriteAheadLogSuite._
+  
   val hadoopConf = new Configuration()
   var tempDir: File = null
   var testDir: String = null
   var testFile: String = null
-  var manager: WriteAheadLogManager = null
+  var writeAheadLog: FileBasedWriteAheadLog = null
 
   before {
     tempDir = Utils.createTempDir()
     testDir = tempDir.toString
     testFile = new File(tempDir, "testFile").toString
-    if (manager != null) {
-      manager.stop()
-      manager = null
+    if (writeAheadLog != null) {
+      writeAheadLog.close()
+      writeAheadLog = null
     }
   }
 
@@ -52,16 +57,60 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter {
     Utils.deleteRecursively(tempDir)
   }
 
-  test("WriteAheadLogWriter - writing data") {
+  test("WriteAheadLogUtils - log selection and creation") {
+    val logDir = Utils.createTempDir().getAbsolutePath()
+
+    def assertDriverLogClass[T <: WriteAheadLog: ClassTag](conf: SparkConf): WriteAheadLog = {
+      val log = WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf)
+      assert(log.getClass === implicitly[ClassTag[T]].runtimeClass)
+      log
+    }
+
+    def assertReceiverLogClass[T: ClassTag](conf: SparkConf): WriteAheadLog = {
+      val log = WriteAheadLogUtils.createLogForReceiver(conf, logDir, hadoopConf)
+      assert(log.getClass === implicitly[ClassTag[T]].runtimeClass)
+      log
+    }
+
+    val emptyConf = new SparkConf()  // no log configuration
+    assertDriverLogClass[FileBasedWriteAheadLog](emptyConf)
+    assertReceiverLogClass[FileBasedWriteAheadLog](emptyConf)
+
+    // Verify setting driver WAL class
+    val conf1 = new SparkConf().set("spark.streaming.driver.writeAheadLog.class",
+      classOf[MockWriteAheadLog0].getName())
+    assertDriverLogClass[MockWriteAheadLog0](conf1)
+    assertReceiverLogClass[FileBasedWriteAheadLog](conf1)
+
+    // Verify setting receiver WAL class
+    val receiverWALConf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class",
+      classOf[MockWriteAheadLog0].getName())
+    assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf)
+    assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf)
+
+    // Verify setting receiver WAL class with 1-arg constructor
+    val receiverWALConf2 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class",
+      classOf[MockWriteAheadLog1].getName())
+    assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf2)
+
+    // Verify failure setting receiver WAL class with 2-arg constructor
+    intercept[SparkException] {
+      val receiverWALConf3 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class",
+        classOf[MockWriteAheadLog2].getName())
+      assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf3)
+    }
+  }
+
+  test("FileBasedWriteAheadLogWriter - writing data") {
     val dataToWrite = generateRandomData()
     val segments = writeDataUsingWriter(testFile, dataToWrite)
     val writtenData = readDataManually(segments)
     assert(writtenData === dataToWrite)
   }
 
-  test("WriteAheadLogWriter - syncing of data by writing and reading immediately") {
+  test("FileBasedWriteAheadLogWriter - syncing of data by writing and reading immediately") {
     val dataToWrite = generateRandomData()
-    val writer = new WriteAheadLogWriter(testFile, hadoopConf)
+    val writer = new FileBasedWriteAheadLogWriter(testFile, hadoopConf)
     dataToWrite.foreach { data =>
       val segment = writer.write(stringToByteBuffer(data))
       val dataRead = readDataManually(Seq(segment)).head
@@ -70,10 +119,10 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter {
     writer.close()
   }
 
-  test("WriteAheadLogReader - sequentially reading data") {
+  test("FileBasedWriteAheadLogReader - sequentially reading data") {
     val writtenData = generateRandomData()
     writeDataManually(writtenData, testFile)
-    val reader = new WriteAheadLogReader(testFile, hadoopConf)
+    val reader = new FileBasedWriteAheadLogReader(testFile, hadoopConf)
     val readData = reader.toSeq.map(byteBufferToString)
     assert(readData === writtenData)
     assert(reader.hasNext === false)
@@ -83,14 +132,14 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter {
     reader.close()
   }
 
-  test("WriteAheadLogReader - sequentially reading data written with writer") {
+  test("FileBasedWriteAheadLogReader - sequentially reading data written with writer") {
     val dataToWrite = generateRandomData()
     writeDataUsingWriter(testFile, dataToWrite)
     val readData = readDataUsingReader(testFile)
     assert(readData === dataToWrite)
   }
 
-  test("WriteAheadLogReader - reading data written with writer after corrupted write") {
+  test("FileBasedWriteAheadLogReader - reading data written with writer after corrupted write") {
     // Write data manually for testing the sequential reader
     val dataToWrite = generateRandomData()
     writeDataUsingWriter(testFile, dataToWrite)
@@ -113,38 +162,38 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter {
     assert(readDataUsingReader(testFile) === (dataToWrite.dropRight(1)))
   }
 
-  test("WriteAheadLogRandomReader - reading data using random reader") {
+  test("FileBasedWriteAheadLogRandomReader - reading data using random reader") {
     // Write data manually for testing the random reader
     val writtenData = generateRandomData()
     val segments = writeDataManually(writtenData, testFile)
 
     // Get a random order of these segments and read them back
     val writtenDataAndSegments = writtenData.zip(segments).toSeq.permutations.take(10).flatten
-    val reader = new WriteAheadLogRandomReader(testFile, hadoopConf)
+    val reader = new FileBasedWriteAheadLogRandomReader(testFile, hadoopConf)
     writtenDataAndSegments.foreach { case (data, segment) =>
       assert(data === byteBufferToString(reader.read(segment)))
     }
     reader.close()
   }
 
-  test("WriteAheadLogRandomReader - reading data using random reader written with writer") {
+  test("FileBasedWriteAheadLogRandomReader- reading data using random reader written with writer") {
     // Write data using writer for testing the random reader
     val data = generateRandomData()
     val segments = writeDataUsingWriter(testFile, data)
 
     // Read a random sequence of segments and verify read data
     val dataAndSegments = data.zip(segments).toSeq.permutations.take(10).flatten
-    val reader = new WriteAheadLogRandomReader(testFile, hadoopConf)
+    val reader = new FileBasedWriteAheadLogRandomReader(testFile, hadoopConf)
     dataAndSegments.foreach { case (data, segment) =>
       assert(data === byteBufferToString(reader.read(segment)))
     }
     reader.close()
   }
 
-  test("WriteAheadLogManager - write rotating logs") {
-    // Write data using manager
+  test("FileBasedWriteAheadLog - write rotating logs") {
+    // Write data with rotation using WriteAheadLog class
     val dataToWrite = generateRandomData()
-    writeDataUsingManager(testDir, dataToWrite)
+    writeDataUsingWriteAheadLog(testDir, dataToWrite)
 
     // Read data manually to verify the written data
     val logFiles = getLogFilesInDirectory(testDir)
@@ -153,8 +202,8 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter {
     assert(writtenData === dataToWrite)
   }
 
-  test("WriteAheadLogManager - read rotating logs") {
-    // Write data manually for testing reading through manager
+  test("FileBasedWriteAheadLog - read rotating logs") {
+    // Write data manually for testing reading through WriteAheadLog
     val writtenData = (1 to 10).map { i =>
       val data = generateRandomData()
       val file = testDir + s"/log-$i-$i"
@@ -167,25 +216,25 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter {
     assert(fileSystem.exists(logDirectoryPath) === true)
 
     // Read data using manager and verify
-    val readData = readDataUsingManager(testDir)
+    val readData = readDataUsingWriteAheadLog(testDir)
     assert(readData === writtenData)
   }
 
-  test("WriteAheadLogManager - recover past logs when creating new manager") {
+  test("FileBasedWriteAheadLog - recover past logs when creating new manager") {
     // Write data with manager, recover with new manager and verify
     val dataToWrite = generateRandomData()
-    writeDataUsingManager(testDir, dataToWrite)
+    writeDataUsingWriteAheadLog(testDir, dataToWrite)
     val logFiles = getLogFilesInDirectory(testDir)
     assert(logFiles.size > 1)
-    val readData = readDataUsingManager(testDir)
+    val readData = readDataUsingWriteAheadLog(testDir)
     assert(dataToWrite === readData)
   }
 
-  test("WriteAheadLogManager - cleanup old logs") {
+  test("FileBasedWriteAheadLog - clean old logs") {
     logCleanUpTest(waitForCompletion = false)
   }
 
-  test("WriteAheadLogManager - cleanup old logs synchronously") {
+  test("FileBasedWriteAheadLog - clean old logs synchronously") {
     logCleanUpTest(waitForCompletion = true)
   }
 
@@ -193,11 +242,11 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter {
     // Write data with manager, recover with new manager and verify
     val manualClock = new ManualClock
     val dataToWrite = generateRandomData()
-    manager = writeDataUsingManager(testDir, dataToWrite, manualClock, stopManager = false)
+    writeAheadLog = writeDataUsingWriteAheadLog(testDir, dataToWrite, manualClock, closeLog = false)
     val logFiles = getLogFilesInDirectory(testDir)
     assert(logFiles.size > 1)
 
-    manager.cleanupOldLogs(manualClock.getTimeMillis() / 2, waitForCompletion)
+    writeAheadLog.clean(manualClock.getTimeMillis() / 2, waitForCompletion)
 
     if (waitForCompletion) {
       assert(getLogFilesInDirectory(testDir).size < logFiles.size)
@@ -208,11 +257,11 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter {
     }
   }
 
-  test("WriteAheadLogManager - handling file errors while reading rotating logs") {
+  test("FileBasedWriteAheadLog - handling file errors while reading rotating logs") {
     // Generate a set of log files
     val manualClock = new ManualClock
     val dataToWrite1 = generateRandomData()
-    writeDataUsingManager(testDir, dataToWrite1, manualClock)
+    writeDataUsingWriteAheadLog(testDir, dataToWrite1, manualClock)
     val logFiles1 = getLogFilesInDirectory(testDir)
     assert(logFiles1.size > 1)
 
@@ -220,12 +269,12 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter {
     // Recover old files and generate a second set of log files
     val dataToWrite2 = generateRandomData()
     manualClock.advance(100000)
-    writeDataUsingManager(testDir, dataToWrite2, manualClock)
+    writeDataUsingWriteAheadLog(testDir, dataToWrite2, manualClock)
     val logFiles2 = getLogFilesInDirectory(testDir)
     assert(logFiles2.size > logFiles1.size)
 
     // Read the files and verify that all the written data can be read
-    val readData1 = readDataUsingManager(testDir)
+    val readData1 = readDataUsingWriteAheadLog(testDir)
     assert(readData1 === (dataToWrite1 ++ dataToWrite2))
 
     // Corrupt the first set of files so that they are basically unreadable
@@ -236,25 +285,51 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter {
     }
 
     // Verify that the corrupted files do not prevent reading of the second set of data
-    val readData = readDataUsingManager(testDir)
+    val readData = readDataUsingWriteAheadLog(testDir)
     assert(readData === dataToWrite2)
   }
+
+  test("FileBasedWriteAheadLog - do not create directories or files unless write") {
+    val nonexistentTempPath = File.createTempFile("test", "")
+    nonexistentTempPath.delete()
+    assert(!nonexistentTempPath.exists())
+
+    val writtenSegment = writeDataManually(generateRandomData(), testFile)
+    val wal = new FileBasedWriteAheadLog(
+      new SparkConf(), tempDir.getAbsolutePath, new Configuration(), 1, 1)
+    assert(!nonexistentTempPath.exists(), "Directory created just by creating log object")
+    wal.read(writtenSegment.head)
+    assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment")
+  }
 }
 
 object WriteAheadLogSuite {
 
+  class MockWriteAheadLog0() extends WriteAheadLog {
+    override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { null }
+    override def read(handle: WriteAheadLogRecordHandle): ByteBuffer = { null }
+    override def readAll(): util.Iterator[ByteBuffer] = { null }
+    override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { }
+    override def close(): Unit = { }
+  }
+
+  class MockWriteAheadLog1(val conf: SparkConf) extends MockWriteAheadLog0()
+
+  class MockWriteAheadLog2(val conf: SparkConf, x: Int) extends MockWriteAheadLog0()
+
+
   private val hadoopConf = new Configuration()
 
   /** Write data to a file directly and return an array of the file segments written. */
-  def writeDataManually(data: Seq[String], file: String): Seq[WriteAheadLogFileSegment] = {
-    val segments = new ArrayBuffer[WriteAheadLogFileSegment]()
+  def writeDataManually(data: Seq[String], file: String): Seq[FileBasedWriteAheadLogSegment] = {
+    val segments = new ArrayBuffer[FileBasedWriteAheadLogSegment]()
     val writer = HdfsUtils.getOutputStream(file, hadoopConf)
     data.foreach { item =>
       val offset = writer.getPos
       val bytes = Utils.serialize(item)
       writer.writeInt(bytes.size)
       writer.write(bytes)
-      segments += WriteAheadLogFileSegment(file, offset, bytes.size)
+      segments += FileBasedWriteAheadLogSegment(file, offset, bytes.size)
     }
     writer.close()
     segments
@@ -263,8 +338,11 @@ object WriteAheadLogSuite {
   /**
    * Write data to a file using the writer class and return an array of the file segments written.
    */
-  def writeDataUsingWriter(filePath: String, data: Seq[String]): Seq[WriteAheadLogFileSegment] = {
-    val writer = new WriteAheadLogWriter(filePath, hadoopConf)
+  def writeDataUsingWriter(
+      filePath: String,
+      data: Seq[String]
+    ): Seq[FileBasedWriteAheadLogSegment] = {
+    val writer = new FileBasedWriteAheadLogWriter(filePath, hadoopConf)
     val segments = data.map {
       item => writer.write(item)
     }
@@ -272,27 +350,27 @@ object WriteAheadLogSuite {
     segments
   }
 
-  /** Write data to rotating files in log directory using the manager class. */
-  def writeDataUsingManager(
+  /** Write data to rotating files in log directory using the WriteAheadLog class. */
+  def writeDataUsingWriteAheadLog(
       logDirectory: String,
       data: Seq[String],
       manualClock: ManualClock = new ManualClock,
-      stopManager: Boolean = true
-    ): WriteAheadLogManager = {
+      closeLog: Boolean = true
+    ): FileBasedWriteAheadLog = {
     if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000)
-    val manager = new WriteAheadLogManager(logDirectory, hadoopConf,
-      rollingIntervalSecs = 1, callerName = "WriteAheadLogSuite", clock = manualClock)
+    val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1)
+    
     // Ensure that 500 does not get sorted after 2000, so put a high base value.
     data.foreach { item =>
       manualClock.advance(500)
-      manager.writeToLog(item)
+      wal.write(item, manualClock.getTimeMillis())
     }
-    if (stopManager) manager.stop()
-    manager
+    if (closeLog) wal.close()
+    wal
   }
 
   /** Read data from a segments of a log file directly and return the list of byte buffers. */
-  def readDataManually(segments: Seq[WriteAheadLogFileSegment]): Seq[String] = {
+  def readDataManually(segments: Seq[FileBasedWriteAheadLogSegment]): Seq[String] = {
     segments.map { segment =>
       val reader = HdfsUtils.getInputStream(segment.path, hadoopConf)
       try {
@@ -331,18 +409,18 @@ object WriteAheadLogSuite {
 
   /** Read all the data from a log file using reader class and return the list of byte buffers. */
   def readDataUsingReader(file: String): Seq[String] = {
-    val reader = new WriteAheadLogReader(file, hadoopConf)
+    val reader = new FileBasedWriteAheadLogReader(file, hadoopConf)
     val readData = reader.toList.map(byteBufferToString)
     reader.close()
     readData
   }
 
-  /** Read all the data in the log file in a directory using the manager class. */
-  def readDataUsingManager(logDirectory: String): Seq[String] = {
-    val manager = new WriteAheadLogManager(logDirectory, hadoopConf,
-      callerName = "WriteAheadLogSuite")
-    val data = manager.readFromLog().map(byteBufferToString).toSeq
-    manager.stop()
+  /** Read all the data in the log file in a directory using the WriteAheadLog class. */
+  def readDataUsingWriteAheadLog(logDirectory: String): Seq[String] = {
+    import scala.collection.JavaConversions._
+    val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1)
+    val data = wal.readAll().map(byteBufferToString).toSeq
+    wal.close()
     data
   }
 
-- 
GitLab