Skip to content
Snippets Groups Projects
Commit 32f74111 authored by Josh Rosen's avatar Josh Rosen Committed by Yin Huai
Browse files

[SPARK-13021][CORE] Fail fast when custom RDDs violate RDD.partition's API contract

Spark's `Partition` and `RDD.partitions` APIs have a contract which requires custom implementations of `RDD.partitions` to ensure that for all `x`, `rdd.partitions(x).index == x`; in other words, the `index` reported by a repartition needs to match its position in the partitions array.

If a custom RDD implementation violates this contract, then Spark has the potential to become stuck in an infinite recomputation loop when recomputing a subset of an RDD's partitions, since the tasks that are actually run will not correspond to the missing output partitions that triggered the recomputation. Here's a link to a notebook which demonstrates this problem: https://rawgit.com/JoshRosen/e520fb9a64c1c97ec985/raw/5e8a5aa8d2a18910a1607f0aa4190104adda3424/Violating%2520RDD.partitions%2520contract.html

In order to guard against this infinite loop behavior, this patch modifies Spark so that it fails fast and refuses to compute RDDs' whose `partitions` violate the API contract.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #10932 from JoshRosen/SPARK-13021.
parent 87abcf7d
No related branches found
No related tags found
No related merge requests found
...@@ -112,6 +112,9 @@ abstract class RDD[T: ClassTag]( ...@@ -112,6 +112,9 @@ abstract class RDD[T: ClassTag](
/** /**
* Implemented by subclasses to return the set of partitions in this RDD. This method will only * Implemented by subclasses to return the set of partitions in this RDD. This method will only
* be called once, so it is safe to implement a time-consuming computation in it. * be called once, so it is safe to implement a time-consuming computation in it.
*
* The partitions in this array must satisfy the following property:
* `rdd.partitions.zipWithIndex.forall { case (partition, index) => partition.index == index }`
*/ */
protected def getPartitions: Array[Partition] protected def getPartitions: Array[Partition]
...@@ -237,6 +240,10 @@ abstract class RDD[T: ClassTag]( ...@@ -237,6 +240,10 @@ abstract class RDD[T: ClassTag](
checkpointRDD.map(_.partitions).getOrElse { checkpointRDD.map(_.partitions).getOrElse {
if (partitions_ == null) { if (partitions_ == null) {
partitions_ = getPartitions partitions_ = getPartitions
partitions_.zipWithIndex.foreach { case (partition, index) =>
require(partition.index == index,
s"partitions($index).partition == ${partition.index}, but it should equal $index")
}
} }
partitions_ partitions_
} }
......
...@@ -914,6 +914,24 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { ...@@ -914,6 +914,24 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
} }
} }
test("RDD.partitions() fails fast when partitions indicies are incorrect (SPARK-13021)") {
class BadRDD[T: ClassTag](prev: RDD[T]) extends RDD[T](prev) {
override def compute(part: Partition, context: TaskContext): Iterator[T] = {
prev.compute(part, context)
}
override protected def getPartitions: Array[Partition] = {
prev.partitions.reverse // breaks contract, which is that `rdd.partitions(i).index == i`
}
}
val rdd = new BadRDD(sc.parallelize(1 to 100, 100))
val e = intercept[IllegalArgumentException] {
rdd.partitions
}
assert(e.getMessage.contains("partitions"))
}
test("nested RDDs are not supported (SPARK-5063)") { test("nested RDDs are not supported (SPARK-5063)") {
val rdd: RDD[Int] = sc.parallelize(1 to 100) val rdd: RDD[Int] = sc.parallelize(1 to 100)
val rdd2: RDD[Int] = sc.parallelize(1 to 100) val rdd2: RDD[Int] = sc.parallelize(1 to 100)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment