Skip to content
Snippets Groups Projects
Commit 502c5117 authored by Ankur Dave's avatar Ankur Dave
Browse files

Use pid2vid for creating VTableReplicatedValues

parent 53d24a97
No related branches found
No related tags found
No related merge requests found
......@@ -18,6 +18,7 @@ class GraphKryoRegistrator extends KryoRegistrator {
kryo.register(classOf[EdgePartition[Object]])
kryo.register(classOf[BitSet])
kryo.register(classOf[VertexIdToIndexMap])
kryo.register(classOf[VertexAttributeBlock[Object]])
// This avoids a large number of hash table lookups.
kryo.setReferences(false)
}
......
package org.apache.spark.graph.impl
import scala.collection.mutable.ArrayBuilder
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.util.collection.{OpenHashSet, PrimitiveKeyOpenHashMap}
import org.apache.spark.graph._
import org.apache.spark.graph.impl.MsgRDDFunctions._
......@@ -34,7 +37,7 @@ class VTableReplicatedValues[VD: ClassManifest](
}
}
class VertexAttributeBlock[VD: ClassManifest](val vids: Array[Vid], val attrs: Array[VD])
object VTableReplicatedValues {
protected def createVTableReplicated[VD: ClassManifest](
......@@ -44,13 +47,30 @@ object VTableReplicatedValues {
includeSrcAttr: Boolean,
includeDstAttr: Boolean): RDD[(Pid, Array[VD])] = {
// Join vid2pid and vTable, generate a shuffle dependency on the joined
// result, and get the shuffle id so we can use it on the slave.
val msgsByPartition = vTable.zipJoinFlatMap(vid2pid.get(includeSrcAttr, includeDstAttr)) {
// TODO(rxin): reuse VertexBroadcastMessage
(vid, vdata, pids) => pids.iterator.map { pid =>
new VertexBroadcastMsg[VD](pid, vid, vdata)
// Within each partition of vid2pid, construct a pid2vid mapping
val numPartitions = vTable.partitions.size
val pid2vid = vid2pid.get(includeSrcAttr, includeDstAttr).mapPartitions { iter =>
val pid2vidLocal = Array.fill[ArrayBuilder[Vid]](numPartitions)(ArrayBuilder.make[Vid])
for ((vid, pids) <- iter) {
pids.foreach { pid => pid2vidLocal(pid) += vid }
}
Iterator(pid2vidLocal.map(_.result))
}
val msgsByPartition = pid2vid.zipPartitions(vTable.index.rdd, vTable.valuesRDD) {
(pid2vidIter, indexIter, valuesIter) =>
val pid2vid = pid2vidIter.next()
val index = indexIter.next()
val values = valuesIter.next()
val vmap = new PrimitiveKeyOpenHashMap(index, values._1)
// Send each partition the vertex attributes it wants
val output = new Array[(Pid, VertexAttributeBlock[VD])](pid2vid.size)
for (pid <- 0 until pid2vid.size) {
val block = new VertexAttributeBlock(pid2vid(pid), pid2vid(pid).map(vid => vmap(vid)))
output(pid) = (pid, block)
}
output.iterator
}.partitionBy(localVidMap.partitioner.get).cache()
localVidMap.zipPartitions(msgsByPartition){
......@@ -59,14 +79,16 @@ object VTableReplicatedValues {
assert(!mapIter.hasNext)
// Populate the vertex array using the vidToIndex map
val vertexArray = new Array[VD](vidToIndex.capacity)
for (msg <- msgsIter) {
val ind = vidToIndex.getPos(msg.vid) & OpenHashSet.POSITION_MASK
vertexArray(ind) = msg.data
for ((_, block) <- msgsIter) {
for (i <- 0 until block.vids.size) {
val vid = block.vids(i)
val attr = block.attrs(i)
val ind = vidToIndex.getPos(vid) & OpenHashSet.POSITION_MASK
vertexArray(ind) = attr
}
}
Iterator((pid, vertexArray))
}.cache()
// @todo assert edge table has partitioner
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment