Skip to content
Snippets Groups Projects
Commit ea085371 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Fixed an exponential recursion that could happen with doCheckpoint due

to lack of memoization
parent da8afbc7
No related branches found
No related tags found
No related merge requests found
package spark.bagel package spark.bagel
import org.scalatest.{FunSuite, Assertions, BeforeAndAfter} import org.scalatest.{FunSuite, Assertions, BeforeAndAfter}
import org.scalatest.prop.Checkers import org.scalatest.concurrent.Timeouts
import org.scalacheck.Arbitrary._ import org.scalatest.time.SpanSugar._
import org.scalacheck.Gen
import org.scalacheck.Prop._
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
...@@ -13,7 +11,7 @@ import spark._ ...@@ -13,7 +11,7 @@ import spark._
class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable
class TestMessage(val targetId: String) extends Message[String] with Serializable class TestMessage(val targetId: String) extends Message[String] with Serializable
class BagelSuite extends FunSuite with Assertions with BeforeAndAfter { class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts {
var sc: SparkContext = _ var sc: SparkContext = _
...@@ -25,7 +23,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter { ...@@ -25,7 +23,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port") System.clearProperty("spark.driver.port")
} }
test("halting by voting") { test("halting by voting") {
sc = new SparkContext("local", "test") sc = new SparkContext("local", "test")
val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0)))) val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0))))
...@@ -36,8 +34,9 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter { ...@@ -36,8 +34,9 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter {
(self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
(new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
} }
for ((id, vert) <- result.collect) for ((id, vert) <- result.collect) {
assert(vert.age === numSupersteps) assert(vert.age === numSupersteps)
}
} }
test("halting by message silence") { test("halting by message silence") {
...@@ -57,7 +56,27 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter { ...@@ -57,7 +56,27 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter {
} }
(new TestVertex(self.active, self.age + 1), msgsOut) (new TestVertex(self.active, self.age + 1), msgsOut)
} }
for ((id, vert) <- result.collect) for ((id, vert) <- result.collect) {
assert(vert.age === numSupersteps) assert(vert.age === numSupersteps)
}
}
test("large number of iterations") {
// This tests whether jobs with a large number of iterations finish in a reasonable time,
// because non-memoized recursion in RDD or DAGScheduler used to cause them to hang
failAfter(10 seconds) {
sc = new SparkContext("local", "test")
val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0))))
val msgs = sc.parallelize(Array[(String, TestMessage)]())
val numSupersteps = 50
val result =
Bagel.run(sc, verts, msgs, sc.defaultParallelism) {
(self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
(new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
}
for ((id, vert) <- result.collect) {
assert(vert.age === numSupersteps)
}
}
} }
} }
...@@ -636,16 +636,22 @@ abstract class RDD[T: ClassManifest]( ...@@ -636,16 +636,22 @@ abstract class RDD[T: ClassManifest](
/** The [[spark.SparkContext]] that this RDD was created on. */ /** The [[spark.SparkContext]] that this RDD was created on. */
def context = sc def context = sc
// Avoid handling doCheckpoint multiple times to prevent excessive recursion
private var doCheckpointCalled = false
/** /**
* Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler
* after a job using this RDD has completed (therefore the RDD has been materialized and * after a job using this RDD has completed (therefore the RDD has been materialized and
* potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs. * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs.
*/ */
private[spark] def doCheckpoint() { private[spark] def doCheckpoint() {
if (checkpointData.isDefined) { if (!doCheckpointCalled) {
checkpointData.get.doCheckpoint() doCheckpointCalled = true
} else { if (checkpointData.isDefined) {
dependencies.foreach(_.rdd.doCheckpoint()) checkpointData.get.doCheckpoint()
} else {
dependencies.foreach(_.rdd.doCheckpoint())
}
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment