diff --git a/graph/src/main/scala/org/apache/spark/graph/Analytics.scala b/graph/src/main/scala/org/apache/spark/graph/Analytics.scala index 8feb42490da8002b2cb6eaa08e7d1731dbce8c50..acb9e3753f8e29194551c1b32f7c9d69683d3af4 100644 --- a/graph/src/main/scala/org/apache/spark/graph/Analytics.scala +++ b/graph/src/main/scala/org/apache/spark/graph/Analytics.scala @@ -167,8 +167,17 @@ object Analytics extends Logging { * and return an RDD with the vertex value containing the * lowest vertex id in the connected component containing * that vertex. + * + * @tparam VD the vertex attribute type (discarded in the computation) + * @tparam ED the edge attribute type (preserved in the computation) + * + * @param graph the graph for which to compute the connected components + * + * @return a graph with vertex attributes containing the smallest vertex + * in each connected component */ - def connectedComponents[VD: Manifest, ED: Manifest](graph: Graph[VD, ED]) = { + def connectedComponents[VD: Manifest, ED: Manifest](graph: Graph[VD, ED]): + Graph[Vid, ED] = { val ccGraph = graph.mapVertices { case (vid, _) => vid } def sendMessage(id: Vid, edge: EdgeTriplet[Vid, ED]): Option[Vid] = { @@ -179,21 +188,27 @@ object Analytics extends Logging { } val initialMessage = Long.MaxValue - Pregel(ccGraph, initialMessage)( + Pregel(ccGraph, initialMessage, EdgeDirection.Both)( (id, attr, msg) => math.min(attr, msg), sendMessage, (a,b) => math.min(a,b) ) + /** + * Originally this was implemented using the GraphLab abstraction but with + * support for message computation along all edge directions the pregel + * abstraction is sufficient + */ // GraphLab(ccGraph, gatherDirection = EdgeDirection.Both, scatterDirection = EdgeDirection.Both)( // (me_id, edge) => edge.otherVertexAttr(me_id), // gather // (a: Vid, b: Vid) => math.min(a, b), // merge // (id, data, a: Option[Vid]) => math.min(data, a.getOrElse(Long.MaxValue)), // apply // (me_id, edge) => (edge.vertexAttr(me_id) < edge.otherVertexAttr(me_id)) // ) + } // end of connectedComponents - } + def main(args: Array[String]) = { val host = args(0) val taskType = args(1) diff --git a/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala b/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala index f4a8c6b4c9f4d81da8f11394cbe030cd1fdb17fc..8d0b2e0b02b75475fa2b3cbe434784417f742359 100644 --- a/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala +++ b/graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala @@ -79,6 +79,7 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext { } // end of test Star PageRank + test("Grid PageRank") { withSpark(new SparkContext("local", "test")) { sc => val gridGraph = GraphGenerators.gridGraph(sc, 10, 10) @@ -104,4 +105,68 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext { } // end of Grid PageRank + test("Grid Connected Components") { + withSpark(new SparkContext("local", "test")) { sc => + val gridGraph = GraphGenerators.gridGraph(sc, 10, 10) + val ccGraph = Analytics.connectedComponents(gridGraph).cache() + val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum + assert(maxCCid === 0) + } + } // end of Grid connected components + + + test("Reverse Grid Connected Components") { + withSpark(new SparkContext("local", "test")) { sc => + val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).reverse + val ccGraph = Analytics.connectedComponents(gridGraph).cache() + val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum + assert(maxCCid === 0) + } + } // end of Grid connected components + + + test("Chain Connected Components") { + withSpark(new SparkContext("local", "test")) { sc => + val chain1 = (0 until 9).map(x => (x, x+1) ) + val chain2 = (10 until 20).map(x => (x, x+1) ) + val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } + val twoChains = Graph(rawEdges) + val ccGraph = Analytics.connectedComponents(twoChains).cache() + val vertices = ccGraph.vertices.collect + for ( (id, cc) <- vertices ) { + if(id < 10) { assert(cc === 0) } + else { assert(cc === 10) } + } + val ccMap = vertices.toMap + println(ccMap) + for( id <- 0 until 20 ) { + if(id < 10) { assert(ccMap(id) === 0) } + else { assert(ccMap(id) === 10) } + } + } + } // end of chain connected components + + test("Reverse Chain Connected Components") { + withSpark(new SparkContext("local", "test")) { sc => + val chain1 = (0 until 9).map(x => (x, x+1) ) + val chain2 = (10 until 20).map(x => (x, x+1) ) + val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) } + val twoChains = Graph(rawEdges).reverse + val ccGraph = Analytics.connectedComponents(twoChains).cache() + val vertices = ccGraph.vertices.collect + for ( (id, cc) <- vertices ) { + if(id < 10) { assert(cc === 0) } + else { assert(cc === 10) } + } + val ccMap = vertices.toMap + println(ccMap) + for( id <- 0 until 20 ) { + if(id < 10) { assert(ccMap(id) === 0) } + else { assert(ccMap(id) === 10) } + } + } + } // end of chain connected components + + + } // end of AnalyticsSuite