Skip to content
Snippets Groups Projects
Commit 0028caf3 authored by Ankur Dave's avatar Ankur Dave
Browse files

Simplify and genericize type parameters in Bagel

parent 2d7057bf
No related branches found
No related tags found
No related merge requests found
...@@ -6,54 +6,110 @@ import spark.SparkContext._ ...@@ -6,54 +6,110 @@ import spark.SparkContext._
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
object Bagel extends Logging { object Bagel extends Logging {
def run[V <: Vertex : Manifest, M <: Message : Manifest, C : Manifest, A : Manifest]( def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest,
C : Manifest, A : Manifest](
sc: SparkContext, sc: SparkContext,
verts: RDD[(String, V)], vertices: RDD[(K, V)],
msgs: RDD[(String, M)] messages: RDD[(K, M)],
combiner: Combiner[M, C],
aggregator: Option[Aggregator[V, A]],
partitioner: Partitioner,
numSplits: Int
)( )(
combiner: Combiner[M, C] = new DefaultCombiner[M], compute: (V, Option[C], Option[A], Int) => (V, Array[M])
aggregator: Aggregator[V, A] = new NullAggregator[V], ): RDD[(K, V)] = {
superstep: Int = 0, val splits = if (numSplits != 0) numSplits else sc.defaultParallelism
numSplits: Int = sc.defaultParallelism
)( var superstep = 0
compute: (V, Option[C], A, Int) => (V, Iterable[M]) var verts = vertices
): RDD[V] = { var msgs = messages
var noActivity = false
logInfo("Starting superstep "+superstep+".") do {
val startTime = System.currentTimeMillis logInfo("Starting superstep "+superstep+".")
val startTime = System.currentTimeMillis
val aggregated = agg(verts, aggregator)
val combinedMsgs = msgs.combineByKey(combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners, numSplits) val aggregated = agg(verts, aggregator)
val grouped = verts.groupWith(combinedMsgs) val combinedMsgs = msgs.combineByKey(
val (processed, numMsgs, numActiveVerts) = comp[V, M, C](sc, grouped, compute(_, _, aggregated, superstep)) combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners,
splits, partitioner)
val timeTaken = System.currentTimeMillis - startTime val grouped = combinedMsgs.groupWith(verts)
logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000)) val (processed, numMsgs, numActiveVerts) =
comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep))
// Check stopping condition and iterate
val noActivity = numMsgs == 0 && numActiveVerts == 0 val timeTaken = System.currentTimeMillis - startTime
if (noActivity) { logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
processed.map { case (id, (vert, msgs)) => vert }
} else { verts = processed.mapValues { case (vert, msgs) => vert }
val newVerts = processed.mapValues { case (vert, msgs) => vert } msgs = processed.flatMap {
val newMsgs = processed.flatMap {
case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m)) case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m))
} }
run(sc, newVerts, newMsgs)(combiner, aggregator, superstep + 1, numSplits)(compute) superstep += 1
}
noActivity = numMsgs == 0 && numActiveVerts == 0
} while (!noActivity)
verts
}
def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest,
C : Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
combiner: Combiner[M, C],
partitioner: Partitioner,
numSplits: Int
)(
compute: (V, Option[C], Int) => (V, Array[M])
): RDD[(K, V)] = {
run[K, V, M, C, Nothing](
sc, vertices, messages, combiner, None, partitioner, numSplits)(
addAggregatorArg[K, V, M, C](compute))
}
def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest,
C : Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
combiner: Combiner[M, C],
numSplits: Int
)(
compute: (V, Option[C], Int) => (V, Array[M])
): RDD[(K, V)] = {
val part = new HashPartitioner(numSplits)
run[K, V, M, C, Nothing](
sc, vertices, messages, combiner, None, part, numSplits)(
addAggregatorArg[K, V, M, C](compute))
}
def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
numSplits: Int
)(
compute: (V, Option[Array[M]], Int) => (V, Array[M])
): RDD[(K, V)] = {
val part = new HashPartitioner(numSplits)
run[K, V, M, Array[M], Nothing](
sc, vertices, messages, new DefaultCombiner(), None, part, numSplits)(
addAggregatorArg[K, V, M, Array[M]](compute))
} }
/** /**
* Aggregates the given vertices using the given aggregator, or does * Aggregates the given vertices using the given aggregator, if it
* nothing if it is a NullAggregator. * is specified.
*/ */
def agg[V <: Vertex, A : Manifest](verts: RDD[(String, V)], aggregator: Aggregator[V, A]): A = aggregator match { private def agg[K, V <: Vertex, A : Manifest](
case _: NullAggregator[_] => verts: RDD[(K, V)],
None aggregator: Option[Aggregator[V, A]]
case _ => ): Option[A] = aggregator match {
verts.map { case Some(a) =>
case (id, vert) => aggregator.createAggregator(vert) Some(verts.map {
}.reduce(aggregator.mergeAggregators(_, _)) case (id, vert) => a.createAggregator(vert)
}.reduce(a.mergeAggregators(_, _)))
case None => None
} }
/** /**
...@@ -61,23 +117,27 @@ object Bagel extends Logging { ...@@ -61,23 +117,27 @@ object Bagel extends Logging {
* function. Returns the processed RDD, the number of messages * function. Returns the processed RDD, the number of messages
* created, and the number of active vertices. * created, and the number of active vertices.
*/ */
def comp[V <: Vertex, M <: Message, C](sc: SparkContext, grouped: RDD[(String, (Seq[V], Seq[C]))], compute: (V, Option[C]) => (V, Iterable[M])): (RDD[(String, (V, Iterable[M]))], Int, Int) = { private def comp[K : Manifest, V <: Vertex, M <: Message[K], C](
sc: SparkContext,
grouped: RDD[(K, (Seq[C], Seq[V]))],
compute: (V, Option[C]) => (V, Array[M])
): (RDD[(K, (V, Array[M]))], Int, Int) = {
var numMsgs = sc.accumulator(0) var numMsgs = sc.accumulator(0)
var numActiveVerts = sc.accumulator(0) var numActiveVerts = sc.accumulator(0)
val processed = grouped.flatMapValues { val processed = grouped.flatMapValues {
case (Seq(), _) => None case (_, vs) if vs.size == 0 => None
case (Seq(v), c) => case (c, vs) =>
val (newVert, newMsgs) = val (newVert, newMsgs) =
compute(v, c match { compute(vs(0), c match {
case Seq(comb) => Some(comb) case Seq(comb) => Some(comb)
case Seq() => None case Seq() => None
}) })
numMsgs += newMsgs.size numMsgs += newMsgs.size
if (newVert.active) if (newVert.active)
numActiveVerts += 1 numActiveVerts += 1
Some((newVert, newMsgs)) Some((newVert, newMsgs))
}.cache }.cache
// Force evaluation of processed RDD for accurate performance measurements // Force evaluation of processed RDD for accurate performance measurements
...@@ -90,16 +150,16 @@ object Bagel extends Logging { ...@@ -90,16 +150,16 @@ object Bagel extends Logging {
* Converts a compute function that doesn't take an aggregator to * Converts a compute function that doesn't take an aggregator to
* one that does, so it can be passed to Bagel.run. * one that does, so it can be passed to Bagel.run.
*/ */
implicit def addAggregatorArg[ private def addAggregatorArg[
V <: Vertex : Manifest, M <: Message : Manifest, C K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C
]( ](
compute: (V, Option[C], Int) => (V, Iterable[M]) compute: (V, Option[C], Int) => (V, Array[M])
): (V, Option[C], Option[Nothing], Int) => (V, Iterable[M]) = { ): (V, Option[C], Option[Nothing], Int) => (V, Array[M]) = {
(vert: V, messages: Option[C], aggregator: Option[Nothing], superstep: Int) => compute(vert, messages, superstep) (vert: V, msgs: Option[C], aggregated: Option[Nothing], superstep: Int) =>
compute(vert, msgs, superstep)
} }
} }
// TODO: Simplify Combiner interface and make it more OO.
trait Combiner[M, C] { trait Combiner[M, C] {
def createCombiner(msg: M): C def createCombiner(msg: M): C
def mergeMsg(combiner: C, msg: M): C def mergeMsg(combiner: C, msg: M): C
...@@ -111,18 +171,13 @@ trait Aggregator[V, A] { ...@@ -111,18 +171,13 @@ trait Aggregator[V, A] {
def mergeAggregators(a: A, b: A): A def mergeAggregators(a: A, b: A): A
} }
class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] with Serializable { class DefaultCombiner[M : Manifest] extends Combiner[M, Array[M]] with Serializable {
def createCombiner(msg: M): ArrayBuffer[M] = def createCombiner(msg: M): Array[M] =
ArrayBuffer(msg) Array(msg)
def mergeMsg(combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] = def mergeMsg(combiner: Array[M], msg: M): Array[M] =
combiner += msg combiner :+ msg
def mergeCombiners(a: ArrayBuffer[M], b: ArrayBuffer[M]): ArrayBuffer[M] = def mergeCombiners(a: Array[M], b: Array[M]): Array[M] =
a ++= b a ++ b
}
class NullAggregator[V] extends Aggregator[V, Option[Nothing]] with Serializable {
def createAggregator(vert: V): Option[Nothing] = None
def mergeAggregators(a: Option[Nothing], b: Option[Nothing]): Option[Nothing] = None
} }
/** /**
...@@ -132,7 +187,6 @@ class NullAggregator[V] extends Aggregator[V, Option[Nothing]] with Serializable ...@@ -132,7 +187,6 @@ class NullAggregator[V] extends Aggregator[V, Option[Nothing]] with Serializable
* inherit from java.io.Serializable or scala.Serializable. * inherit from java.io.Serializable or scala.Serializable.
*/ */
trait Vertex { trait Vertex {
def id: String
def active: Boolean def active: Boolean
} }
...@@ -142,16 +196,6 @@ trait Vertex { ...@@ -142,16 +196,6 @@ trait Vertex {
* Subclasses may contain a payload to deliver to the target vertex * Subclasses may contain a payload to deliver to the target vertex
* and must inherit from java.io.Serializable or scala.Serializable. * and must inherit from java.io.Serializable or scala.Serializable.
*/ */
trait Message { trait Message[K] {
def targetId: String def targetId: K
}
/**
* Represents a directed edge between two vertices.
*
* Subclasses may store state along each edge and must inherit from
* java.io.Serializable or scala.Serializable.
*/
trait Edge {
def targetId: String
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment