Skip to content
Snippets Groups Projects
Commit d6a902f3 authored by Joseph E. Gonzalez's avatar Joseph E. Gonzalez
Browse files

Finished updating connected components to used Pregel like abstraction and...

Finished updating connected components to used Pregel like abstraction and created a series of tests in the AnalyticsSuite.
parent a2287ae1
No related branches found
No related tags found
No related merge requests found
......@@ -167,8 +167,17 @@ object Analytics extends Logging {
* and return an RDD with the vertex value containing the
* lowest vertex id in the connected component containing
* that vertex.
*
* @tparam VD the vertex attribute type (discarded in the computation)
* @tparam ED the edge attribute type (preserved in the computation)
*
* @param graph the graph for which to compute the connected components
*
* @return a graph with vertex attributes containing the smallest vertex
* in each connected component
*/
def connectedComponents[VD: Manifest, ED: Manifest](graph: Graph[VD, ED]) = {
def connectedComponents[VD: Manifest, ED: Manifest](graph: Graph[VD, ED]):
Graph[Vid, ED] = {
val ccGraph = graph.mapVertices { case (vid, _) => vid }
def sendMessage(id: Vid, edge: EdgeTriplet[Vid, ED]): Option[Vid] = {
......@@ -179,21 +188,27 @@ object Analytics extends Logging {
}
val initialMessage = Long.MaxValue
Pregel(ccGraph, initialMessage)(
Pregel(ccGraph, initialMessage, EdgeDirection.Both)(
(id, attr, msg) => math.min(attr, msg),
sendMessage,
(a,b) => math.min(a,b)
)
/**
* Originally this was implemented using the GraphLab abstraction but with
* support for message computation along all edge directions the pregel
* abstraction is sufficient
*/
// GraphLab(ccGraph, gatherDirection = EdgeDirection.Both, scatterDirection = EdgeDirection.Both)(
// (me_id, edge) => edge.otherVertexAttr(me_id), // gather
// (a: Vid, b: Vid) => math.min(a, b), // merge
// (id, data, a: Option[Vid]) => math.min(data, a.getOrElse(Long.MaxValue)), // apply
// (me_id, edge) => (edge.vertexAttr(me_id) < edge.otherVertexAttr(me_id))
// )
} // end of connectedComponents
}
def main(args: Array[String]) = {
val host = args(0)
val taskType = args(1)
......
......@@ -79,6 +79,7 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
} // end of test Star PageRank
test("Grid PageRank") {
withSpark(new SparkContext("local", "test")) { sc =>
val gridGraph = GraphGenerators.gridGraph(sc, 10, 10)
......@@ -104,4 +105,68 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
} // end of Grid PageRank
test("Grid Connected Components") {
withSpark(new SparkContext("local", "test")) { sc =>
val gridGraph = GraphGenerators.gridGraph(sc, 10, 10)
val ccGraph = Analytics.connectedComponents(gridGraph).cache()
val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum
assert(maxCCid === 0)
}
} // end of Grid connected components
test("Reverse Grid Connected Components") {
withSpark(new SparkContext("local", "test")) { sc =>
val gridGraph = GraphGenerators.gridGraph(sc, 10, 10).reverse
val ccGraph = Analytics.connectedComponents(gridGraph).cache()
val maxCCid = ccGraph.vertices.map { case (vid, ccId) => ccId }.sum
assert(maxCCid === 0)
}
} // end of Grid connected components
test("Chain Connected Components") {
withSpark(new SparkContext("local", "test")) { sc =>
val chain1 = (0 until 9).map(x => (x, x+1) )
val chain2 = (10 until 20).map(x => (x, x+1) )
val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
val twoChains = Graph(rawEdges)
val ccGraph = Analytics.connectedComponents(twoChains).cache()
val vertices = ccGraph.vertices.collect
for ( (id, cc) <- vertices ) {
if(id < 10) { assert(cc === 0) }
else { assert(cc === 10) }
}
val ccMap = vertices.toMap
println(ccMap)
for( id <- 0 until 20 ) {
if(id < 10) { assert(ccMap(id) === 0) }
else { assert(ccMap(id) === 10) }
}
}
} // end of chain connected components
test("Reverse Chain Connected Components") {
withSpark(new SparkContext("local", "test")) { sc =>
val chain1 = (0 until 9).map(x => (x, x+1) )
val chain2 = (10 until 20).map(x => (x, x+1) )
val rawEdges = sc.parallelize(chain1 ++ chain2, 3).map { case (s,d) => (s.toLong, d.toLong) }
val twoChains = Graph(rawEdges).reverse
val ccGraph = Analytics.connectedComponents(twoChains).cache()
val vertices = ccGraph.vertices.collect
for ( (id, cc) <- vertices ) {
if(id < 10) { assert(cc === 0) }
else { assert(cc === 10) }
}
val ccMap = vertices.toMap
println(ccMap)
for( id <- 0 until 20 ) {
if(id < 10) { assert(ccMap(id) === 0) }
else { assert(ccMap(id) === 10) }
}
}
} // end of chain connected components
} // end of AnalyticsSuite
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment