From 0b7a132d03d5a0106d85a8cca1ab28d6af9c8b55 Mon Sep 17 00:00:00 2001
From: Tathagata Das <tathagata.das1565@gmail.com>
Date: Wed, 8 Jan 2014 03:22:06 -0800
Subject: [PATCH] Modified checkpoing file clearing policy.

---
 .../apache/spark/streaming/Checkpoint.scala   |  7 ++-
 .../org/apache/spark/streaming/DStream.scala  | 15 +++--
 .../streaming/DStreamCheckpointData.scala     | 63 +++++++++++++------
 .../apache/spark/streaming/DStreamGraph.scala | 30 +++++----
 .../streaming/dstream/FileInputDStream.scala  |  8 +--
 .../streaming/scheduler/JobGenerator.scala    | 23 +++++--
 .../spark/streaming/CheckpointSuite.scala     | 10 +--
 7 files changed, 104 insertions(+), 52 deletions(-)

diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index 09b184b9cf..155d5bc02e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -28,6 +28,7 @@ import org.apache.spark.{SparkException, Logging}
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.util.MetadataCleaner
 import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.streaming.scheduler.JobGenerator
 
 
 private[streaming]
@@ -58,7 +59,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
  * Convenience class to handle the writing of graph checkpoint to file
  */
 private[streaming]
-class CheckpointWriter(checkpointDir: String, hadoopConf: Configuration) extends Logging {
+class CheckpointWriter(jobGenerator: JobGenerator, checkpointDir: String, hadoopConf: Configuration) extends Logging {
   val file = new Path(checkpointDir, "graph")
   val MAX_ATTEMPTS = 3
   val executor = Executors.newFixedThreadPool(1)
@@ -80,7 +81,7 @@ class CheckpointWriter(checkpointDir: String, hadoopConf: Configuration) extends
       while (attempts < MAX_ATTEMPTS && !stopped) {
         attempts += 1
         try {
-          logDebug("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'")
+          logInfo("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'")
           // This is inherently thread unsafe, so alleviating it by writing to '.new' and
           // then moving it to the final file
           val fos = fs.create(writeFile)
@@ -96,6 +97,7 @@ class CheckpointWriter(checkpointDir: String, hadoopConf: Configuration) extends
           val finishTime = System.currentTimeMillis()
           logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file +
             "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " milliseconds")
+          jobGenerator.onCheckpointCompletion(checkpointTime)
           return
         } catch {
           case ioe: IOException =>
@@ -116,6 +118,7 @@ class CheckpointWriter(checkpointDir: String, hadoopConf: Configuration) extends
     bos.close()
     try {
       executor.execute(new CheckpointWriteHandler(checkpoint.checkpointTime, bos.toByteArray))
+      logInfo("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue")
     } catch {
       case rej: RejectedExecutionException =>
         logError("Could not submit checkpoint task to the thread pool executor", rej)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
index a78d3965ee..20074249d7 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
@@ -331,13 +331,12 @@ abstract class DStream[T: ClassTag] (
    * implementation clears the old generated RDDs. Subclasses of DStream may override
    * this to clear their own metadata along with the generated RDDs.
    */
-  protected[streaming] def clearOldMetadata(time: Time) {
-    var numForgotten = 0
+  protected[streaming] def clearMetadata(time: Time) {
     val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration))
     generatedRDDs --= oldRDDs.keys
     logInfo("Cleared " + oldRDDs.size + " RDDs that were older than " +
       (time - rememberDuration) + ": " + oldRDDs.keys.mkString(", "))
-    dependencies.foreach(_.clearOldMetadata(time))
+    dependencies.foreach(_.clearMetadata(time))
   }
 
   /* Adds metadata to the Stream while it is running.
@@ -358,12 +357,18 @@ abstract class DStream[T: ClassTag] (
    */
   protected[streaming] def updateCheckpointData(currentTime: Time) {
     logInfo("Updating checkpoint data for time " + currentTime)
-    checkpointData.update()
+    checkpointData.update(currentTime)
     dependencies.foreach(_.updateCheckpointData(currentTime))
-    checkpointData.cleanup()
     logDebug("Updated checkpoint data for time " + currentTime + ": " + checkpointData)
   }
 
+  protected[streaming] def clearCheckpointData(time: Time) {
+    logInfo("Clearing checkpoint data")
+    checkpointData.cleanup(time)
+    dependencies.foreach(_.clearCheckpointData(time))
+    logInfo("Cleared checkpoint data")
+  }
+
   /**
    * Restore the RDDs in generatedRDDs from the checkpointData. This is an internal method
    * that should not be called directly. This is a default implementation that recreates RDDs
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala
index 3fd5d52403..cc2f08a7d1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamCheckpointData.scala
@@ -17,15 +17,16 @@
 
 package org.apache.spark.streaming
 
+import scala.collection.mutable.{HashMap, HashSet}
+import scala.reflect.ClassTag
+
 import org.apache.hadoop.fs.Path
 import org.apache.hadoop.fs.FileSystem
 import org.apache.hadoop.conf.Configuration
 
-import collection.mutable.HashMap
 import org.apache.spark.Logging
 
-import scala.collection.mutable.HashMap
-import scala.reflect.ClassTag
+import java.io.{ObjectInputStream, IOException}
 
 
 private[streaming]
@@ -33,35 +34,35 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
   extends Serializable with Logging {
   protected val data = new HashMap[Time, AnyRef]()
 
+  @transient private var allCheckpointFiles = new HashMap[Time, String]
+  @transient private var timeToLastCheckpointFileTime = new HashMap[Time, Time]
   @transient private var fileSystem : FileSystem = null
-  @transient private var lastCheckpointFiles: HashMap[Time, String] = null
 
-  protected[streaming] def checkpointFiles = data.asInstanceOf[HashMap[Time, String]]
+  //@transient private var lastCheckpointFiles: HashMap[Time, String] = null
+
+  protected[streaming] def currentCheckpointFiles = data.asInstanceOf[HashMap[Time, String]]
 
   /**
    * Updates the checkpoint data of the DStream. This gets called every time
    * the graph checkpoint is initiated. Default implementation records the
    * checkpoint files to which the generate RDDs of the DStream has been saved.
    */
-  def update() {
+  def update(time: Time) {
 
     // Get the checkpointed RDDs from the generated RDDs
-    val newCheckpointFiles = dstream.generatedRDDs.filter(_._2.getCheckpointFile.isDefined)
+    val checkpointFiles = dstream.generatedRDDs.filter(_._2.getCheckpointFile.isDefined)
                                        .map(x => (x._1, x._2.getCheckpointFile.get))
 
     // Make a copy of the existing checkpoint data (checkpointed RDDs)
-    lastCheckpointFiles = checkpointFiles.clone()
+    //lastCheckpointFiles = checkpointFiles.clone()
 
     // If the new checkpoint data has checkpoints then replace existing with the new one
-    if (newCheckpointFiles.size > 0) {
-      checkpointFiles.clear()
-      checkpointFiles ++= newCheckpointFiles
-    }
-
-    // TODO: remove this, this is just for debugging
-    newCheckpointFiles.foreach {
-      case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") }
+    if (currentCheckpointFiles.size > 0) {
+      currentCheckpointFiles.clear()
+      currentCheckpointFiles ++= checkpointFiles
     }
+    allCheckpointFiles ++= currentCheckpointFiles
+    timeToLastCheckpointFileTime(time) = currentCheckpointFiles.keys.min(Time.ordering)
   }
 
   /**
@@ -69,7 +70,8 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
    * checkpoint is initiated, but after `update` is called. Default
    * implementation, cleans up old checkpoint files.
    */
-  def cleanup() {
+  def cleanup(time: Time) {
+    /*
     // If there is at least on checkpoint file in the current checkpoint files,
     // then delete the old checkpoint files.
     if (checkpointFiles.size > 0 && lastCheckpointFiles != null) {
@@ -89,6 +91,23 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
         }
       }
     }
+    */
+    val lastCheckpointFileTime = timeToLastCheckpointFileTime.remove(time).get
+    allCheckpointFiles.filter(_._1 < lastCheckpointFileTime).foreach {
+      case (time, file) =>
+        try {
+          val path = new Path(file)
+          if (fileSystem == null) {
+            fileSystem = path.getFileSystem(dstream.ssc.sparkContext.hadoopConfiguration)
+          }
+          fileSystem.delete(path, true)
+          allCheckpointFiles -= time
+          logInfo("Deleted checkpoint file '" + file + "' for time " + time)
+        } catch {
+          case e: Exception =>
+            logWarning("Error deleting old checkpoint file '" + file + "' for time " + time, e)
+        }
+    }
   }
 
   /**
@@ -98,7 +117,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
    */
   def restore() {
     // Create RDDs from the checkpoint data
-    checkpointFiles.foreach {
+    currentCheckpointFiles.foreach {
       case(time, file) => {
         logInfo("Restoring checkpointed RDD for time " + time + " from file '" + file + "'")
         dstream.generatedRDDs += ((time, dstream.context.sparkContext.checkpointFile[T](file)))
@@ -107,6 +126,12 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T])
   }
 
   override def toString() = {
-    "[\n" + checkpointFiles.size + " checkpoint files \n" + checkpointFiles.mkString("\n") + "\n]"
+    "[\n" + currentCheckpointFiles.size + " checkpoint files \n" + currentCheckpointFiles.mkString("\n") + "\n]"
+  }
+
+  @throws(classOf[IOException])
+  private def readObject(ois: ObjectInputStream) {
+    timeToLastCheckpointFileTime = new HashMap[Time, Time]
+    allCheckpointFiles = new HashMap[Time, String]
   }
 }
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index daed7ff7c3..bfedef2e4e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -105,36 +105,44 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
   def getOutputStreams() = this.synchronized { outputStreams.toArray }
 
   def generateJobs(time: Time): Seq[Job] = {
+    logInfo("Generating jobs for time " + time)
     this.synchronized {
-      logInfo("Generating jobs for time " + time)
       val jobs = outputStreams.flatMap(outputStream => outputStream.generateJob(time))
       logInfo("Generated " + jobs.length + " jobs for time " + time)
       jobs
     }
   }
 
-  def clearOldMetadata(time: Time) {
+  def clearMetadata(time: Time) {
+    logInfo("Clearing metadata for time " + time)
     this.synchronized {
-      logInfo("Clearing old metadata for time " + time)
-      outputStreams.foreach(_.clearOldMetadata(time))
-      logInfo("Cleared old metadata for time " + time)
+      outputStreams.foreach(_.clearMetadata(time))
     }
+    logInfo("Cleared old metadata for time " + time)
   }
 
   def updateCheckpointData(time: Time) {
+    logInfo("Updating checkpoint data for time " + time)
     this.synchronized {
-      logInfo("Updating checkpoint data for time " + time)
       outputStreams.foreach(_.updateCheckpointData(time))
-      logInfo("Updated checkpoint data for time " + time)
     }
+    logInfo("Updated checkpoint data for time " + time)
+  }
+
+  def clearCheckpointData(time: Time) {
+    logInfo("Restoring checkpoint data")
+    this.synchronized {
+      outputStreams.foreach(_.clearCheckpointData(time))
+    }
+    logInfo("Restored checkpoint data")
   }
 
   def restoreCheckpointData() {
+    logInfo("Restoring checkpoint data")
     this.synchronized {
-      logInfo("Restoring checkpoint data")
       outputStreams.foreach(_.restoreCheckpointData())
-      logInfo("Restored checkpoint data")
     }
+    logInfo("Restored checkpoint data")
   }
 
   def validate() {
@@ -147,8 +155,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
 
   @throws(classOf[IOException])
   private def writeObject(oos: ObjectOutputStream) {
+    logDebug("DStreamGraph.writeObject used")
     this.synchronized {
-      logDebug("DStreamGraph.writeObject used")
       checkpointInProgress = true
       oos.defaultWriteObject()
       checkpointInProgress = false
@@ -157,8 +165,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
 
   @throws(classOf[IOException])
   private def readObject(ois: ObjectInputStream) {
+    logDebug("DStreamGraph.readObject used")
     this.synchronized {
-      logDebug("DStreamGraph.readObject used")
       checkpointInProgress = true
       ois.defaultReadObject()
       checkpointInProgress = false
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
index 0028422db9..4585e3f6bd 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
@@ -90,8 +90,8 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
   }
 
   /** Clear the old time-to-files mappings along with old RDDs */
-  protected[streaming] override def clearOldMetadata(time: Time) {
-    super.clearOldMetadata(time)
+  protected[streaming] override def clearMetadata(time: Time) {
+    super.clearMetadata(time)
     val oldFiles = files.filter(_._1 <= (time - rememberDuration))
     files --= oldFiles.keys
     logInfo("Cleared " + oldFiles.size + " old files that were older than " +
@@ -172,12 +172,12 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
 
     def hadoopFiles = data.asInstanceOf[HashMap[Time, Array[String]]]
 
-    override def update() {
+    override def update(time: Time) {
       hadoopFiles.clear()
       hadoopFiles ++= files
     }
 
-    override def cleanup() { }
+    override def cleanup(time: Time) { }
 
     override def restore() {
       hadoopFiles.toSeq.sortBy(_._1)(Time.ordering).foreach {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 5f48692df8..6fbe6da921 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -26,8 +26,9 @@ import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock}
 /** Event classes for JobGenerator */
 private[scheduler] sealed trait JobGeneratorEvent
 private[scheduler] case class GenerateJobs(time: Time) extends JobGeneratorEvent
-private[scheduler] case class ClearOldMetadata(time: Time) extends JobGeneratorEvent
+private[scheduler] case class ClearMetadata(time: Time) extends JobGeneratorEvent
 private[scheduler] case class DoCheckpoint(time: Time) extends JobGeneratorEvent
+private[scheduler] case class ClearCheckpointData(time: Time) extends JobGeneratorEvent
 
 /**
  * This class generates jobs from DStreams as well as drives checkpointing and cleaning
@@ -55,7 +56,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
   val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
     longTime => eventProcessorActor ! GenerateJobs(new Time(longTime)))
   lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
-    new CheckpointWriter(ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration)
+    new CheckpointWriter(this, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration)
   } else {
     null
   }
@@ -79,15 +80,20 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
    * On batch completion, clear old metadata and checkpoint computation.
    */
   private[scheduler] def onBatchCompletion(time: Time) {
-    eventProcessorActor ! ClearOldMetadata(time)
+    eventProcessorActor ! ClearMetadata(time)
+  }
+  
+  private[streaming] def onCheckpointCompletion(time: Time) {
+    eventProcessorActor ! ClearCheckpointData(time)
   }
 
   /** Processes all events */
   private def processEvent(event: JobGeneratorEvent) {
     event match {
       case GenerateJobs(time) => generateJobs(time)
-      case ClearOldMetadata(time) => clearOldMetadata(time)
+      case ClearMetadata(time) => clearMetadata(time)
       case DoCheckpoint(time) => doCheckpoint(time)
+      case ClearCheckpointData(time) => clearCheckpointData(time)
     }
   }
 
@@ -143,11 +149,16 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
   }
 
   /** Clear DStream metadata for the given `time`. */
-  private def clearOldMetadata(time: Time) {
-    ssc.graph.clearOldMetadata(time)
+  private def clearMetadata(time: Time) {
+    ssc.graph.clearMetadata(time)
     eventProcessorActor ! DoCheckpoint(time)
   }
 
+  /** Clear DStream checkpoint data for the given `time`. */
+  private def clearCheckpointData(time: Time) {
+    ssc.graph.clearCheckpointData(time)
+  }
+
   /** Perform checkpoint for the give `time`. */
   private def doCheckpoint(time: Time) = synchronized {
     if (checkpointWriter != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 4e25c9566c..53bc24ff7a 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -84,9 +84,9 @@ class CheckpointSuite extends TestSuiteBase {
     ssc.start()
     advanceTimeWithRealDelay(ssc, firstNumBatches)
     logInfo("Checkpoint data of state stream = \n" + stateStream.checkpointData)
-    assert(!stateStream.checkpointData.checkpointFiles.isEmpty,
+    assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty,
       "No checkpointed RDDs in state stream before first failure")
-    stateStream.checkpointData.checkpointFiles.foreach {
+    stateStream.checkpointData.currentCheckpointFiles.foreach {
       case (time, file) => {
         assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time +
             " for state stream before first failure does not exist")
@@ -95,7 +95,7 @@ class CheckpointSuite extends TestSuiteBase {
 
     // Run till a further time such that previous checkpoint files in the stream would be deleted
     // and check whether the earlier checkpoint files are deleted
-    val checkpointFiles = stateStream.checkpointData.checkpointFiles.map(x => new File(x._2))
+    val checkpointFiles = stateStream.checkpointData.currentCheckpointFiles.map(x => new File(x._2))
     advanceTimeWithRealDelay(ssc, secondNumBatches)
     checkpointFiles.foreach(file =>
       assert(!file.exists, "Checkpoint file '" + file + "' was not deleted"))
@@ -114,9 +114,9 @@ class CheckpointSuite extends TestSuiteBase {
     // is present in the checkpoint data or not
     ssc.start()
     advanceTimeWithRealDelay(ssc, 1)
-    assert(!stateStream.checkpointData.checkpointFiles.isEmpty,
+    assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty,
       "No checkpointed RDDs in state stream before second failure")
-    stateStream.checkpointData.checkpointFiles.foreach {
+    stateStream.checkpointData.currentCheckpointFiles.foreach {
       case (time, file) => {
         assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time +
           " for state stream before seconds failure does not exist")
-- 
GitLab