From 9d3d9c8251724712590f3178e69e78ea0b750e9c Mon Sep 17 00:00:00 2001
From: Tathagata Das <tathagata.das1565@gmail.com>
Date: Fri, 10 Jan 2014 11:44:02 +0000
Subject: [PATCH] Refactored graph checkpoint file reading and writing code to
 make it cleaner and easily debuggable.

---
 .../apache/spark/streaming/Checkpoint.scala   | 150 ++++++++++++------
 .../streaming/dstream/FileInputDStream.scala  |   1 +
 2 files changed, 102 insertions(+), 49 deletions(-)

diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index d268b68f90..7366d8a7a4 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -53,6 +53,55 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
   }
 }
 
+private[streaming]  
+object Checkpoint extends Logging {
+  val PREFIX = "checkpoint-"
+  val REGEX = (PREFIX + """([\d]+)([\w\.]*)""").r
+
+  /** Get the checkpoint file for the given checkpoint time */
+  def checkpointFile(checkpointDir: String, checkpointTime: Time) = {
+    new Path(checkpointDir, PREFIX + checkpointTime.milliseconds)
+  }
+
+  /** Get the checkpoint backup file for the given checkpoint time */
+  def checkpointBackupFile(checkpointDir: String, checkpointTime: Time) = {
+    new Path(checkpointDir, PREFIX + checkpointTime.milliseconds + ".bk")
+  }
+
+  /** Get checkpoint files present in the give directory, ordered by oldest-first */
+  def getCheckpointFiles(checkpointDir: String, fs: FileSystem): Seq[Path] = {
+    def sortFunc(path1: Path, path2: Path): Boolean = {
+      val (time1, bk1) = path1.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) }
+      val (time2, bk2) = path2.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) }
+      logInfo("Path 1: " + path1 + " -> " + time1 + ", " + bk1)
+      logInfo("Path 2: " + path2 + " -> " + time2 + ", " + bk2)
+      val precede = (time1 < time2) || (time1 == time2 && bk1) 
+      logInfo(precede.toString)
+      precede
+    }
+
+    val path = new Path(checkpointDir)
+    if (fs.exists(path)) {
+      val statuses = fs.listStatus(path)
+      if (statuses != null) {
+        val paths = statuses.map(_.getPath)
+        logInfo("Paths = " + paths.map(_.getName).mkString(", "))
+        val filtered = paths.filter(p => REGEX.findFirstIn(p.toString).nonEmpty)
+        logInfo("Filtered paths = " + filtered.map(_.getName).mkString(", "))
+        val sorted = filtered.sortWith(sortFunc)
+        logInfo("Sorted paths = " + sorted.map(_.getName).mkString(", "))
+        sorted
+      } else {
+        logWarning("Listing " + path + " returned null")
+        Seq.empty
+      }
+    } else {
+      logInfo("Checkpoint directory " + path + " does not exist")
+      Seq.empty 
+    }
+  }
+}
+
 
 /**
  * Convenience class to handle the writing of graph checkpoint to file
@@ -60,14 +109,13 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
 private[streaming]
 class CheckpointWriter(jobGenerator: JobGenerator, conf: SparkConf, checkpointDir: String, hadoopConf: Configuration)
   extends Logging {
-  val file = new Path(checkpointDir, "graph")
   val MAX_ATTEMPTS = 3
   val executor = Executors.newFixedThreadPool(1)
   val compressionCodec = CompressionCodec.createCodec(conf)
   // The file to which we actually write - and then "move" to file
-  val writeFile = new Path(file.getParent, file.getName + ".next")
+  // val writeFile = new Path(file.getParent, file.getName + ".next")
   // The file to which existing checkpoint is backed up (i.e. "moved")
-  val bakFile = new Path(file.getParent, file.getName + ".bk")
+  // val bakFile = new Path(file.getParent, file.getName + ".bk")
 
   private var stopped = false
   private var fs_ : FileSystem = _
@@ -78,40 +126,57 @@ class CheckpointWriter(jobGenerator: JobGenerator, conf: SparkConf, checkpointDi
     def run() {
       var attempts = 0
       val startTime = System.currentTimeMillis()
+      val tempFile = new Path(checkpointDir, "temp")
+      val checkpointFile = Checkpoint.checkpointFile(checkpointDir, checkpointTime)
+      val backupFile = Checkpoint.checkpointBackupFile(checkpointDir, checkpointTime)
+
       while (attempts < MAX_ATTEMPTS && !stopped) {
         attempts += 1
         try {
-          logInfo("Saving checkpoint for time " + checkpointTime + " to file '" + file + "'")
-          // This is inherently thread unsafe, so alleviating it by writing to '.next' and
-          // then moving it to the final file
-          val fos = fs.create(writeFile)
+          logInfo("Saving checkpoint for time " + checkpointTime + " to file '" + checkpointFile + "'")
+
+          // Write checkpoint to temp file
+          fs.delete(tempFile, true)   // just in case it exists
+          val fos = fs.create(tempFile)
           fos.write(bytes)
           fos.close()
 
-          // Back up existing checkpoint if it exists
-          if (fs.exists(file) && fs.rename(file, bakFile)) {
-            logDebug("Moved existing checkpoint file to " + bakFile)
+          // If the checkpoint file exists, back it up
+          // If the backup exists as well, just delete it, otherwise rename will fail
+          if (fs.exists(checkpointFile)) {
+            fs.delete(backupFile, true) // just in case it exists
+            if (!fs.rename(checkpointFile, backupFile)) {
+              logWarning("Could not rename " + checkpointFile + " to " + backupFile)
+            }
+          }
+
+          // Rename temp file to the final checkpoint file
+          if (!fs.rename(tempFile, checkpointFile)) {
+            logWarning("Could not rename " + tempFile + " to " + checkpointFile)
           }
-          fs.delete(file, false) // paranoia
-
-          // Rename temp written file to the right location
-          if (fs.rename(writeFile, file)) {
-            val finishTime = System.currentTimeMillis()
-            logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file +
-              "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " ms")
-            jobGenerator.onCheckpointCompletion(checkpointTime)
-          } else {
-            throw new SparkException("Failed to rename checkpoint file from "
-              + writeFile + " to " + file)
+
+          // Delete old checkpoint files
+          val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs)
+          if (allCheckpointFiles.size > 4) {
+            allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => {
+              logInfo("Deleting " + file)
+              fs.delete(file, true)
+            })
           }
+
+          // All done, print success 
+          val finishTime = System.currentTimeMillis()
+          logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + checkpointFile +
+            "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " ms")
+          jobGenerator.onCheckpointCompletion(checkpointTime)
           return
         } catch {
           case ioe: IOException =>
-            logWarning("Error writing checkpoint to file in " + attempts + " attempts", ioe)
+            logWarning("Error in attempt " + attempts + " of writing checkpoint to " + checkpointFile, ioe)
             reset()
         }
       }
-      logError("Could not write checkpoint for time " + checkpointTime + " to file '" + file + "'")
+      logWarning("Could not write checkpoint for time " + checkpointTime + " to file " + checkpointFile + "'")
     }
   }
 
@@ -147,7 +212,7 @@ class CheckpointWriter(jobGenerator: JobGenerator, conf: SparkConf, checkpointDi
   }
 
   private def fs = synchronized {
-    if (fs_ == null) fs_ = file.getFileSystem(hadoopConf)
+    if (fs_ == null) fs_ = new Path(checkpointDir).getFileSystem(hadoopConf)
     fs_
   }
 
@@ -160,36 +225,21 @@ class CheckpointWriter(jobGenerator: JobGenerator, conf: SparkConf, checkpointDi
 private[streaming]
 object CheckpointReader extends Logging {
 
-  private val graphFileNames = Seq("graph", "graph.bk")
-  
   def read(checkpointDir: String, conf: SparkConf, hadoopConf: Configuration): Option[Checkpoint] = {
     val checkpointPath = new Path(checkpointDir)
     def fs = checkpointPath.getFileSystem(hadoopConf)
     
-        // See if the checkpoint directory exists
-    if (!fs.exists(checkpointPath)) {
-      logInfo("Could not load checkpoint as path '" + checkpointPath + "' does not exist")
+    // Try to find the checkpoint files 
+    val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse
+    if (checkpointFiles.isEmpty) {
       return None
     }
 
-    // Try to find the checkpoint data
-    val existingFiles = graphFileNames.map(new Path(checkpointPath, _)).filter(fs.exists)
-    if (existingFiles.isEmpty) {
-      logInfo("Could not load checkpoint as checkpoint data was not " +
-        "found in directory " + checkpointDir + "")
-      val statuses = fs.listStatus(checkpointPath)
-      if (statuses!=null) {
-        logInfo("Checkpoint directory " + checkpointDir + " contains the files:\n" +
-          statuses.mkString("\n"))
-      }
-      return None
-    }
-    logInfo("Checkpoint files found: " + existingFiles.mkString(","))
-
-    // Try to read the checkpoint data
+    // Try to read the checkpoint files in the order  
+    logInfo("Checkpoint files found: " + checkpointFiles.mkString(","))
     val compressionCodec = CompressionCodec.createCodec(conf)
-    existingFiles.foreach(file => {
-      logInfo("Attempting to load checkpoint from file '" + file + "'")
+    checkpointFiles.foreach(file => {
+      logInfo("Attempting to load checkpoint from file " + file)
       try {
         val fis = fs.open(file)
         // ObjectInputStream uses the last defined user-defined class loader in the stack
@@ -204,15 +254,17 @@ object CheckpointReader extends Logging {
         ois.close()
         fs.close()
         cp.validate()
-        logInfo("Checkpoint successfully loaded from file '" + file + "'")
+        logInfo("Checkpoint successfully loaded from file " + file)
         logInfo("Checkpoint was generated at time " + cp.checkpointTime)
         return Some(cp)
       } catch {
         case e: Exception =>
-          logWarning("Error reading checkpoint from file '" + file + "'", e)
+          logWarning("Error reading checkpoint from file " + file, e)
       }
     })
-    throw new SparkException("Failed to read checkpoint from directory '" + checkpointDir + "'")
+
+    // If none of checkpoint files could be read, then throw exception
+    throw new SparkException("Failed to read checkpoint from directory " + checkpointPath)
   }
 }
 
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
index 4585e3f6bd..a79fe523a6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
@@ -113,6 +113,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas
     val newFiles = fs.listStatus(directoryPath, filter).map(_.getPath.toString)
     val timeTaken = System.currentTimeMillis - lastNewFileFindingTime
     logInfo("Finding new files took " + timeTaken + " ms")
+    logDebug("# cached file times = " + fileModTimes.size)
     if (timeTaken > slideDuration.milliseconds) {
       logWarning(
         "Time taken to find new files exceeds the batch size. " +
-- 
GitLab