From 80abc2807825d69b0f7a5e374eb6e6442332f400 Mon Sep 17 00:00:00 2001
From: Ankur Dave <ankurdave@gmail.com>
Date: Wed, 6 Nov 2013 22:50:30 -0800
Subject: [PATCH] Optimize mrTriplets for source-attr-only mapF using bytecode
 inspection

---
 .../apache/spark/graph/impl/GraphImpl.scala   | 41 +++++++++++++++++--
 .../org/apache/spark/graph/GraphSuite.scala   | 24 ++++++++++-
 2 files changed, 60 insertions(+), 5 deletions(-)

diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
index 0d7546b575..64fdb10831 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
@@ -12,6 +12,7 @@ import org.apache.spark.util.ClosureCleaner
 import org.apache.spark.graph._
 import org.apache.spark.graph.impl.GraphImpl._
 import org.apache.spark.graph.impl.MsgRDDFunctions._
+import org.apache.spark.graph.util.BytecodeUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveKeyOpenHashMap}
@@ -384,6 +385,22 @@ object GraphImpl {
       .mapValues(a => a.toArray).cache()
   }
 
+  protected def createVid2PidSourceAttrOnly[ED: ClassManifest](
+    eTable: RDD[(Pid, EdgePartition[ED])],
+    vTableIndex: VertexSetIndex): VertexSetRDD[Array[Pid]] = {
+    val preAgg = eTable.mapPartitions { iter =>
+      val (pid, edgePartition) = iter.next()
+      val vSet = new VertexSet
+      edgePartition.foreach(e => {vSet.add(e.srcId)})
+      vSet.iterator.map { vid => (vid.toLong, pid) }
+    }
+    VertexSetRDD[Pid, ArrayBuffer[Pid]](preAgg, vTableIndex,
+      (p: Pid) => ArrayBuffer(p),
+      (ab: ArrayBuffer[Pid], p:Pid) => {ab.append(p); ab},
+      (a: ArrayBuffer[Pid], b: ArrayBuffer[Pid]) => a ++ b)
+      .mapValues(a => a.toArray).cache()
+  }
+
   protected def createLocalVidMap[ED: ClassManifest](eTable: RDD[(Pid, EdgePartition[ED])]):
     RDD[(Pid, VertexIdToIndexMap)] = {
     eTable.mapPartitions( _.map{ case (pid, epart) =>
@@ -468,8 +485,22 @@ object GraphImpl {
     ClosureCleaner.clean(mapFunc)
     ClosureCleaner.clean(reduceFunc)
 
+    // For each vertex, replicate its attribute only to partitions where it is
+    // in the relevant position in an edge.
+    val mapFuncUsesSrcAttr =
+      BytecodeUtils.invokedMethod(mapFunc, classOf[EdgeTriplet[VD, ED]], "srcAttr")
+    val mapFuncUsesDstAttr =
+      BytecodeUtils.invokedMethod(mapFunc, classOf[EdgeTriplet[VD, ED]], "dstAttr")
+    val vTableReplicatedValues =
+      if (mapFuncUsesSrcAttr && !mapFuncUsesDstAttr) {
+        val vid2pidSourceAttrOnly = createVid2PidSourceAttrOnly(g.eTable, g.vTable.index)
+        createVTableReplicated(g.vTable, vid2pidSourceAttrOnly, g.localVidMap)
+      } else {
+        g.vTableReplicatedValues
+      }
+
     // Map and preaggregate
-    val preAgg = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValues){
+    val preAgg = g.eTable.zipPartitions(g.localVidMap, vTableReplicatedValues){
       (edgePartitionIter, vidToIndexIter, vertexArrayIter) =>
       val (_, edgePartition) = edgePartitionIter.next()
       val (_, vidToIndex) = vidToIndexIter.next()
@@ -488,8 +519,12 @@ object GraphImpl {
 
       edgePartition.foreach { e =>
         et.set(e)
-        et.srcAttr = vmap(e.srcId)
-        et.dstAttr = vmap(e.dstId)
+        if (mapFuncUsesSrcAttr) {
+          et.srcAttr = vmap(e.srcId)
+        }
+        if (mapFuncUsesDstAttr) {
+          et.dstAttr = vmap(e.dstId)
+        }
         // TODO(rxin): rewrite the foreach using a simple while loop to speed things up.
         // Also given we are only allowing zero, one, or two messages, we can completely unroll
         // the for loop.
diff --git a/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala b/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala
index ec548bda16..37fb60c4cc 100644
--- a/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala
+++ b/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala
@@ -58,6 +58,26 @@ class GraphSuite extends FunSuite with LocalSparkContext {
     }
   }
 
+  test("aggregateNeighborsSourceAttrOnly") {
+    withSpark(new SparkContext("local", "test")) { sc =>
+      val n = 3
+      // Create a star graph where the degree of each vertex is its attribute
+      val star = Graph(sc.parallelize((1 to n).map(x => ((n + 1): Vid, x: Vid))))
+
+      val totalOfInNeighborDegrees = star.aggregateNeighbors(
+        (vid, edge) => {
+          // All edges have the center vertex as the source, which has degree n
+          if (edge.srcAttr != n) {
+            throw new Exception("edge.srcAttr is %d, expected %d".format(edge.srcAttr, n))
+          }
+          Some(edge.srcAttr)
+        },
+        (a: Int, b: Int) => a + b,
+        EdgeDirection.In)
+      assert(totalOfInNeighborDegrees.collect().toSet === (1 to n).map(x => (x, n)).toSet)
+    }
+  }
+
   test("joinVertices") {
     withSpark(new SparkContext("local", "test")) { sc =>
       val vertices = sc.parallelize(Seq[(Vid, String)]((1, "one"), (2, "two"), (3, "three")), 2)
@@ -87,6 +107,6 @@ class GraphSuite extends FunSuite with LocalSparkContext {
       assert(b.zipJoin(c)((id, b, c) => b + c).map(x => x._2).reduce(_+_) === 0)
 
     }
-  } 
-  
+  }
+
 }
-- 
GitLab