Skip to content
Snippets Groups Projects
Commit cbdc01ee authored by Ankur Dave's avatar Ankur Dave
Browse files

Update WikipediaPageRank to reflect Bagel API changes

parent 6d707f6b
No related branches found
No related tags found
No related merge requests found
package spark.bagel.examples
import spark._
import spark.SparkContext._
import spark.bagel._
import spark.bagel.Bagel._
import scala.collection.mutable.ArrayBuffer
import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream}
import com.esotericsoftware.kryo._
class PageRankUtils extends Serializable {
def computeWithCombiner(numVertices: Long, epsilon: Double)(
self: PRVertex, messageSum: Option[Double], superstep: Int
): (PRVertex, Array[PRMessage]) = {
val newValue = messageSum match {
case Some(msgSum) if msgSum != 0 =>
0.15 / numVertices + 0.85 * msgSum
case _ => self.value
}
val terminate = superstep >= 10
val outbox: Array[PRMessage] =
if (!terminate)
self.outEdges.map(targetId =>
new PRMessage(targetId, newValue / self.outEdges.size))
else
Array[PRMessage]()
(new PRVertex(newValue, self.outEdges, !terminate), outbox)
}
def computeNoCombiner(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[Array[PRMessage]], superstep: Int): (PRVertex, Array[PRMessage]) =
computeWithCombiner(numVertices, epsilon)(self, messages match {
case Some(msgs) => Some(msgs.map(_.value).sum)
case None => None
}, superstep)
}
class PRCombiner extends Combiner[PRMessage, Double] with Serializable {
def createCombiner(msg: PRMessage): Double =
msg.value
def mergeMsg(combiner: Double, msg: PRMessage): Double =
combiner + msg.value
def mergeCombiners(a: Double, b: Double): Double =
a + b
}
class PRVertex() extends Vertex with Serializable {
var value: Double = _
var outEdges: Array[String] = _
var active: Boolean = _
def this(value: Double, outEdges: Array[String], active: Boolean = true) {
this()
this.value = value
this.outEdges = outEdges
this.active = active
}
override def toString(): String = {
"PRVertex(value=%f, outEdges.length=%d, active=%s)".format(value, outEdges.length, active.toString)
}
}
class PRMessage() extends Message[String] with Serializable {
var targetId: String = _
var value: Double = _
def this(targetId: String, value: Double) {
this()
this.targetId = targetId
this.value = value
}
}
class PRKryoRegistrator extends KryoRegistrator {
def registerClasses(kryo: Kryo) {
kryo.register(classOf[PRVertex])
kryo.register(classOf[PRMessage])
}
}
class CustomPartitioner(partitions: Int) extends Partitioner {
def numPartitions = partitions
def getPartition(key: Any): Int = {
val hash = key match {
case k: Long => (k & 0x00000000FFFFFFFFL).toInt
case _ => key.hashCode
}
val mod = key.hashCode % partitions
if (mod < 0) mod + partitions else mod
}
override def equals(other: Any): Boolean = other match {
case c: CustomPartitioner =>
c.numPartitions == numPartitions
case _ => false
}
}
......@@ -6,28 +6,23 @@ import spark.SparkContext._
import spark.bagel._
import spark.bagel.Bagel._
import scala.collection.mutable.ArrayBuffer
import scala.xml.{XML,NodeSeq}
import java.io.{Externalizable,ObjectInput,ObjectOutput,DataOutputStream,DataInputStream}
import com.esotericsoftware.kryo._
object WikipediaPageRank {
def main(args: Array[String]) {
if (args.length < 4) {
System.err.println("Usage: WikipediaPageRank <inputFile> <threshold> <numSplits> <host> [<noCombiner>]")
if (args.length < 5) {
System.err.println("Usage: WikipediaPageRank <inputFile> <threshold> <numSplits> <host> <usePartitioner>")
System.exit(-1)
}
System.setProperty("spark.serialization", "spark.KryoSerialization")
System.setProperty("spark.serializer", "spark.KryoSerializer")
System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName)
val inputFile = args(0)
val threshold = args(1).toDouble
val numSplits = args(2).toInt
val host = args(3)
val noCombiner = args.length > 4 && args(4).nonEmpty
val usePartitioner = args(4).toBoolean
val sc = new SparkContext(host, "WikipediaPageRank")
// Parse the Wikipedia page data into a graph
......@@ -38,7 +33,7 @@ object WikipediaPageRank {
println("Done counting vertices.")
println("Parsing input file...")
val vertices: RDD[(String, PRVertex)] = input.map(line => {
var vertices = input.map(line => {
val fields = line.split("\t")
val (title, body) = (fields(1), fields(3).replace("\\n", "\n"))
val links =
......@@ -52,105 +47,33 @@ object WikipediaPageRank {
System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body)
NodeSeq.Empty
}
val outEdges = ArrayBuffer(links.map(link => new PREdge(new String(link.text))): _*)
val outEdges = links.map(link => new String(link.text)).toArray
val id = new String(title)
(id, new PRVertex(id, 1.0 / numVertices, outEdges, true))
}).cache
(id, new PRVertex(1.0 / numVertices, outEdges))
})
if (usePartitioner)
vertices = vertices.partitionBy(new HashPartitioner(sc.defaultParallelism)).cache
else
vertices = vertices.cache
println("Done parsing input file.")
// Do the computation
val epsilon = 0.01 / numVertices
val messages = sc.parallelize(List[(String, PRMessage)]())
val messages = sc.parallelize(Array[(String, PRMessage)]())
val utils = new PageRankUtils
val result =
if (noCombiner) {
Bagel.run(sc, vertices, messages)(numSplits = numSplits)(PRNoCombiner.compute(numVertices, epsilon))
} else {
Bagel.run(sc, vertices, messages)(combiner = PRCombiner, numSplits = numSplits)(PRCombiner.compute(numVertices, epsilon))
}
Bagel.run(
sc, vertices, messages, combiner = new PRCombiner(),
numSplits = numSplits)(
utils.computeWithCombiner(numVertices, epsilon))
// Print the result
System.err.println("Articles with PageRank >= "+threshold+":")
val top = result.filter(_.value >= threshold).map(vertex =>
"%s\t%s\n".format(vertex.id, vertex.value)).collect.mkString
val top =
(result
.filter { case (id, vertex) => vertex.value >= threshold }
.map { case (id, vertex) => "%s\t%s\n".format(id, vertex.value) }
.collect.mkString)
println(top)
}
}
object PRCombiner extends Combiner[PRMessage, Double] with Serializable {
def createCombiner(msg: PRMessage): Double =
msg.value
def mergeMsg(combiner: Double, msg: PRMessage): Double =
combiner + msg.value
def mergeCombiners(a: Double, b: Double): Double =
a + b
def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = {
val newValue = messageSum match {
case Some(msgSum) if msgSum != 0 =>
0.15 / numVertices + 0.85 * msgSum
case _ => self.value
}
val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30
val outbox =
if (!terminate)
self.outEdges.map(edge =>
new PRMessage(edge.targetId, newValue / self.outEdges.size))
else
ArrayBuffer[PRMessage]()
(new PRVertex(self.id, newValue, self.outEdges, !terminate), outbox)
}
}
object PRNoCombiner extends DefaultCombiner[PRMessage] with Serializable {
def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) =
PRCombiner.compute(numVertices, epsilon)(self, messages match {
case Some(msgs) => Some(msgs.map(_.value).sum)
case None => None
}, superstep)
}
class PRVertex() extends Vertex with Serializable {
var id: String = _
var value: Double = _
var outEdges: ArrayBuffer[PREdge] = _
var active: Boolean = true
def this(id: String, value: Double, outEdges: ArrayBuffer[PREdge], active: Boolean) {
this()
this.id = id
this.value = value
this.outEdges = outEdges
this.active = active
}
}
class PRMessage() extends Message with Serializable {
var targetId: String = _
var value: Double = _
def this(targetId: String, value: Double) {
this()
this.targetId = targetId
this.value = value
}
}
class PREdge() extends Edge with Serializable {
var targetId: String = _
def this(targetId: String) {
this()
this.targetId = targetId
}
}
class PRKryoRegistrator extends KryoRegistrator {
def registerClasses(kryo: Kryo) {
kryo.register(classOf[PRVertex])
kryo.register(classOf[PRMessage])
kryo.register(classOf[PREdge])
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment