From 0028caf3a4727623f70e23cd2f611f9797d0a3d3 Mon Sep 17 00:00:00 2001
From: Ankur Dave <ankurdave@gmail.com>
Date: Sun, 9 Oct 2011 15:58:39 -0700
Subject: [PATCH] Simplify and genericize type parameters in Bagel

---
 bagel/src/main/scala/spark/bagel/Bagel.scala | 214 +++++++++++--------
 1 file changed, 129 insertions(+), 85 deletions(-)

diff --git a/bagel/src/main/scala/spark/bagel/Bagel.scala b/bagel/src/main/scala/spark/bagel/Bagel.scala
index c24c65be2a..2f57c9c0fd 100644
--- a/bagel/src/main/scala/spark/bagel/Bagel.scala
+++ b/bagel/src/main/scala/spark/bagel/Bagel.scala
@@ -6,54 +6,110 @@ import spark.SparkContext._
 import scala.collection.mutable.ArrayBuffer
 
 object Bagel extends Logging {
-  def run[V <: Vertex : Manifest, M <: Message : Manifest, C : Manifest, A : Manifest](
+  def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest,
+          C : Manifest, A : Manifest](
     sc: SparkContext,
-    verts: RDD[(String, V)],
-    msgs: RDD[(String, M)]
+    vertices: RDD[(K, V)],
+    messages: RDD[(K, M)],
+    combiner: Combiner[M, C],
+    aggregator: Option[Aggregator[V, A]],
+    partitioner: Partitioner,
+    numSplits: Int
   )(
-    combiner: Combiner[M, C] = new DefaultCombiner[M],
-    aggregator: Aggregator[V, A] = new NullAggregator[V],
-    superstep: Int = 0,
-    numSplits: Int = sc.defaultParallelism
-  )(
-    compute: (V, Option[C], A, Int) => (V, Iterable[M])
-  ): RDD[V] = {
-
-    logInfo("Starting superstep "+superstep+".")
-    val startTime = System.currentTimeMillis
-
-    val aggregated = agg(verts, aggregator)
-    val combinedMsgs = msgs.combineByKey(combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners, numSplits)
-    val grouped = verts.groupWith(combinedMsgs)
-    val (processed, numMsgs, numActiveVerts) = comp[V, M, C](sc, grouped, compute(_, _, aggregated, superstep))
-
-    val timeTaken = System.currentTimeMillis - startTime
-    logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
-
-    // Check stopping condition and iterate
-    val noActivity = numMsgs == 0 && numActiveVerts == 0
-    if (noActivity) {
-      processed.map { case (id, (vert, msgs)) => vert }
-    } else {
-      val newVerts = processed.mapValues { case (vert, msgs) => vert }
-      val newMsgs = processed.flatMap {
+    compute: (V, Option[C], Option[A], Int) => (V, Array[M])
+  ): RDD[(K, V)] = {
+    val splits = if (numSplits != 0) numSplits else sc.defaultParallelism
+
+    var superstep = 0
+    var verts = vertices
+    var msgs = messages
+    var noActivity = false
+    do {
+      logInfo("Starting superstep "+superstep+".")
+      val startTime = System.currentTimeMillis
+
+      val aggregated = agg(verts, aggregator)
+      val combinedMsgs = msgs.combineByKey(
+        combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners,
+        splits, partitioner)
+      val grouped = combinedMsgs.groupWith(verts)
+      val (processed, numMsgs, numActiveVerts) =
+        comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep))
+
+      val timeTaken = System.currentTimeMillis - startTime
+      logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
+
+      verts = processed.mapValues { case (vert, msgs) => vert }
+      msgs = processed.flatMap {
         case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m))
       }
-      run(sc, newVerts, newMsgs)(combiner, aggregator, superstep + 1, numSplits)(compute)
-    }
+      superstep += 1
+
+      noActivity = numMsgs == 0 && numActiveVerts == 0
+    } while (!noActivity)
+
+    verts
+  }
+
+  def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest,
+          C : Manifest](
+    sc: SparkContext,
+    vertices: RDD[(K, V)],
+    messages: RDD[(K, M)],
+    combiner: Combiner[M, C],
+    partitioner: Partitioner,
+    numSplits: Int
+  )(
+    compute: (V, Option[C], Int) => (V, Array[M])
+  ): RDD[(K, V)] = {
+    run[K, V, M, C, Nothing](
+      sc, vertices, messages, combiner, None, partitioner, numSplits)(
+      addAggregatorArg[K, V, M, C](compute))
+  }
+
+  def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest,
+          C : Manifest](
+    sc: SparkContext,
+    vertices: RDD[(K, V)],
+    messages: RDD[(K, M)],
+    combiner: Combiner[M, C],
+    numSplits: Int
+  )(
+    compute: (V, Option[C], Int) => (V, Array[M])
+  ): RDD[(K, V)] = {
+    val part = new HashPartitioner(numSplits)
+    run[K, V, M, C, Nothing](
+      sc, vertices, messages, combiner, None, part, numSplits)(
+      addAggregatorArg[K, V, M, C](compute))
+  }
+
+  def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest](
+    sc: SparkContext,
+    vertices: RDD[(K, V)],
+    messages: RDD[(K, M)],
+    numSplits: Int
+  )(
+    compute: (V, Option[Array[M]], Int) => (V, Array[M])
+  ): RDD[(K, V)] = {
+    val part = new HashPartitioner(numSplits)
+    run[K, V, M, Array[M], Nothing](
+      sc, vertices, messages, new DefaultCombiner(), None, part, numSplits)(
+      addAggregatorArg[K, V, M, Array[M]](compute))
   }
 
   /**
-   * Aggregates the given vertices using the given aggregator, or does
-   * nothing if it is a NullAggregator.
+   * Aggregates the given vertices using the given aggregator, if it
+   * is specified.
    */
-  def agg[V <: Vertex, A : Manifest](verts: RDD[(String, V)], aggregator: Aggregator[V, A]): A = aggregator match {
-    case _: NullAggregator[_] =>
-      None
-    case _ =>
-      verts.map {
-        case (id, vert) => aggregator.createAggregator(vert)
-      }.reduce(aggregator.mergeAggregators(_, _))
+  private def agg[K, V <: Vertex, A : Manifest](
+    verts: RDD[(K, V)],
+    aggregator: Option[Aggregator[V, A]]
+  ): Option[A] = aggregator match {
+    case Some(a) =>
+      Some(verts.map {
+        case (id, vert) => a.createAggregator(vert)
+      }.reduce(a.mergeAggregators(_, _)))
+    case None => None
   }
 
   /**
@@ -61,23 +117,27 @@ object Bagel extends Logging {
    * function. Returns the processed RDD, the number of messages
    * created, and the number of active vertices.
    */
-  def comp[V <: Vertex, M <: Message, C](sc: SparkContext, grouped: RDD[(String, (Seq[V], Seq[C]))], compute: (V, Option[C]) => (V, Iterable[M])): (RDD[(String, (V, Iterable[M]))], Int, Int) = {
+  private def comp[K : Manifest, V <: Vertex, M <: Message[K], C](
+    sc: SparkContext,
+    grouped: RDD[(K, (Seq[C], Seq[V]))],
+    compute: (V, Option[C]) => (V, Array[M])
+  ): (RDD[(K, (V, Array[M]))], Int, Int) = {
     var numMsgs = sc.accumulator(0)
     var numActiveVerts = sc.accumulator(0)
     val processed = grouped.flatMapValues {
-      case (Seq(), _) => None
-      case (Seq(v), c) =>
-          val (newVert, newMsgs) =
-            compute(v, c match {
-              case Seq(comb) => Some(comb)
-              case Seq() => None
-            })
-
-          numMsgs += newMsgs.size
-          if (newVert.active)
-            numActiveVerts += 1
-
-          Some((newVert, newMsgs))
+      case (_, vs) if vs.size == 0 => None
+      case (c, vs) =>
+        val (newVert, newMsgs) =
+          compute(vs(0), c match {
+            case Seq(comb) => Some(comb)
+            case Seq() => None
+          })
+
+        numMsgs += newMsgs.size
+        if (newVert.active)
+          numActiveVerts += 1
+
+        Some((newVert, newMsgs))
     }.cache
 
     // Force evaluation of processed RDD for accurate performance measurements
@@ -90,16 +150,16 @@ object Bagel extends Logging {
    * Converts a compute function that doesn't take an aggregator to
    * one that does, so it can be passed to Bagel.run.
    */
-  implicit def addAggregatorArg[
-    V <: Vertex : Manifest, M <: Message : Manifest, C
+  private def addAggregatorArg[
+    K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C
   ](
-    compute: (V, Option[C], Int) => (V, Iterable[M])
-  ): (V, Option[C], Option[Nothing], Int) => (V, Iterable[M]) = {
-    (vert: V, messages: Option[C], aggregator: Option[Nothing], superstep: Int) => compute(vert, messages, superstep)
+    compute: (V, Option[C], Int) => (V, Array[M])
+  ): (V, Option[C], Option[Nothing], Int) => (V, Array[M]) = {
+    (vert: V, msgs: Option[C], aggregated: Option[Nothing], superstep: Int) =>
+      compute(vert, msgs, superstep)
   }
 }
 
-// TODO: Simplify Combiner interface and make it more OO.
 trait Combiner[M, C] {
   def createCombiner(msg: M): C
   def mergeMsg(combiner: C, msg: M): C
@@ -111,18 +171,13 @@ trait Aggregator[V, A] {
   def mergeAggregators(a: A, b: A): A
 }
 
-class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] with Serializable {
-  def createCombiner(msg: M): ArrayBuffer[M] =
-    ArrayBuffer(msg)
-  def mergeMsg(combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] =
-    combiner += msg
-  def mergeCombiners(a: ArrayBuffer[M], b: ArrayBuffer[M]): ArrayBuffer[M] =
-    a ++= b
-}
-
-class NullAggregator[V] extends Aggregator[V, Option[Nothing]] with Serializable {
-  def createAggregator(vert: V): Option[Nothing] = None
-  def mergeAggregators(a: Option[Nothing], b: Option[Nothing]): Option[Nothing] = None
+class DefaultCombiner[M : Manifest] extends Combiner[M, Array[M]] with Serializable {
+  def createCombiner(msg: M): Array[M] =
+    Array(msg)
+  def mergeMsg(combiner: Array[M], msg: M): Array[M] =
+    combiner :+ msg
+  def mergeCombiners(a: Array[M], b: Array[M]): Array[M] =
+    a ++ b
 }
 
 /**
@@ -132,7 +187,6 @@ class NullAggregator[V] extends Aggregator[V, Option[Nothing]] with Serializable
  * inherit from java.io.Serializable or scala.Serializable.
  */
 trait Vertex {
-  def id: String
   def active: Boolean
 }
 
@@ -142,16 +196,6 @@ trait Vertex {
  * Subclasses may contain a payload to deliver to the target vertex
  * and must inherit from java.io.Serializable or scala.Serializable.
  */
-trait Message {
-  def targetId: String
-}
-
-/**
- * Represents a directed edge between two vertices.
- *
- * Subclasses may store state along each edge and must inherit from
- * java.io.Serializable or scala.Serializable.
- */
-trait Edge {
-  def targetId: String
+trait Message[K] {
+  def targetId: K
 }
-- 
GitLab