From d6a902f309e914560ec9a49ca539c47b928d107a Mon Sep 17 00:00:00 2001
From: "Joseph E. Gonzalez" <joseph.e.gonzalez@gmail.com>
Date: Mon, 28 Oct 2013 11:52:26 -0700
Subject: [PATCH] Finished updating connected components to used Pregel like
 abstraction and created a series of tests in the AnalyticsSuite.

---
 .../org/apache/spark/graph/Analytics.scala    | 21 +++++-
 .../apache/spark/graph/AnalyticsSuite.scala   | 65 +++++++++++++++++++
 2 files changed, 83 insertions(+), 3 deletions(-)

diff --git a/graph/src/main/scala/org/apache/spark/graph/Analytics.scala b/graph/src/main/scala/org/apache/spark/graph/Analytics.scala
index 8feb42490d..acb9e3753f 100644
--- a/graph/src/main/scala/org/apache/spark/graph/Analytics.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/Analytics.scala
@@ -167,8 +167,17 @@ object Analytics extends Logging {
    * and return an RDD with the vertex value containing the
    * lowest vertex id in the connected component containing
    * that vertex.
+   *
+   * @tparam VD the vertex attribute type (discarded in the computation)
+   * @tparam ED the edge attribute type (preserved in the computation)
+   *
+   * @param graph the graph for which to compute the connected components 
+   *
+   * @return a graph with vertex attributes containing the smallest vertex
+   * in each connected component
    */
-  def connectedComponents[VD: Manifest, ED: Manifest](graph: Graph[VD, ED]) = {
+  def connectedComponents[VD: Manifest, ED: Manifest](graph: Graph[VD, ED]): 
+    Graph[Vid, ED] = {
     val ccGraph = graph.mapVertices { case (vid, _) => vid }
 
     def sendMessage(id: Vid, edge: EdgeTriplet[Vid, ED]): Option[Vid] = {
@@ -179,21 +188,27 @@ object Analytics extends Logging {
     }
 
     val initialMessage = Long.MaxValue
-    Pregel(ccGraph, initialMessage)(
+    Pregel(ccGraph, initialMessage, EdgeDirection.Both)(
       (id, attr, msg) => math.min(attr, msg),
       sendMessage, 
       (a,b) => math.min(a,b)
       )
 
+    /**
+     * Originally this was implemented using the GraphLab abstraction but with
+     * support for message computation along all edge directions the pregel
+     * abstraction is sufficient 
+     */
     // GraphLab(ccGraph, gatherDirection = EdgeDirection.Both, scatterDirection = EdgeDirection.Both)(
     //   (me_id, edge) => edge.otherVertexAttr(me_id), // gather
     //   (a: Vid, b: Vid) => math.min(a, b), // merge
     //   (id, data, a: Option[Vid]) => math.min(data, a.getOrElse(Long.MaxValue)), // apply
     //   (me_id, edge) => (edge.vertexAttr(me_id) < edge.otherVertexAttr(me_id))
     // )
+  } // end of connectedComponents 
 
-  }
   
+
   def main(args: Array[String]) = {
     val host = args(0)
     val taskType = args(1)
diff --git a/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala b/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala
index f4a8c6b4c9..8d0b2e0b02 100644
--- a/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala
+++ b/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala
@@ -79,6 +79,7 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
   } // end of test Star PageRank
 
 
+
   test("Grid PageRank") {
     withSpark(new SparkContext("local", "test")) { sc =>
       val gridGraph = GraphGenerators.gridGraph(sc, 10, 10)
@@ -104,4 +105,68 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
   } // end of Grid PageRank
 
 
+  test("Grid Connected Components") {
+    withSpark(new SparkContext("local", "test")) { sc =>
+      val gridGraph = GraphGenerators.gridGraph(sc, 10, 10)
+      val ccGraph = Analytics.connectedComponents(gridGraph).cache()
+      val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum
+      assert(maxCCid === 0)
+    }
+  } // end of Grid connected components
+
+
+  test("Reverse Grid Connected Components") {
+    withSpark(new SparkContext("local", "test")) { sc =>
+      val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).reverse
+      val ccGraph = Analytics.connectedComponents(gridGraph).cache()
+      val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum
+      assert(maxCCid === 0)
+    }
+  } // end of Grid connected components
+
+
+  test("Chain Connected Components") {
+    withSpark(new SparkContext("local", "test")) { sc =>
+      val chain1 = (0 until 9).map(x => (x, x+1) )
+      val chain2 = (10 until 20).map(x => (x, x+1) )
+      val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
+      val twoChains = Graph(rawEdges)
+      val ccGraph = Analytics.connectedComponents(twoChains).cache()
+      val vertices = ccGraph.vertices.collect
+      for ( (id, cc) <- vertices ) {
+        if(id < 10) { assert(cc === 0) }
+        else { assert(cc === 10) }
+      }
+      val ccMap = vertices.toMap
+      println(ccMap)
+      for( id <- 0 until 20 ) {
+        if(id < 10) { assert(ccMap(id) === 0) }
+        else { assert(ccMap(id) === 10) }
+      }
+    }
+  } // end of chain connected components
+
+  test("Reverse Chain Connected Components") {
+    withSpark(new SparkContext("local", "test")) { sc =>
+      val chain1 = (0 until 9).map(x => (x, x+1) )
+      val chain2 = (10 until 20).map(x => (x, x+1) )
+      val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
+      val twoChains = Graph(rawEdges).reverse
+      val ccGraph = Analytics.connectedComponents(twoChains).cache()
+      val vertices = ccGraph.vertices.collect
+      for ( (id, cc) <- vertices ) {
+        if(id < 10) { assert(cc === 0) }
+        else { assert(cc === 10) }
+      }
+      val ccMap = vertices.toMap
+      println(ccMap)
+      for( id <- 0 until 20 ) {
+        if(id < 10) { assert(ccMap(id) === 0) }
+        else { assert(ccMap(id) === 10) }
+      }
+    }
+  } // end of chain connected components
+
+
+
 } // end of AnalyticsSuite
-- 
GitLab