From e2047d3927e0032cc1d6de3fd09d00f96ce837d0 Mon Sep 17 00:00:00 2001 From: Reynold Xin <rxin@apache.org> Date: Fri, 11 Oct 2013 13:04:45 -0700 Subject: [PATCH] Making takeAsync and collectAsync deterministic. --- .../scala/org/apache/spark/FutureAction.scala | 4 ---- .../apache/spark/rdd/AsyncRDDActions.scala | 20 ++++++++++--------- .../spark/rdd/AsyncRDDActionsSuite.scala | 10 ++++------ 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 9f41912d6c..eab2957632 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -177,10 +177,6 @@ class CancellablePromise[T] extends FutureAction[T] with Promise[T] { def run(func: => T)(implicit executor: ExecutionContext): Unit = scala.concurrent.future { thread = Thread.currentThread try { - if (cancelled) { - // This action has been cancelled before this thread even started running. - this.failure(new SparkException("action cancelled")) - } this.success(func) } catch { case e: Exception => this.failure(e) diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 579832427e..32af795d4c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -54,9 +54,9 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with * Return a future for retrieving all elements of this RDD. */ def collectAsync(): FutureAction[Seq[T]] = { - val results = new ArrayBuffer[T] + val results = new Array[Array[T]](self.partitions.size) self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size), - (index, data) => results ++= data, results) + (index, data) => results(index) = data, results.flatten.toSeq) } /** @@ -66,10 +66,10 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with val promise = new CancellablePromise[Seq[T]] promise.run { - val buf = new ArrayBuffer[T](num) + val results = new ArrayBuffer[T](num) val totalParts = self.partitions.length var partsScanned = 0 - while (buf.size < num && partsScanned < totalParts) { + while (results.size < num && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1 @@ -77,26 +77,28 @@ class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with // If we didn't find any rows after the first iteration, just try all partitions next. // Otherwise, interpolate the number of partitions we need to try, but overestimate it // by 50%. - if (buf.size == 0) { + if (results.size == 0) { numPartsToTry = totalParts - 1 } else { - numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt + numPartsToTry = (1.5 * num * partsScanned / results.size).toInt } } numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions - val left = num - buf.size + val left = num - results.size val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val buf = new Array[Array[T]](p.size) promise.runJob(self, (it: Iterator[T]) => it.take(left).toArray, p, - (index: Int, data: Array[T]) => buf ++= data.take(num - buf.size), + (index: Int, data: Array[T]) => buf(index) = data, Unit) + buf.foreach(results ++= _.take(num - results.size)) partsScanned += numPartsToTry } - buf.toSeq + results.toSeq } promise.future diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index 131e2466ac..3ef000da4a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -53,8 +53,7 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll { test("collectAsync") { assert(zeroPartRdd.collectAsync().get() === Seq.empty) - // Note that we sort the collected output because the order is indeterministic. - val collected = sc.parallelize(1 to 1000, 3).collectAsync().get().sorted + val collected = sc.parallelize(1 to 1000, 3).collectAsync().get() assert(collected === (1 to 1000)) } @@ -80,10 +79,9 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll { test("takeAsync") { def testTake(rdd: RDD[Int], input: Seq[Int], num: Int) { - // Note that we sort the collected output because the order is indeterministic. - val expected = input.take(num).size - val saw = rdd.takeAsync(num).get().size - assert(saw == expected, "incorrect result for rdd with %d partitions (expected %d, saw %d)" + val expected = input.take(num) + val saw = rdd.takeAsync(num).get() + assert(saw == expected, "incorrect result for rdd with %d partitions (expected %s, saw %s)" .format(rdd.partitions.size, expected, saw)) } val input = Range(1, 1000) -- GitLab