From ab3889f62753beb51a550b500cd0830adb03d4cc Mon Sep 17 00:00:00 2001 From: Ankur Dave <ankurdave@gmail.com> Date: Sun, 9 Oct 2011 16:15:15 -0700 Subject: [PATCH] Implement standalone WikipediaPageRank with custom serializer --- .../WikipediaPageRankStandalone.scala | 198 ++++++++++++++++++ 1 file changed, 198 insertions(+) create mode 100644 bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala diff --git a/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala new file mode 100644 index 0000000000..2e38376499 --- /dev/null +++ b/bagel/src/main/scala/spark/bagel/examples/WikipediaPageRankStandalone.scala @@ -0,0 +1,198 @@ +package spark.bagel.examples + +import spark._ +import spark.SparkContext._ + +import spark.bagel._ +import spark.bagel.Bagel._ + +import scala.xml.{XML,NodeSeq} + +import scala.collection.mutable.ArrayBuffer + +import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} + +object WikipediaPageRankStandalone { + def main(args: Array[String]) { + if (args.length < 5) { + System.err.println("Usage: WikipediaPageRankStandalone <inputFile> <threshold> <numIterations> <host> <usePartitioner>") + System.exit(-1) + } + + System.setProperty("spark.serializer", "spark.bagel.examples.WPRSerializer") + + val inputFile = args(0) + val threshold = args(1).toDouble + val numIterations = args(2).toInt + val host = args(3) + val usePartitioner = args(4).toBoolean + val sc = new SparkContext(host, "WikipediaPageRankStandalone") + + val input = sc.textFile(inputFile) + val partitioner = new HashPartitioner(sc.defaultParallelism) + val links = + if (usePartitioner) + input.map(parseArticle _).partitionBy(partitioner).cache + else + input.map(parseArticle _).cache + val n = links.count + val defaultRank = 1.0 / n + val a = 0.15 + + // Do the computation + val startTime = System.currentTimeMillis + val ranks = + pageRank(links, numIterations, defaultRank, a, n, partitioner, usePartitioner, sc.defaultParallelism) + + // Print the result + System.err.println("Articles with PageRank >= "+threshold+":") + val top = + (ranks + .filter { case (id, rank) => rank >= threshold } + .map { case (id, rank) => "%s\t%s\n".format(id, rank) } + .collect.mkString) + println(top) + + val time = (System.currentTimeMillis - startTime) / 1000.0 + println("Completed %d iterations in %f seconds: %f seconds per iteration" + .format(numIterations, time, time / numIterations)) + System.exit(0) + } + + def parseArticle(line: String): (String, Array[String]) = { + val fields = line.split("\t") + val (title, body) = (fields(1), fields(3).replace("\\n", "\n")) + val id = new String(title) + val links = + if (body == "\\N") + NodeSeq.Empty + else + try { + XML.loadString(body) \\ "link" \ "target" + } catch { + case e: org.xml.sax.SAXParseException => + System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body) + NodeSeq.Empty + } + val outEdges = links.map(link => new String(link.text)).toArray + (id, outEdges) + } + + def pageRank( + links: RDD[(String, Array[String])], + numIterations: Int, + defaultRank: Double, + a: Double, + n: Long, + partitioner: Partitioner, + usePartitioner: Boolean, + numSplits: Int + ): RDD[(String, Double)] = { + var ranks = links.mapValues { edges => defaultRank } + for (i <- 1 to numIterations) { + val contribs = links.groupWith(ranks).flatMap { + case (id, (linksWrapper, rankWrapper)) => + if (linksWrapper.length > 0) { + if (rankWrapper.length > 0) { + linksWrapper(0).map(dest => (dest, rankWrapper(0) / linksWrapper(0).size)) + } else { + linksWrapper(0).map(dest => (dest, defaultRank / linksWrapper(0).size)) + } + } else { + Array[(String, Double)]() + } + } + ranks = (contribs.combineByKey((x: Double) => x, + (x: Double, y: Double) => x + y, + (x: Double, y: Double) => x + y, + numSplits, + partitioner) + .mapValues(sum => a/n + (1-a)*sum)) + } + ranks + } +} + +class WPRSerializer extends spark.Serializer { + def newInstance(): SerializerInstance = new WPRSerializerInstance() +} + +class WPRSerializerInstance extends SerializerInstance { + def serialize[T](t: T): Array[Byte] = { + throw new UnsupportedOperationException() + } + + def deserialize[T](bytes: Array[Byte]): T = { + throw new UnsupportedOperationException() + } + + def outputStream(s: OutputStream): SerializationStream = { + new WPRSerializationStream(s) + } + + def inputStream(s: InputStream): DeserializationStream = { + new WPRDeserializationStream(s) + } +} + +class WPRSerializationStream(os: OutputStream) extends SerializationStream { + val dos = new DataOutputStream(os) + + def writeObject[T](t: T): Unit = t match { + case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match { + case links: Array[String] => { + dos.writeInt(0) // links + dos.writeUTF(id) + dos.writeInt(links.length) + for (link <- links) { + dos.writeUTF(link) + } + } + case rank: Double => { + dos.writeInt(1) // rank + dos.writeUTF(id) + dos.writeDouble(rank) + } + } + case (id: String, rank: Double) => { + dos.writeInt(2) // rank without wrapper + dos.writeUTF(id) + dos.writeDouble(rank) + } + } + + def flush() { dos.flush() } + def close() { dos.close() } +} + +class WPRDeserializationStream(is: InputStream) extends DeserializationStream { + val dis = new DataInputStream(is) + + def readObject[T](): T = { + val typeId = dis.readInt() + typeId match { + case 0 => { + val id = dis.readUTF() + val numLinks = dis.readInt() + val links = new Array[String](numLinks) + for (i <- 0 until numLinks) { + val link = dis.readUTF() + links(i) = link + } + (id, ArrayBuffer(links)).asInstanceOf[T] + } + case 1 => { + val id = dis.readUTF() + val rank = dis.readDouble() + (id, ArrayBuffer(rank)).asInstanceOf[T] + } + case 2 => { + val id = dis.readUTF() + val rank = dis.readDouble() + (id, rank).asInstanceOf[T] + } + } + } + + def close() { dis.close() } +} -- GitLab