From 16baea62bce62987158acce0595a0916c25b32b2 Mon Sep 17 00:00:00 2001
From: Tathagata Das <tathagata.das1565@gmail.com>
Date: Sun, 10 Feb 2013 19:14:49 -0800
Subject: [PATCH] Fixed bug in CheckpointRDD to prevent exception when the
 original RDD had zero splits.

---
 core/src/main/scala/spark/rdd/CheckpointRDD.scala |  4 ++--
 core/src/test/scala/spark/CheckpointSuite.scala   | 10 ++++++++++
 2 files changed, 12 insertions(+), 2 deletions(-)

diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
index 96b593ba7c..a21338f85f 100644
--- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
@@ -24,8 +24,8 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: Stri
     val dirContents = fs.listStatus(new Path(checkpointPath))
     val splitFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
     val numSplits = splitFiles.size
-    if (!splitFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
-        !splitFiles(numSplits-1).endsWith(CheckpointRDD.splitIdToFile(numSplits-1))) {
+    if (numSplits > 0 && (!splitFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
+        !splitFiles(numSplits-1).endsWith(CheckpointRDD.splitIdToFile(numSplits-1)))) {
       throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
     }
     Array.tabulate(numSplits)(i => new CheckpointRDDSplit(i))
diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala
index 0b74607fb8..4425949f46 100644
--- a/core/src/test/scala/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/spark/CheckpointSuite.scala
@@ -162,6 +162,16 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
       rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
   }
 
+  test("CheckpointRDD with zero partitions") {
+    val rdd = new BlockRDD[Int](sc, Array[String]())
+    assert(rdd.splits.size === 0)
+    assert(rdd.isCheckpointed === false)
+    rdd.checkpoint()
+    assert(rdd.count() === 0)
+    assert(rdd.isCheckpointed === true)
+    assert(rdd.splits.size === 0)
+  }
+
   /**
    * Test checkpointing of the final RDD generated by the given operation. By default,
    * this method tests whether the size of serialized RDD has reduced after checkpointing or not.
-- 
GitLab