From d1ff1b722274de8e03938452d8155f2a26c55f96 Mon Sep 17 00:00:00 2001 From: Ankur Dave <ankurdave@gmail.com> Date: Sun, 10 Nov 2013 01:51:42 -0800 Subject: [PATCH] Build pid2vid structures only once, in Vid2Pid --- .../graph/impl/VTableReplicatedValues.scala | 12 +------- .../org/apache/spark/graph/impl/Vid2Pid.scala | 29 +++++++++++++++++++ 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala b/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala index 25cd1b8054..fee2d40ee4 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala @@ -1,7 +1,5 @@ package org.apache.spark.graph.impl -import scala.collection.mutable.ArrayBuilder - import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.util.collection.{OpenHashSet, PrimitiveKeyOpenHashMap} @@ -47,15 +45,7 @@ object VTableReplicatedValues { includeSrcAttr: Boolean, includeDstAttr: Boolean): RDD[(Pid, Array[VD])] = { - // Within each partition of vid2pid, construct a pid2vid mapping - val numPartitions = vTable.partitions.size - val pid2vid = vid2pid.get(includeSrcAttr, includeDstAttr).mapPartitions { iter => - val pid2vidLocal = Array.fill[ArrayBuilder[Vid]](numPartitions)(ArrayBuilder.make[Vid]) - for ((vid, pids) <- iter) { - pids.foreach { pid => pid2vidLocal(pid) += vid } - } - Iterator(pid2vidLocal.map(_.result)) - } + val pid2vid = vid2pid.getPid2Vid(includeSrcAttr, includeDstAttr) val msgsByPartition = pid2vid.zipPartitions(vTable.index.rdd, vTable.valuesRDD) { (pid2vidIter, indexIter, valuesIter) => diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala b/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala index 9bdca7f407..363adbbce9 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala @@ -2,6 +2,7 @@ package org.apache.spark.graph.impl import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ArrayBuilder import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -20,6 +21,11 @@ class Vid2Pid( val dstAttrOnly: VertexSetRDD[Array[Pid]] = createVid2Pid(false, true) val noAttrs: VertexSetRDD[Array[Pid]] = createVid2Pid(false, false) + val pid2VidBothAttrs: RDD[Array[Array[Vid]]] = createPid2Vid(bothAttrs) + val pid2VidSrcAttrOnly: RDD[Array[Array[Vid]]] = createPid2Vid(srcAttrOnly) + val pid2VidDstAttrOnly: RDD[Array[Array[Vid]]] = createPid2Vid(dstAttrOnly) + val pid2VidNoAttrs: RDD[Array[Array[Vid]]] = createPid2Vid(noAttrs) + def get(includeSrcAttr: Boolean, includeDstAttr: Boolean): VertexSetRDD[Array[Pid]] = (includeSrcAttr, includeDstAttr) match { case (true, true) => bothAttrs @@ -28,6 +34,14 @@ class Vid2Pid( case (false, false) => noAttrs } + def getPid2Vid(includeSrcAttr: Boolean, includeDstAttr: Boolean): RDD[Array[Array[Vid]]] = + (includeSrcAttr, includeDstAttr) match { + case (true, true) => pid2VidBothAttrs + case (true, false) => pid2VidSrcAttrOnly + case (false, true) => pid2VidDstAttrOnly + case (false, false) => pid2VidNoAttrs + } + def persist(newLevel: StorageLevel) { bothAttrs.persist(newLevel) srcAttrOnly.persist(newLevel) @@ -55,4 +69,19 @@ class Vid2Pid( (a: ArrayBuffer[Pid], b: ArrayBuffer[Pid]) => a ++ b) .mapValues(a => a.toArray).cache() } + + /** + * Creates an intermediate pid2vid structure that tells each partition of the + * vertex data where it should go. + */ + private def createPid2Vid(vid2pid: VertexSetRDD[Array[Pid]]): RDD[Array[Array[Vid]]] = { + val numPartitions = vid2pid.partitions.size + vid2pid.mapPartitions { iter => + val pid2vidLocal = Array.fill[ArrayBuilder[Vid]](numPartitions)(ArrayBuilder.make[Vid]) + for ((vid, pids) <- iter) { + pids.foreach { pid => pid2vidLocal(pid) += vid } + } + Iterator(pid2vidLocal.map(_.result)) + } + } } -- GitLab