Skip to content
Snippets Groups Projects
Commit c5b3ea75 authored by Ankur Dave's avatar Ankur Dave
Browse files

Clean up Bagel source and interface

parent 19122af7
No related branches found
No related tags found
No related merge requests found
......@@ -7,75 +7,81 @@ import scala.collection.mutable.ArrayBuffer
object Pregel extends Logging {
/**
* Runs a Pregel job on the given vertices, running the specified
* compute function on each vertex in every superstep. Before
* beginning the first superstep, sends the given messages to their
* destination vertices. In the join stage, launches splits
* separate tasks (where splits is manually specified to work
* around a bug in Spark).
* Runs a Pregel job on the given vertices consisting of the
* specified compute function.
*
* Halts when no more messages are being sent between vertices, and
* all vertices have voted to halt by setting their state to
* Inactive.
* Before beginning the first superstep, the given messages are sent
* to their destination vertices.
*
* During the job, the specified combiner functions are applied to
* messages as they travel between vertices.
*
* The job halts and returns the resulting set of vertices when no
* messages are being sent between vertices and all vertices have
* voted to halt by setting their state to inactive.
*/
def run[V <: Vertex : Manifest, M <: Message : Manifest, C](sc: SparkContext, verts: RDD[(String, V)], msgs: RDD[(String, M)], splits: Int, messageCombiner: (C, M) => C, defaultCombined: () => C, mergeCombined: (C, C) => C, maxSupersteps: Option[Int] = None, superstep: Int = 0)(compute: (V, C, Int) => (V, Iterable[M])): RDD[V] = {
def run[V <: Vertex : Manifest, M <: Message : Manifest, C](
sc: SparkContext,
verts: RDD[(String, V)],
msgs: RDD[(String, M)],
createCombiner: M => C,
mergeMsg: (C, M) => C,
mergeCombiners: (C, C) => C,
numSplits: Int,
superstep: Int = 0
)(compute: (V, Option[C], Int) => (V, Iterable[M])): RDD[V] = {
logInfo("Starting superstep "+superstep+".")
val startTime = System.currentTimeMillis
// Bring together vertices and messages
val combinedMsgs = msgs.combineByKey({x => messageCombiner(defaultCombined(), x)}, messageCombiner, mergeCombined, splits)
logDebug("verts.splits.size = " + verts.splits.size)
logDebug("combinedMsgs.splits.size = " + combinedMsgs.splits.size)
logDebug("verts.partitioner = " + verts.partitioner)
logDebug("combinedMsgs.partitioner = " + combinedMsgs.partitioner)
val joined = verts.groupWith(combinedMsgs)
logDebug("joined.splits.size = " + joined.splits.size)
logDebug("joined.partitioner = " + joined.partitioner)
val combinedMsgs = msgs.combineByKey(createCombiner, mergeMsg, mergeCombiners, numSplits)
val grouped = verts.groupWith(combinedMsgs)
// Run compute on each vertex
var messageCount = sc.accumulator(0)
var activeVertexCount = sc.accumulator(0)
val processed = joined.flatMapValues {
var numMsgs = sc.accumulator(0)
var numActiveVerts = sc.accumulator(0)
val processed = grouped.flatMapValues {
case (Seq(), _) => None
case (Seq(v), Seq(comb)) =>
val (newVertex, newMessages) = compute(v, comb, superstep)
case (Seq(v), c) =>
val (newVert, newMsgs) =
compute(v, c match {
case Seq(comb) => Some(comb)
case Seq() => None
}, superstep)
messageCount += newMessages.size
if (newVertex.active)
activeVertexCount += 1
numMsgs += newMsgs.size
if (newVert.active)
numActiveVerts += 1
Some((newVertex, newMessages))
case (Seq(v), Seq()) =>
val (newVertex, newMessages) = compute(v, defaultCombined(), superstep)
messageCount += newMessages.size
if (newVertex.active)
activeVertexCount += 1
Some((newVertex, newMessages))
Some((newVert, newMsgs))
}.cache
// Force evaluation of processed RDD for accurate performance measurements
processed.foreach(x => {})
val timeTaken = System.currentTimeMillis - startTime
logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
// Check stopping condition and recurse
val stop = messageCount.value == 0 && activeVertexCount.value == 0
if (stop || (maxSupersteps.isDefined && superstep >= maxSupersteps.get)) {
processed.map { _._2._1 }
// Check stopping condition and iterate
val noActivity = numMsgs.value == 0 && numActiveVerts.value == 0
if (noActivity) {
processed.map { case (id, (vert, msgs)) => vert }
} else {
val newVerts = processed.mapValues(_._1)
val newMsgs = processed.flatMap(x => x._2._2.map(m => (m.targetId, m)))
run(sc, newVerts, newMsgs, splits, messageCombiner, defaultCombined, mergeCombined, maxSupersteps, superstep + 1)(compute)
val newVerts = processed.mapValues { case (vert, msgs) => vert }
val newMsgs = processed.flatMap {
case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m))
}
run(sc, newVerts, newMsgs, createCombiner, mergeMsg, mergeCombiners, numSplits, superstep + 1)(compute)
}
}
}
/**
* Represents a Pregel vertex. Must be subclassed to store state
* along with each vertex. Must be annotated with @serializable.
* Represents a Pregel vertex.
*
* Subclasses may store state along with each vertex and must be
* annotated with @serializable.
*/
trait Vertex {
def id: String
......@@ -83,17 +89,20 @@ trait Vertex {
}
/**
* Represents a Pregel message to a target vertex. Must be
* subclassed to contain a payload. Must be annotated with @serializable.
* Represents a Pregel message to a target vertex.
*
* Subclasses may contain a payload to deliver to the target vertex
* and must be annotated with @serializable.
*/
trait Message {
def targetId: String
}
/**
* Represents a directed edge between two vertices. Owned by the
* source vertex, and contains the ID of the target vertex. Must
* be subclassed to store state along with each edge. Must be annotated with @serializable.
* Represents a directed edge between two vertices.
*
* Subclasses may store state along each edge and must be annotated
* with @serializable.
*/
trait Edge {
def targetId: String
......
......@@ -49,12 +49,17 @@ object ShortestPath {
messages.count()+" messages.")
// Do the computation
def messageCombiner(minSoFar: Int, message: SPMessage): Int =
min(minSoFar, message.value)
def createCombiner(message: SPMessage): Int = message.value
def mergeMsg(combiner: Int, message: SPMessage): Int =
min(combiner, message.value)
def mergeCombiners(a: Int, b: Int): Int = min(a, b)
val result = Pregel.run(sc, vertices, messages, numSplits, messageCombiner, () => Int.MaxValue, min _) {
(self: SPVertex, messageMinValue: Int, superstep: Int) =>
val newValue = min(self.value, messageMinValue)
val result = Pregel.run(sc, vertices, messages, createCombiner, mergeMsg, mergeCombiners, numSplits) {
(self: SPVertex, messageMinValue: Option[Int], superstep: Int) =>
val newValue = messageMinValue match {
case Some(minVal) => min(self.value, minVal)
case None => self.value
}
val outbox =
if (newValue != self.value)
......
......@@ -4,7 +4,6 @@ import spark._
import spark.SparkContext._
import scala.collection.mutable.ArrayBuffer
import scala.xml.{XML,NodeSeq}
import java.io.{Externalizable,ObjectInput,ObjectOutput,DataOutputStream,DataInputStream}
......@@ -14,7 +13,7 @@ import com.esotericsoftware.kryo._
object WikipediaPageRank {
def main(args: Array[String]) {
if (args.length < 4) {
System.err.println("Usage: PageRank <inputFile> <threshold> <numSplits> <host> [<noCombiner>]")
System.err.println("Usage: WikipediaPageRank <inputFile> <threshold> <numSplits> <host> [<noCombiner>]")
System.exit(-1)
}
......@@ -52,22 +51,18 @@ object WikipediaPageRank {
}
val outEdges = ArrayBuffer(links.map(link => new PREdge(new String(link.text))): _*)
val id = new String(title)
(id, (new PRVertex(id, 1.0 / numVertices, outEdges, true)))
})
val graph = vertices.groupByKey(numSplits).mapValues(_.head).cache
(id, new PRVertex(id, 1.0 / numVertices, outEdges, true))
}).cache
println("Done parsing input file.")
println("Input file had "+graph.count+" vertices.")
// Do the computation
val epsilon = 0.01 / numVertices
val messages = sc.parallelize(List[(String, PRMessage)]())
val result =
if (noCombiner) {
val messages = sc.parallelize(List[(String, PRMessage)]())
Pregel.run[PRVertex, PRMessage, ArrayBuffer[PRMessage]](sc, graph, messages, numSplits, NoCombiner.messageCombiner, NoCombiner.defaultCombined, NoCombiner.mergeCombined)(NoCombiner.compute(numVertices, epsilon))
Pregel.run[PRVertex, PRMessage, ArrayBuffer[PRMessage]](sc, vertices, messages, NoCombiner.createCombiner, NoCombiner.mergeMsg, NoCombiner.mergeCombiners, numSplits)(NoCombiner.compute(numVertices, epsilon))
} else {
val messages = sc.parallelize(List[(String, PRMessage)]())
Pregel.run[PRVertex, PRMessage, Double](sc, graph, messages, numSplits, Combiner.messageCombiner, Combiner.defaultCombined, Combiner.mergeCombined)(Combiner.compute(numVertices, epsilon))
Pregel.run[PRVertex, PRMessage, Double](sc, vertices, messages, Combiner.createCombiner, Combiner.mergeMsg, Combiner.mergeCombiners, numSplits)(Combiner.compute(numVertices, epsilon))
}
// Print the result
......@@ -78,19 +73,19 @@ object WikipediaPageRank {
}
object Combiner {
def messageCombiner(minSoFar: Double, message: PRMessage): Double =
minSoFar + message.value
def createCombiner(message: PRMessage): Double = message.value
def mergeCombined(a: Double, b: Double) = a + b
def mergeMsg(combiner: Double, message: PRMessage): Double =
combiner + message.value
def defaultCombined(): Double = 0.0
def mergeCombiners(a: Double, b: Double) = a + b
def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Double, superstep: Int): (PRVertex, Iterable[PRMessage]) = {
val newValue =
if (messageSum != 0)
0.15 / numVertices + 0.85 * messageSum
else
self.value
def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = {
val newValue = messageSum match {
case Some(msgSum) if msgSum != 0 =>
0.15 / numVertices + 0.85 * msgSum
case _ => self.value
}
val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30
......@@ -106,20 +101,24 @@ object WikipediaPageRank {
}
object NoCombiner {
def messageCombiner(messagesSoFar: ArrayBuffer[PRMessage], message: PRMessage): ArrayBuffer[PRMessage] =
messagesSoFar += message
def createCombiner(message: PRMessage): ArrayBuffer[PRMessage] =
ArrayBuffer(message)
def mergeCombined(a: ArrayBuffer[PRMessage], b: ArrayBuffer[PRMessage]): ArrayBuffer[PRMessage] =
a ++= b
def mergeMsg(combiner: ArrayBuffer[PRMessage], message: PRMessage): ArrayBuffer[PRMessage] =
combiner += message
def defaultCombined(): ArrayBuffer[PRMessage] = ArrayBuffer[PRMessage]()
def mergeCombiners(a: ArrayBuffer[PRMessage], b: ArrayBuffer[PRMessage]): ArrayBuffer[PRMessage] =
a ++= b
def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Seq[PRMessage], superstep: Int): (PRVertex, Iterable[PRMessage]) =
Combiner.compute(numVertices, epsilon)(self, messages.map(_.value).sum, superstep)
def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) =
Combiner.compute(numVertices, epsilon)(self, messages match {
case Some(msgs) => Some(msgs.map(_.value).sum)
case None => None
}, superstep)
}
}
@serializable class PRVertex() extends Vertex with Externalizable {
@serializable class PRVertex() extends Vertex {
var id: String = _
var value: Double = _
var outEdges: ArrayBuffer[PREdge] = _
......@@ -132,29 +131,9 @@ object WikipediaPageRank {
this.outEdges = outEdges
this.active = active
}
def writeExternal(out: ObjectOutput) {
out.writeUTF(id)
out.writeDouble(value)
out.writeInt(outEdges.length)
for (e <- outEdges)
out.writeUTF(e.targetId)
out.writeBoolean(active)
}
def readExternal(in: ObjectInput) {
id = in.readUTF()
value = in.readDouble()
val numEdges = in.readInt()
outEdges = new ArrayBuffer[PREdge](numEdges)
for (i <- 0 until numEdges) {
outEdges += new PREdge(in.readUTF())
}
active = in.readBoolean()
}
}
@serializable class PRMessage() extends Message with Externalizable {
@serializable class PRMessage() extends Message {
var targetId: String = _
var value: Double = _
......@@ -163,33 +142,15 @@ object WikipediaPageRank {
this.targetId = targetId
this.value = value
}
def writeExternal(out: ObjectOutput) {
out.writeUTF(targetId)
out.writeDouble(value)
}
def readExternal(in: ObjectInput) {
targetId = in.readUTF()
value = in.readDouble()
}
}
@serializable class PREdge() extends Edge with Externalizable {
@serializable class PREdge() extends Edge {
var targetId: String = _
def this(targetId: String) {
this()
this.targetId = targetId
}
def writeExternal(out: ObjectOutput) {
out.writeUTF(targetId)
}
def readExternal(in: ObjectInput) {
targetId = in.readUTF()
}
}
class PRKryoRegistrator extends KryoRegistrator {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment