Skip to content
Snippets Groups Projects
Commit c18fa3eb authored by Ankur Dave's avatar Ankur Dave
Browse files

Package combiner functions into a trait

parent 1c8ca0eb
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,7 @@ package bagel ...@@ -2,7 +2,7 @@ package bagel
import spark._ import spark._
import spark.SparkContext._ import spark.SparkContext._
import scala.collection.mutable.HashMap
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
object Pregel extends Logging { object Pregel extends Logging {
...@@ -24,9 +24,7 @@ object Pregel extends Logging { ...@@ -24,9 +24,7 @@ object Pregel extends Logging {
sc: SparkContext, sc: SparkContext,
verts: RDD[(String, V)], verts: RDD[(String, V)],
msgs: RDD[(String, M)], msgs: RDD[(String, M)],
createCombiner: M => C, combiner: Combiner[M, C],
mergeMsg: (C, M) => C,
mergeCombiners: (C, C) => C,
numSplits: Int, numSplits: Int,
superstep: Int = 0 superstep: Int = 0
)(compute: (V, Option[C], Int) => (V, Iterable[M])): RDD[V] = { )(compute: (V, Option[C], Int) => (V, Iterable[M])): RDD[V] = {
...@@ -35,7 +33,7 @@ object Pregel extends Logging { ...@@ -35,7 +33,7 @@ object Pregel extends Logging {
val startTime = System.currentTimeMillis val startTime = System.currentTimeMillis
// Bring together vertices and messages // Bring together vertices and messages
val combinedMsgs = msgs.combineByKey(createCombiner, mergeMsg, mergeCombiners, numSplits) val combinedMsgs = msgs.combineByKey(combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners, numSplits)
val grouped = verts.groupWith(combinedMsgs) val grouped = verts.groupWith(combinedMsgs)
// Run compute on each vertex // Run compute on each vertex
...@@ -72,17 +70,24 @@ object Pregel extends Logging { ...@@ -72,17 +70,24 @@ object Pregel extends Logging {
val newMsgs = processed.flatMap { val newMsgs = processed.flatMap {
case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m)) case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m))
} }
run(sc, newVerts, newMsgs, createCombiner, mergeMsg, mergeCombiners, numSplits, superstep + 1)(compute) run(sc, newVerts, newMsgs, combiner, numSplits, superstep + 1)(compute)
} }
} }
}
trait Combiner[M, C] {
def createCombiner(msg: M): C
def mergeMsg(combiner: C, msg: M): C
def mergeCombiners(a: C, b: C): C
}
def defaultCreateCombiner[M <: Message](msg: M): ArrayBuffer[M] = ArrayBuffer(msg) @serializable class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] {
def defaultMergeMsg[M <: Message](combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] = def createCombiner(msg: M): ArrayBuffer[M] =
ArrayBuffer(msg)
def mergeMsg(combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] =
combiner += msg combiner += msg
def defaultMergeCombiners[M <: Message](a: ArrayBuffer[M], b: ArrayBuffer[M]): ArrayBuffer[M] = def mergeCombiners(a: ArrayBuffer[M], b: ArrayBuffer[M]): ArrayBuffer[M] =
a ++= b a ++= b
def defaultCompute[V <: Vertex, M <: Message](self: V, msgs: Option[ArrayBuffer[M]], superstep: Int): (V, Iterable[M]) =
(self, List())
} }
/** /**
......
...@@ -49,12 +49,7 @@ object ShortestPath { ...@@ -49,12 +49,7 @@ object ShortestPath {
messages.count()+" messages.") messages.count()+" messages.")
// Do the computation // Do the computation
def createCombiner(message: SPMessage): Int = message.value val result = Pregel.run(sc, vertices, messages, MinCombiner, numSplits) {
def mergeMsg(combiner: Int, message: SPMessage): Int =
min(combiner, message.value)
def mergeCombiners(a: Int, b: Int): Int = min(a, b)
val result = Pregel.run(sc, vertices, messages, createCombiner, mergeMsg, mergeCombiners, numSplits) {
(self: SPVertex, messageMinValue: Option[Int], superstep: Int) => (self: SPVertex, messageMinValue: Option[Int], superstep: Int) =>
val newValue = messageMinValue match { val newValue = messageMinValue match {
case Some(minVal) => min(self.value, minVal) case Some(minVal) => min(self.value, minVal)
...@@ -82,6 +77,15 @@ object ShortestPath { ...@@ -82,6 +77,15 @@ object ShortestPath {
} }
} }
object MinCombiner extends Combiner[SPMessage, Int] {
def createCombiner(msg: SPMessage): Int =
msg.value
def mergeMsg(combiner: Int, msg: SPMessage): Int =
min(combiner, msg.value)
def mergeCombiners(a: Int, b: Int): Int =
min(a, b)
}
@serializable class SPVertex(val id: String, val value: Int, val outEdges: Seq[SPEdge], val active: Boolean) extends Vertex @serializable class SPVertex(val id: String, val value: Int, val outEdges: Seq[SPEdge], val active: Boolean) extends Vertex
@serializable class SPEdge(val targetId: String, val value: Int) extends Edge @serializable class SPEdge(val targetId: String, val value: Int) extends Edge
@serializable class SPMessage(val targetId: String, val value: Int) extends Message @serializable class SPMessage(val targetId: String, val value: Int) extends Message
...@@ -60,9 +60,9 @@ object WikipediaPageRank { ...@@ -60,9 +60,9 @@ object WikipediaPageRank {
val messages = sc.parallelize(List[(String, PRMessage)]()) val messages = sc.parallelize(List[(String, PRMessage)]())
val result = val result =
if (noCombiner) { if (noCombiner) {
Pregel.run[PRVertex, PRMessage, ArrayBuffer[PRMessage]](sc, vertices, messages, NoCombiner.createCombiner, NoCombiner.mergeMsg, NoCombiner.mergeCombiners, numSplits)(NoCombiner.compute(numVertices, epsilon)) Pregel.run(sc, vertices, messages, PRNoCombiner, numSplits)(PRNoCombiner.compute(numVertices, epsilon))
} else { } else {
Pregel.run[PRVertex, PRMessage, Double](sc, vertices, messages, Combiner.createCombiner, Combiner.mergeMsg, Combiner.mergeCombiners, numSplits)(Combiner.compute(numVertices, epsilon)) Pregel.run(sc, vertices, messages, PRCombiner, numSplits)(PRCombiner.compute(numVertices, epsilon))
} }
// Print the result // Print the result
...@@ -71,53 +71,44 @@ object WikipediaPageRank { ...@@ -71,53 +71,44 @@ object WikipediaPageRank {
"%s\t%s\n".format(vertex.id, vertex.value)).collect.mkString "%s\t%s\n".format(vertex.id, vertex.value)).collect.mkString
println(top) println(top)
} }
}
object Combiner { object PRCombiner extends Combiner[PRMessage, Double] {
def createCombiner(message: PRMessage): Double = message.value def createCombiner(msg: PRMessage): Double =
msg.value
def mergeMsg(combiner: Double, message: PRMessage): Double = def mergeMsg(combiner: Double, msg: PRMessage): Double =
combiner + message.value combiner + msg.value
def mergeCombiners(a: Double, b: Double): Double =
def mergeCombiners(a: Double, b: Double) = a + b a + b
def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = { def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = {
val newValue = messageSum match { val newValue = messageSum match {
case Some(msgSum) if msgSum != 0 => case Some(msgSum) if msgSum != 0 =>
0.15 / numVertices + 0.85 * msgSum 0.15 / numVertices + 0.85 * msgSum
case _ => self.value case _ => self.value
}
val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30
val outbox =
if (!terminate)
self.outEdges.map(edge =>
new PRMessage(edge.targetId, newValue / self.outEdges.size))
else
ArrayBuffer[PRMessage]()
(new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox)
} }
}
object NoCombiner { val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30
def createCombiner(message: PRMessage): ArrayBuffer[PRMessage] =
ArrayBuffer(message)
def mergeMsg(combiner: ArrayBuffer[PRMessage], message: PRMessage): ArrayBuffer[PRMessage] = val outbox =
combiner += message if (!terminate)
self.outEdges.map(edge =>
new PRMessage(edge.targetId, newValue / self.outEdges.size))
else
ArrayBuffer[PRMessage]()
def mergeCombiners(a: ArrayBuffer[PRMessage], b: ArrayBuffer[PRMessage]): ArrayBuffer[PRMessage] = (new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox)
a ++= b
def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) =
Combiner.compute(numVertices, epsilon)(self, messages match {
case Some(msgs) => Some(msgs.map(_.value).sum)
case None => None
}, superstep)
} }
} }
object PRNoCombiner extends DefaultCombiner[PRMessage] {
def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) =
PRCombiner.compute(numVertices, epsilon)(self, messages match {
case Some(msgs) => Some(msgs.map(_.value).sum)
case None => None
}, superstep)
}
@serializable class PRVertex() extends Vertex { @serializable class PRVertex() extends Vertex {
var id: String = _ var id: String = _
var value: Double = _ var value: Double = _
......
...@@ -20,10 +20,7 @@ class BagelSuite extends FunSuite with Assertions { ...@@ -20,10 +20,7 @@ class BagelSuite extends FunSuite with Assertions {
val msgs = sc.parallelize(Array[(String, TestMessage)]()) val msgs = sc.parallelize(Array[(String, TestMessage)]())
val numSupersteps = 5 val numSupersteps = 5
val result = val result =
Pregel.run(sc, verts, msgs, Pregel.run(sc, verts, msgs, new DefaultCombiner[TestMessage], 1) {
Pregel.defaultCreateCombiner[TestMessage],
Pregel.defaultMergeMsg[TestMessage],
Pregel.defaultMergeCombiners[TestMessage], 1) {
(self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) => (self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) =>
(new TestVertex(self.id, superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]()) (new TestVertex(self.id, superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
} }
...@@ -37,10 +34,7 @@ class BagelSuite extends FunSuite with Assertions { ...@@ -37,10 +34,7 @@ class BagelSuite extends FunSuite with Assertions {
val msgs = sc.parallelize(Array("a" -> new TestMessage("a"))) val msgs = sc.parallelize(Array("a" -> new TestMessage("a")))
val numSupersteps = 5 val numSupersteps = 5
val result = val result =
Pregel.run(sc, verts, msgs, Pregel.run(sc, verts, msgs, new DefaultCombiner[TestMessage], 1) {
Pregel.defaultCreateCombiner[TestMessage],
Pregel.defaultMergeMsg[TestMessage],
Pregel.defaultMergeCombiners[TestMessage], 1) {
(self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) => (self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) =>
val msgsOut = val msgsOut =
msgs match { msgs match {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment