From d1ff1b722274de8e03938452d8155f2a26c55f96 Mon Sep 17 00:00:00 2001
From: Ankur Dave <ankurdave@gmail.com>
Date: Sun, 10 Nov 2013 01:51:42 -0800
Subject: [PATCH] Build pid2vid structures only once, in Vid2Pid

---
 .../graph/impl/VTableReplicatedValues.scala   | 12 +-------
 .../org/apache/spark/graph/impl/Vid2Pid.scala | 29 +++++++++++++++++++
 2 files changed, 30 insertions(+), 11 deletions(-)

diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala b/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala
index 25cd1b8054..fee2d40ee4 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala
@@ -1,7 +1,5 @@
 package org.apache.spark.graph.impl
 
-import scala.collection.mutable.ArrayBuilder
-
 import org.apache.spark.SparkContext._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.collection.{OpenHashSet, PrimitiveKeyOpenHashMap}
@@ -47,15 +45,7 @@ object VTableReplicatedValues {
       includeSrcAttr: Boolean,
       includeDstAttr: Boolean): RDD[(Pid, Array[VD])] = {
 
-    // Within each partition of vid2pid, construct a pid2vid mapping
-    val numPartitions = vTable.partitions.size
-    val pid2vid = vid2pid.get(includeSrcAttr, includeDstAttr).mapPartitions { iter =>
-      val pid2vidLocal = Array.fill[ArrayBuilder[Vid]](numPartitions)(ArrayBuilder.make[Vid])
-      for ((vid, pids) <- iter) {
-        pids.foreach { pid => pid2vidLocal(pid) += vid }
-      }
-      Iterator(pid2vidLocal.map(_.result))
-    }
+    val pid2vid = vid2pid.getPid2Vid(includeSrcAttr, includeDstAttr)
 
     val msgsByPartition = pid2vid.zipPartitions(vTable.index.rdd, vTable.valuesRDD) {
       (pid2vidIter, indexIter, valuesIter) =>
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala b/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala
index 9bdca7f407..363adbbce9 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala
@@ -2,6 +2,7 @@ package org.apache.spark.graph.impl
 
 import scala.collection.JavaConversions._
 import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.ArrayBuilder
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
@@ -20,6 +21,11 @@ class Vid2Pid(
   val dstAttrOnly: VertexSetRDD[Array[Pid]] = createVid2Pid(false, true)
   val noAttrs: VertexSetRDD[Array[Pid]] = createVid2Pid(false, false)
 
+  val pid2VidBothAttrs: RDD[Array[Array[Vid]]] = createPid2Vid(bothAttrs)
+  val pid2VidSrcAttrOnly: RDD[Array[Array[Vid]]] = createPid2Vid(srcAttrOnly)
+  val pid2VidDstAttrOnly: RDD[Array[Array[Vid]]] = createPid2Vid(dstAttrOnly)
+  val pid2VidNoAttrs: RDD[Array[Array[Vid]]] = createPid2Vid(noAttrs)
+
   def get(includeSrcAttr: Boolean, includeDstAttr: Boolean): VertexSetRDD[Array[Pid]] =
     (includeSrcAttr, includeDstAttr) match {
       case (true, true) => bothAttrs
@@ -28,6 +34,14 @@ class Vid2Pid(
       case (false, false) => noAttrs
     }
 
+  def getPid2Vid(includeSrcAttr: Boolean, includeDstAttr: Boolean): RDD[Array[Array[Vid]]] =
+    (includeSrcAttr, includeDstAttr) match {
+      case (true, true) => pid2VidBothAttrs
+      case (true, false) => pid2VidSrcAttrOnly
+      case (false, true) => pid2VidDstAttrOnly
+      case (false, false) => pid2VidNoAttrs
+    }
+
   def persist(newLevel: StorageLevel) {
     bothAttrs.persist(newLevel)
     srcAttrOnly.persist(newLevel)
@@ -55,4 +69,19 @@ class Vid2Pid(
       (a: ArrayBuffer[Pid], b: ArrayBuffer[Pid]) => a ++ b)
       .mapValues(a => a.toArray).cache()
   }
+
+  /**
+   * Creates an intermediate pid2vid structure that tells each partition of the
+   * vertex data where it should go.
+   */
+  private def createPid2Vid(vid2pid: VertexSetRDD[Array[Pid]]): RDD[Array[Array[Vid]]] = {
+    val numPartitions = vid2pid.partitions.size
+    vid2pid.mapPartitions { iter =>
+      val pid2vidLocal = Array.fill[ArrayBuilder[Vid]](numPartitions)(ArrayBuilder.make[Vid])
+      for ((vid, pids) <- iter) {
+        pids.foreach { pid => pid2vidLocal(pid) += vid }
+      }
+      Iterator(pid2vidLocal.map(_.result))
+    }
+  }
 }
-- 
GitLab