Skip to content
Snippets Groups Projects
Commit c30624dc authored by Joseph E. Gonzalez's avatar Joseph E. Gonzalez
Browse files

Adding dynamic pregel, fixing bugs in PageRank, and adding basic analytics unit tests.

parent 0bd92ed8
No related branches found
No related tags found
No related merge requests found
......@@ -35,10 +35,10 @@ object Analytics extends Logging {
def vertexProgram(id: Vid, attr: Double, msgSum: Double): Double =
resetProb + (1.0 - resetProb) * msgSum
def sendMessage(id: Vid, edge: EdgeTriplet[Double, Double]): Option[Double] =
Some(edge.srcAttr / edge.attr)
Some(edge.srcAttr * edge.attr)
def messageCombiner(a: Double, b: Double): Double = a + b
// The initial message received by all vertices in PageRank
val initialMessage = 1.0
val initialMessage = 0.0
// Execute pregel for a fixed number of iterations.
Pregel(pagerankGraph, initialMessage, numIter)(
......@@ -49,8 +49,8 @@ object Analytics extends Logging {
/**
* Compute the PageRank of a graph returning the pagerank of each vertex as an RDD
*/
def dynamicPagerank[VD: Manifest, ED: Manifest](
graph: Graph[VD, ED], tol: Float, resetProb: Double = 0.15): Graph[Double, Double] = {
def deltaPagerank[VD: Manifest, ED: Manifest](
graph: Graph[VD, ED], tol: Double, resetProb: Double = 0.15): Graph[Double, Double] = {
/**
* Initialize the pagerankGraph with each edge attribute
......@@ -64,7 +64,7 @@ object Analytics extends Logging {
// Set the weight on the edges based on the degree
.mapTriplets( e => 1.0 / e.srcAttr )
// Set the vertex attributes to (initalPR, delta = 0)
.mapVertices( (id, attr) => (resetProb, 0.0) )
.mapVertices( (id, attr) => (0.0, 0.0) )
// Display statistics about pagerank
println(pagerankGraph.statistics)
......@@ -78,12 +78,12 @@ object Analytics extends Logging {
}
def sendMessage(id: Vid, edge: EdgeTriplet[(Double, Double), Double]): Option[Double] = {
if (edge.srcAttr._2 > tol) {
Some(edge.srcAttr._2 / edge.attr)
Some(edge.srcAttr._2 * edge.attr)
} else { None }
}
def messageCombiner(a: Double, b: Double): Double = a + b
// The initial message received by all vertices in PageRank
val initialMessage = 1.0 / (1.0 - resetProb)
val initialMessage = resetProb / (1.0 - resetProb)
// Execute a dynamic version of Pregel.
Pregel(pagerankGraph, initialMessage)(
......
......@@ -43,15 +43,12 @@ object Pregel {
mergeMsg: (A, A) => A)
: Graph[VD, ED] = {
var g = graph
//var g = graph.cache()
var i = 0
def mapF(vid: Vid, edge: EdgeTriplet[VD,ED]) = sendMsg(edge.otherVertexId(vid), edge)
// Receive the first set of messages
g = g.mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg))
var g = graph.mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg))
var i = 0
while (i < numIter) {
// compute the messages
val messages = g.aggregateNeighbors(mapF, mergeMsg, EdgeDirection.In)
......@@ -96,27 +93,45 @@ object Pregel {
mergeMsg: (A, A) => A)
: Graph[VD, ED] = {
var g = graph
//var g = graph.cache()
var i = 0
def mapF(vid: Vid, edge: EdgeTriplet[VD,ED]) = sendMsg(edge.otherVertexId(vid), edge)
def vprogFun(id: Vid, attr: (VD, Boolean), msgOpt: Option[A]): (VD, Boolean) = {
msgOpt match {
case Some(msg) => (vprog(id, attr._1, msg), true)
case None => (attr._1, false)
}
}
// Receive the first set of messages
g.mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg))
def sendMsgFun(vid: Vid, edge: EdgeTriplet[(VD,Boolean), ED]): Option[A] = {
if(edge.srcAttr._2) {
val et = new EdgeTriplet[VD, ED]
et.srcId = edge.srcId
et.srcAttr = edge.srcAttr._1
et.dstId = edge.dstId
et.dstAttr = edge.dstAttr._1
et.attr = edge.attr
sendMsg(edge.otherVertexId(vid), et)
} else { None }
}
var activeMessages = g.numEdges
var g = graph.mapVertices( (vid, vdata) => (vprog(vid, vdata, initialMsg), true) )
// compute the messages
var messages = g.aggregateNeighbors(sendMsgFun, mergeMsg, EdgeDirection.In).cache
var activeMessages = messages.count
// Loop
var i = 0
while (activeMessages > 0) {
// receive the messages
g = g.outerJoinVertices(messages)(vprogFun)
val oldMessages = messages
// compute the messages
val messages = g.aggregateNeighbors(mapF, mergeMsg, EdgeDirection.In).cache
messages = g.aggregateNeighbors(sendMsgFun, mergeMsg, EdgeDirection.In).cache
activeMessages = messages.count
// receive the messages
g = g.joinVertices(messages)(vprog)
// after counting we can unpersist the old messages
oldMessages.unpersist(blocking=false)
// count the iteration
i += 1
}
// Return the final graph
g
g.mapVertices((id, attr) => attr._1)
} // end of apply
} // end of class Pregel
......@@ -3,9 +3,45 @@ package org.apache.spark.graph
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.graph.LocalSparkContext._
import org.apache.spark.graph.util.GraphGenerators
import org.apache.spark.graph.Analytics
object GridPageRank {
def apply(nRows: Int, nCols: Int, nIter: Int, resetProb: Double) = {
val inNbrs = Array.fill(nRows * nCols)(collection.mutable.MutableList.empty[Int])
val outDegree = Array.fill(nRows * nCols)(0)
// Convert row column address into vertex ids (row major order)
def sub2ind(r: Int, c: Int): Int = r * nCols + c
// Make the grid graph
for(r <- 0 until nRows; c <- 0 until nCols){
val ind = sub2ind(r,c)
if(r+1 < nRows) {
outDegree(ind) += 1
inNbrs(sub2ind(r+1,c)) += ind
}
if(c+1 < nCols) {
outDegree(ind) += 1
inNbrs(sub2ind(r,c+1)) += ind
}
}
// compute the pagerank
var pr = Array.fill(nRows * nCols)(resetProb)
for(iter <- 0 until nIter) {
val oldPr = pr
pr = new Array[Double](nRows * nCols)
for(ind <- 0 until (nRows * nCols)) {
pr(ind) = resetProb + (1.0 - resetProb) *
inNbrs(ind).map( nbr => oldPr(nbr) / outDegree(nbr)).sum
}
}
(0L until (nRows * nCols)).zip(pr)
}
}
class AnalyticsSuite extends FunSuite with LocalSparkContext {
......@@ -13,18 +49,59 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
System.setProperty("spark.kryo.registrator", "org.apache.spark.graph.GraphKryoRegistrator")
val sc = new Sparkcontext("local", "test")
test("Fixed Iterations PageRank") {
val starGraph = GraphGenerators.starGraph(sc, 1000)
val resetProb = 0.15
val prGraph1 = Analytics.pagerank(graph, 1, resetProb)
val prGraph2 = Analytics.pagerank(grpah, 2, resetProb)
val errors = prGraph1.vertices.zipJoin(prGraph2.vertices)
.map{ case (vid, (pr1, pr2)) => if (pr1 != pr2) { 1 } else { 0 } }.sum
test("Star PageRank") {
withSpark(new SparkContext("local", "test")) { sc =>
val nVertices = 100
val starGraph = GraphGenerators.starGraph(sc, nVertices)
val resetProb = 0.15
val prGraph1 = Analytics.pagerank(starGraph, 1, resetProb)
val prGraph2 = Analytics.pagerank(starGraph, 2, resetProb)
val notMatching = prGraph1.vertices.zipJoin(prGraph2.vertices)
.map{ case (vid, (pr1, pr2)) => if (pr1 != pr2) { 1 } else { 0 } }.sum
assert(notMatching === 0)
prGraph2.vertices.foreach(println(_))
val errors = prGraph2.vertices.map{ case (vid, pr) =>
val correct = (vid > 0 && pr == resetProb) ||
(vid == 0 && math.abs(pr - (resetProb + (1.0 - resetProb) * (resetProb * (nVertices - 1)) )) < 1.0E-5)
if ( !correct ) { 1 } else { 0 }
}
assert(errors.sum === 0)
val prGraph3 = Analytics.deltaPagerank(starGraph, 0, resetProb)
val errors2 = prGraph2.vertices.leftJoin(prGraph3.vertices).map{
case (_, (pr1, Some(pr2))) if(pr1 == pr2) => 0
case _ => 1
}.sum
assert(errors2 === 0)
}
} // end of test Star PageRank
}
test("Grid PageRank") {
withSpark(new SparkContext("local", "test")) { sc =>
val gridGraph = GraphGenerators.gridGraph(sc, 10, 10)
val resetProb = 0.15
val prGraph1 = Analytics.pagerank(gridGraph, 50, resetProb).cache()
val prGraph2 = Analytics.deltaPagerank(gridGraph, 0.0001, resetProb).cache()
val error = prGraph1.vertices.zipJoin(prGraph2.vertices).map {
case (id, (a, b)) => (a - b) * (a - b)
}.sum
prGraph1.vertices.zipJoin(prGraph2.vertices)
.map{ case (id, (a,b)) => (id, (a,b, a-b))}.foreach(println(_))
println(error)
assert(error < 1.0e-5)
val pr3 = sc.parallelize(GridPageRank(10,10, 50, resetProb))
val error2 = prGraph1.vertices.leftJoin(pr3).map {
case (id, (a, Some(b))) => (a - b) * (a - b)
case _ => 0
}.sum
prGraph1.vertices.leftJoin(pr3).foreach(println( _ ))
println(error2)
assert(error2 < 1.0e-5)
}
} // end of Grid PageRank
} // end of AnalyticsSuite
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment