Skip to content
Snippets Groups Projects
Commit 502c5117 authored by Ankur Dave's avatar Ankur Dave
Browse files

Use pid2vid for creating VTableReplicatedValues

parent 53d24a97
No related branches found
No related tags found
No related merge requests found
...@@ -18,6 +18,7 @@ class GraphKryoRegistrator extends KryoRegistrator { ...@@ -18,6 +18,7 @@ class GraphKryoRegistrator extends KryoRegistrator {
kryo.register(classOf[EdgePartition[Object]]) kryo.register(classOf[EdgePartition[Object]])
kryo.register(classOf[BitSet]) kryo.register(classOf[BitSet])
kryo.register(classOf[VertexIdToIndexMap]) kryo.register(classOf[VertexIdToIndexMap])
kryo.register(classOf[VertexAttributeBlock[Object]])
// This avoids a large number of hash table lookups. // This avoids a large number of hash table lookups.
kryo.setReferences(false) kryo.setReferences(false)
} }
......
package org.apache.spark.graph.impl package org.apache.spark.graph.impl
import scala.collection.mutable.ArrayBuilder
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.util.collection.{OpenHashSet, PrimitiveKeyOpenHashMap}
import org.apache.spark.graph._ import org.apache.spark.graph._
import org.apache.spark.graph.impl.MsgRDDFunctions._ import org.apache.spark.graph.impl.MsgRDDFunctions._
...@@ -34,7 +37,7 @@ class VTableReplicatedValues[VD: ClassManifest]( ...@@ -34,7 +37,7 @@ class VTableReplicatedValues[VD: ClassManifest](
} }
} }
class VertexAttributeBlock[VD: ClassManifest](val vids: Array[Vid], val attrs: Array[VD])
object VTableReplicatedValues { object VTableReplicatedValues {
protected def createVTableReplicated[VD: ClassManifest]( protected def createVTableReplicated[VD: ClassManifest](
...@@ -44,13 +47,30 @@ object VTableReplicatedValues { ...@@ -44,13 +47,30 @@ object VTableReplicatedValues {
includeSrcAttr: Boolean, includeSrcAttr: Boolean,
includeDstAttr: Boolean): RDD[(Pid, Array[VD])] = { includeDstAttr: Boolean): RDD[(Pid, Array[VD])] = {
// Join vid2pid and vTable, generate a shuffle dependency on the joined // Within each partition of vid2pid, construct a pid2vid mapping
// result, and get the shuffle id so we can use it on the slave. val numPartitions = vTable.partitions.size
val msgsByPartition = vTable.zipJoinFlatMap(vid2pid.get(includeSrcAttr, includeDstAttr)) { val pid2vid = vid2pid.get(includeSrcAttr, includeDstAttr).mapPartitions { iter =>
// TODO(rxin): reuse VertexBroadcastMessage val pid2vidLocal = Array.fill[ArrayBuilder[Vid]](numPartitions)(ArrayBuilder.make[Vid])
(vid, vdata, pids) => pids.iterator.map { pid => for ((vid, pids) <- iter) {
new VertexBroadcastMsg[VD](pid, vid, vdata) pids.foreach { pid => pid2vidLocal(pid) += vid }
} }
Iterator(pid2vidLocal.map(_.result))
}
val msgsByPartition = pid2vid.zipPartitions(vTable.index.rdd, vTable.valuesRDD) {
(pid2vidIter, indexIter, valuesIter) =>
val pid2vid = pid2vidIter.next()
val index = indexIter.next()
val values = valuesIter.next()
val vmap = new PrimitiveKeyOpenHashMap(index, values._1)
// Send each partition the vertex attributes it wants
val output = new Array[(Pid, VertexAttributeBlock[VD])](pid2vid.size)
for (pid <- 0 until pid2vid.size) {
val block = new VertexAttributeBlock(pid2vid(pid), pid2vid(pid).map(vid => vmap(vid)))
output(pid) = (pid, block)
}
output.iterator
}.partitionBy(localVidMap.partitioner.get).cache() }.partitionBy(localVidMap.partitioner.get).cache()
localVidMap.zipPartitions(msgsByPartition){ localVidMap.zipPartitions(msgsByPartition){
...@@ -59,14 +79,16 @@ object VTableReplicatedValues { ...@@ -59,14 +79,16 @@ object VTableReplicatedValues {
assert(!mapIter.hasNext) assert(!mapIter.hasNext)
// Populate the vertex array using the vidToIndex map // Populate the vertex array using the vidToIndex map
val vertexArray = new Array[VD](vidToIndex.capacity) val vertexArray = new Array[VD](vidToIndex.capacity)
for (msg <- msgsIter) { for ((_, block) <- msgsIter) {
val ind = vidToIndex.getPos(msg.vid) & OpenHashSet.POSITION_MASK for (i <- 0 until block.vids.size) {
vertexArray(ind) = msg.data val vid = block.vids(i)
val attr = block.attrs(i)
val ind = vidToIndex.getPos(vid) & OpenHashSet.POSITION_MASK
vertexArray(ind) = attr
}
} }
Iterator((pid, vertexArray)) Iterator((pid, vertexArray))
}.cache() }.cache()
// @todo assert edge table has partitioner
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment