diff --git a/graph/src/main/scala/org/apache/spark/graph/Analytics.scala b/graph/src/main/scala/org/apache/spark/graph/Analytics.scala
index 8feb42490da8002b2cb6eaa08e7d1731dbce8c50..acb9e3753f8e29194551c1b32f7c9d69683d3af4 100644
--- a/graph/src/main/scala/org/apache/spark/graph/Analytics.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/Analytics.scala
@@ -167,8 +167,17 @@ object Analytics extends Logging {
    * and return an RDD with the vertex value containing the
    * lowest vertex id in the connected component containing
    * that vertex.
+   *
+   * @tparam VD the vertex attribute type (discarded in the computation)
+   * @tparam ED the edge attribute type (preserved in the computation)
+   *
+   * @param graph the graph for which to compute the connected components 
+   *
+   * @return a graph with vertex attributes containing the smallest vertex
+   * in each connected component
    */
-  def connectedComponents[VD: Manifest, ED: Manifest](graph: Graph[VD, ED]) = {
+  def connectedComponents[VD: Manifest, ED: Manifest](graph: Graph[VD, ED]): 
+    Graph[Vid, ED] = {
     val ccGraph = graph.mapVertices { case (vid, _) => vid }
 
     def sendMessage(id: Vid, edge: EdgeTriplet[Vid, ED]): Option[Vid] = {
@@ -179,21 +188,27 @@ object Analytics extends Logging {
     }
 
     val initialMessage = Long.MaxValue
-    Pregel(ccGraph, initialMessage)(
+    Pregel(ccGraph, initialMessage, EdgeDirection.Both)(
       (id, attr, msg) => math.min(attr, msg),
       sendMessage, 
       (a,b) => math.min(a,b)
       )
 
+    /**
+     * Originally this was implemented using the GraphLab abstraction but with
+     * support for message computation along all edge directions the pregel
+     * abstraction is sufficient 
+     */
     // GraphLab(ccGraph, gatherDirection = EdgeDirection.Both, scatterDirection = EdgeDirection.Both)(
     //   (me_id, edge) => edge.otherVertexAttr(me_id), // gather
     //   (a: Vid, b: Vid) => math.min(a, b), // merge
     //   (id, data, a: Option[Vid]) => math.min(data, a.getOrElse(Long.MaxValue)), // apply
     //   (me_id, edge) => (edge.vertexAttr(me_id) < edge.otherVertexAttr(me_id))
     // )
+  } // end of connectedComponents 
 
-  }
   
+
   def main(args: Array[String]) = {
     val host = args(0)
     val taskType = args(1)
diff --git a/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala b/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala
index f4a8c6b4c9f4d81da8f11394cbe030cd1fdb17fc..8d0b2e0b02b75475fa2b3cbe434784417f742359 100644
--- a/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala
+++ b/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala
@@ -79,6 +79,7 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
   } // end of test Star PageRank
 
 
+
   test("Grid PageRank") {
     withSpark(new SparkContext("local", "test")) { sc =>
       val gridGraph = GraphGenerators.gridGraph(sc, 10, 10)
@@ -104,4 +105,68 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
   } // end of Grid PageRank
 
 
+  test("Grid Connected Components") {
+    withSpark(new SparkContext("local", "test")) { sc =>
+      val gridGraph = GraphGenerators.gridGraph(sc, 10, 10)
+      val ccGraph = Analytics.connectedComponents(gridGraph).cache()
+      val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum
+      assert(maxCCid === 0)
+    }
+  } // end of Grid connected components
+
+
+  test("Reverse Grid Connected Components") {
+    withSpark(new SparkContext("local", "test")) { sc =>
+      val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).reverse
+      val ccGraph = Analytics.connectedComponents(gridGraph).cache()
+      val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum
+      assert(maxCCid === 0)
+    }
+  } // end of Grid connected components
+
+
+  test("Chain Connected Components") {
+    withSpark(new SparkContext("local", "test")) { sc =>
+      val chain1 = (0 until 9).map(x => (x, x+1) )
+      val chain2 = (10 until 20).map(x => (x, x+1) )
+      val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
+      val twoChains = Graph(rawEdges)
+      val ccGraph = Analytics.connectedComponents(twoChains).cache()
+      val vertices = ccGraph.vertices.collect
+      for ( (id, cc) <- vertices ) {
+        if(id < 10) { assert(cc === 0) }
+        else { assert(cc === 10) }
+      }
+      val ccMap = vertices.toMap
+      println(ccMap)
+      for( id <- 0 until 20 ) {
+        if(id < 10) { assert(ccMap(id) === 0) }
+        else { assert(ccMap(id) === 10) }
+      }
+    }
+  } // end of chain connected components
+
+  test("Reverse Chain Connected Components") {
+    withSpark(new SparkContext("local", "test")) { sc =>
+      val chain1 = (0 until 9).map(x => (x, x+1) )
+      val chain2 = (10 until 20).map(x => (x, x+1) )
+      val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
+      val twoChains = Graph(rawEdges).reverse
+      val ccGraph = Analytics.connectedComponents(twoChains).cache()
+      val vertices = ccGraph.vertices.collect
+      for ( (id, cc) <- vertices ) {
+        if(id < 10) { assert(cc === 0) }
+        else { assert(cc === 10) }
+      }
+      val ccMap = vertices.toMap
+      println(ccMap)
+      for( id <- 0 until 20 ) {
+        if(id < 10) { assert(ccMap(id) === 0) }
+        else { assert(ccMap(id) === 10) }
+      }
+    }
+  } // end of chain connected components
+
+
+
 } // end of AnalyticsSuite