Skip to content
Snippets Groups Projects
Commit c5be7d2b authored by Ankur Dave's avatar Ankur Dave
Browse files

Update Bagel unit tests to reflect API change

parent 9e4c79a4
No related branches found
No related tags found
No related merge requests found
......@@ -10,45 +10,43 @@ import scala.collection.mutable.ArrayBuffer
import spark._
import spark.bagel.Bagel._
class TestVertex(val id: String, val active: Boolean, val age: Int) extends Vertex with Serializable
class TestMessage(val targetId: String) extends Message with Serializable
class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable
class TestMessage(val targetId: String) extends Message[String] with Serializable
class BagelSuite extends FunSuite with Assertions {
test("halting by voting") {
val sc = new SparkContext("local", "test")
val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(id, true, 0))))
val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0))))
val msgs = sc.parallelize(Array[(String, TestMessage)]())
val numSupersteps = 5
val result =
Bagel.run(sc, verts, msgs)()(addAggregatorArg {
(self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) =>
(new TestVertex(self.id, superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
})
for (vert <- result.collect)
Bagel.run(sc, verts, msgs, sc.defaultParallelism) {
(self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
(new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
}
for ((id, vert) <- result.collect)
assert(vert.age === numSupersteps)
sc.stop()
}
test("halting by message silence") {
val sc = new SparkContext("local", "test")
val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(id, false, 0))))
val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0))))
val msgs = sc.parallelize(Array("a" -> new TestMessage("a")))
val numSupersteps = 5
val result =
Bagel.run(sc, verts, msgs)()(addAggregatorArg {
(self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) =>
val msgsOut =
msgs match {
case Some(ms) if (superstep < numSupersteps - 1) =>
ms
case _ =>
new ArrayBuffer[TestMessage]()
}
(new TestVertex(self.id, self.active, self.age + 1), msgsOut)
})
for (vert <- result.collect)
Bagel.run(sc, verts, msgs, sc.defaultParallelism) {
(self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
val msgsOut =
msgs match {
case Some(ms) if (superstep < numSupersteps - 1) =>
ms
case _ =>
Array[TestMessage]()
}
(new TestVertex(self.active, self.age + 1), msgsOut)
}
for ((id, vert) <- result.collect)
assert(vert.age === numSupersteps)
sc.stop()
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment