From bb46046154a438df4db30a0e1fd557bd3399ee7b Mon Sep 17 00:00:00 2001
From: Xiangrui Meng <meng@databricks.com>
Date: Tue, 18 Nov 2014 16:25:44 -0800
Subject: [PATCH] [SPARK-4433] fix a racing condition in zipWithIndex

Spark hangs with the following code:

~~~
sc.parallelize(1 to 10).zipWithIndex.repartition(10).count()
~~~

This is because ZippedWithIndexRDD triggers a job in getPartitions and it causes a deadlock in DAGScheduler.getPreferredLocs (synced). The fix is to compute `startIndices` during construction.

This should be applied to branch-1.0, branch-1.1, and branch-1.2.

pwendell

Author: Xiangrui Meng <meng@databricks.com>

Closes #3291 from mengxr/SPARK-4433 and squashes the following commits:

c284d9f [Xiangrui Meng] fix a racing condition in zipWithIndex
---
 .../apache/spark/rdd/ZippedWithIndexRDD.scala | 31 ++++++++++---------
 .../scala/org/apache/spark/rdd/RDDSuite.scala |  5 +++
 2 files changed, 22 insertions(+), 14 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
index e2c301603b..8c43a55940 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
@@ -39,21 +39,24 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long)
 private[spark]
 class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, Long)](prev) {
 
-  override def getPartitions: Array[Partition] = {
+  /** The start index of each partition. */
+  @transient private val startIndices: Array[Long] = {
     val n = prev.partitions.size
-    val startIndices: Array[Long] =
-      if (n == 0) {
-        Array[Long]()
-      } else if (n == 1) {
-        Array(0L)
-      } else {
-        prev.context.runJob(
-          prev,
-          Utils.getIteratorSize _,
-          0 until n - 1, // do not need to count the last partition
-          false
-        ).scanLeft(0L)(_ + _)
-      }
+    if (n == 0) {
+      Array[Long]()
+    } else if (n == 1) {
+      Array(0L)
+    } else {
+      prev.context.runJob(
+        prev,
+        Utils.getIteratorSize _,
+        0 until n - 1, // do not need to count the last partition
+        allowLocal = false
+      ).scanLeft(0L)(_ + _)
+    }
+  }
+
+  override def getPartitions: Array[Partition] = {
     firstParent[T].partitions.map(x => new ZippedWithIndexRDDPartition(x, startIndices(x.index)))
   }
 
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 6d2e696dc2..e079ca3b1e 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -739,6 +739,11 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     }
   }
 
+  test("zipWithIndex chained with other RDDs (SPARK-4433)") {
+    val count = sc.parallelize(0 until 10, 2).zipWithIndex().repartition(4).count()
+    assert(count === 10)
+  }
+
   test("zipWithUniqueId") {
     val n = 10
     val data = sc.parallelize(0 until n, 3)
-- 
GitLab