From 32f741115bda5d7d7dbfcd9fe827ecbea7303ffa Mon Sep 17 00:00:00 2001
From: Josh Rosen <joshrosen@databricks.com>
Date: Wed, 27 Jan 2016 13:27:32 -0800
Subject: [PATCH] [SPARK-13021][CORE] Fail fast when custom RDDs violate
 RDD.partition's API contract

Spark's `Partition` and `RDD.partitions` APIs have a contract which requires custom implementations of `RDD.partitions` to ensure that for all `x`, `rdd.partitions(x).index == x`; in other words, the `index` reported by a repartition needs to match its position in the partitions array.

If a custom RDD implementation violates this contract, then Spark has the potential to become stuck in an infinite recomputation loop when recomputing a subset of an RDD's partitions, since the tasks that are actually run will not correspond to the missing output partitions that triggered the recomputation. Here's a link to a notebook which demonstrates this problem: https://rawgit.com/JoshRosen/e520fb9a64c1c97ec985/raw/5e8a5aa8d2a18910a1607f0aa4190104adda3424/Violating%2520RDD.partitions%2520contract.html

In order to guard against this infinite loop behavior, this patch modifies Spark so that it fails fast and refuses to compute RDDs' whose `partitions` violate the API contract.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #10932 from JoshRosen/SPARK-13021.
---
 .../main/scala/org/apache/spark/rdd/RDD.scala  |  7 +++++++
 .../scala/org/apache/spark/rdd/RDDSuite.scala  | 18 ++++++++++++++++++
 2 files changed, 25 insertions(+)

diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 9dad794414..be47172581 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -112,6 +112,9 @@ abstract class RDD[T: ClassTag](
   /**
    * Implemented by subclasses to return the set of partitions in this RDD. This method will only
    * be called once, so it is safe to implement a time-consuming computation in it.
+   *
+   * The partitions in this array must satisfy the following property:
+   *   `rdd.partitions.zipWithIndex.forall { case (partition, index) => partition.index == index }`
    */
   protected def getPartitions: Array[Partition]
 
@@ -237,6 +240,10 @@ abstract class RDD[T: ClassTag](
     checkpointRDD.map(_.partitions).getOrElse {
       if (partitions_ == null) {
         partitions_ = getPartitions
+        partitions_.zipWithIndex.foreach { case (partition, index) =>
+          require(partition.index == index,
+            s"partitions($index).partition == ${partition.index}, but it should equal $index")
+        }
       }
       partitions_
     }
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index ef2ed44500..80347b800a 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -914,6 +914,24 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
     }
   }
 
+  test("RDD.partitions() fails fast when partitions indicies are incorrect (SPARK-13021)") {
+    class BadRDD[T: ClassTag](prev: RDD[T]) extends RDD[T](prev) {
+
+      override def compute(part: Partition, context: TaskContext): Iterator[T] = {
+        prev.compute(part, context)
+      }
+
+      override protected def getPartitions: Array[Partition] = {
+        prev.partitions.reverse // breaks contract, which is that `rdd.partitions(i).index == i`
+      }
+    }
+    val rdd = new BadRDD(sc.parallelize(1 to 100, 100))
+    val e = intercept[IllegalArgumentException] {
+      rdd.partitions
+    }
+    assert(e.getMessage.contains("partitions"))
+  }
+
   test("nested RDDs are not supported (SPARK-5063)") {
     val rdd: RDD[Int] = sc.parallelize(1 to 100)
     val rdd2: RDD[Int] = sc.parallelize(1 to 100)
-- 
GitLab