Skip to content
Snippets Groups Projects
Commit 099977fd authored by Joey's avatar Joey
Browse files

Merge pull request #26 from ankurdave/split-vTableReplicated

Great work!
parents af8e4618 bf19aac2
No related branches found
No related tags found
No related merge requests found
...@@ -4,6 +4,7 @@ import scala.collection.JavaConversions._ ...@@ -4,6 +4,7 @@ import scala.collection.JavaConversions._
import scala.collection.mutable import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.ArrayBuilder
import org.apache.spark.SparkContext._ import org.apache.spark.SparkContext._
import org.apache.spark.Partitioner import org.apache.spark.Partitioner
...@@ -24,7 +25,8 @@ import org.apache.spark.graph.impl.MessageToPartitionRDDFunctions._ ...@@ -24,7 +25,8 @@ import org.apache.spark.graph.impl.MessageToPartitionRDDFunctions._
* The Iterator type returned when constructing edge triplets * The Iterator type returned when constructing edge triplets
*/ */
class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest]( class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest](
val vmap: VertexHashMap[VD], val vidToIndex: VertexIdToIndexMap,
val vertexArray: Array[VD],
val edgePartition: EdgePartition[ED]) extends Iterator[EdgeTriplet[VD, ED]] { val edgePartition: EdgePartition[ED]) extends Iterator[EdgeTriplet[VD, ED]] {
private var pos = 0 private var pos = 0
...@@ -34,10 +36,10 @@ class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest]( ...@@ -34,10 +36,10 @@ class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest](
override def next() = { override def next() = {
et.srcId = edgePartition.srcIds(pos) et.srcId = edgePartition.srcIds(pos)
// assert(vmap.containsKey(e.src.id)) // assert(vmap.containsKey(e.src.id))
et.srcAttr = vmap.get(et.srcId) et.srcAttr = vertexArray(vidToIndex(et.srcId))
et.dstId = edgePartition.dstIds(pos) et.dstId = edgePartition.dstIds(pos)
// assert(vmap.containsKey(e.dst.id)) // assert(vmap.containsKey(e.dst.id))
et.dstAttr = vmap.get(et.dstId) et.dstAttr = vertexArray(vidToIndex(et.dstId))
//println("Iter called: " + pos) //println("Iter called: " + pos)
et.attr = edgePartition.data(pos) et.attr = edgePartition.data(pos)
pos += 1 pos += 1
...@@ -50,10 +52,10 @@ class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest]( ...@@ -50,10 +52,10 @@ class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest](
for (i <- (0 until edgePartition.size)) { for (i <- (0 until edgePartition.size)) {
currentEdge.srcId = edgePartition.srcIds(i) currentEdge.srcId = edgePartition.srcIds(i)
// assert(vmap.containsKey(e.src.id)) // assert(vmap.containsKey(e.src.id))
currentEdge.srcAttr = vmap.get(currentEdge.srcId) currentEdge.srcAttr = vertexArray(vidToIndex(currentEdge.srcId))
currentEdge.dstId = edgePartition.dstIds(i) currentEdge.dstId = edgePartition.dstIds(i)
// assert(vmap.containsKey(e.dst.id)) // assert(vmap.containsKey(e.dst.id))
currentEdge.dstAttr = vmap.get(currentEdge.dstId) currentEdge.dstAttr = vertexArray(vidToIndex(currentEdge.dstId))
currentEdge.attr = edgePartition.data(i) currentEdge.attr = edgePartition.data(i)
lb += currentEdge lb += currentEdge
} }
...@@ -63,17 +65,18 @@ class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest]( ...@@ -63,17 +65,18 @@ class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest](
object EdgeTripletBuilder { object EdgeTripletBuilder {
def makeTriplets[VD: ClassManifest, ED: ClassManifest]( def makeTriplets[VD: ClassManifest, ED: ClassManifest](
vTableReplicated: IndexedRDD[Pid, VertexHashMap[VD]], vTableReplicationMap: IndexedRDD[Pid, VertexIdToIndexMap],
vTableReplicatedValues: IndexedRDD[Pid, Array[VD]],
eTable: IndexedRDD[Pid, EdgePartition[ED]]): RDD[EdgeTriplet[VD, ED]] = { eTable: IndexedRDD[Pid, EdgePartition[ED]]): RDD[EdgeTriplet[VD, ED]] = {
val iterFun = (iter: Iterator[(Pid, (VertexHashMap[VD], EdgePartition[ED]))]) => { val iterFun = (iter: Iterator[(Pid, ((VertexIdToIndexMap, Array[VD]), EdgePartition[ED]))]) => {
val (pid, (vmap, edgePartition)) = iter.next() val (pid, ((vidToIndex, vertexArray), edgePartition)) = iter.next()
//assert(iter.hasNext == false) //assert(iter.hasNext == false)
// Return an iterator that looks up the hash map to find matching // Return an iterator that looks up the hash map to find matching
// vertices for each edge. // vertices for each edge.
new EdgeTripletIterator(vmap, edgePartition) new EdgeTripletIterator(vidToIndex, vertexArray, edgePartition)
} }
ClosureCleaner.clean(iterFun) ClosureCleaner.clean(iterFun)
vTableReplicated.zipJoinRDD(eTable) vTableReplicationMap.zipJoin(vTableReplicatedValues).zipJoinRDD(eTable)
.mapPartitions( iterFun ) // end of map partition .mapPartitions( iterFun ) // end of map partition
} }
...@@ -93,13 +96,16 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( ...@@ -93,13 +96,16 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
/** /**
* The vTableReplicated is a version of the vertex data after it is * (vTableReplicationMap: IndexedRDD[Pid, VertexIdToIndexMap]) is a version of the
* replicated. * vertex data after it is replicated. Within each partition, it holds a map
* from vertex ID to the index where that vertex's attribute is stored. This
* index refers to an array in the same partition in vTableReplicatedValues.
*
* (vTableReplicatedValues: IndexedRDD[Pid, Array[VD]]) holds the vertex data
* and is arranged as described above.
*/ */
@transient val vTableReplicated: IndexedRDD[Pid, VertexHashMap[VD]] = @transient val (vTableReplicationMap, vTableReplicatedValues) =
createVTableReplicated(vTable, vid2pid, eTable) createVTableReplicated(vTable, vid2pid, eTable)
/** Return a RDD of vertices. */ /** Return a RDD of vertices. */
...@@ -114,7 +120,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( ...@@ -114,7 +120,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
/** Return a RDD that brings edges with its source and destination vertices together. */ /** Return a RDD that brings edges with its source and destination vertices together. */
@transient override val triplets: RDD[EdgeTriplet[VD, ED]] = @transient override val triplets: RDD[EdgeTriplet[VD, ED]] =
EdgeTripletBuilder.makeTriplets(vTableReplicated, eTable) EdgeTripletBuilder.makeTriplets(vTableReplicationMap, vTableReplicatedValues, eTable)
// { // {
...@@ -136,8 +142,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( ...@@ -136,8 +142,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
eTable.cache() eTable.cache()
vid2pid.cache() vid2pid.cache()
vTable.cache() vTable.cache()
/** @todo should we cache the replicated data? */
vTableReplicated.cache()
this this
} }
...@@ -179,15 +183,15 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( ...@@ -179,15 +183,15 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
override def mapTriplets[ED2: ClassManifest](f: EdgeTriplet[VD, ED] => ED2): override def mapTriplets[ED2: ClassManifest](f: EdgeTriplet[VD, ED] => ED2):
Graph[VD, ED2] = { Graph[VD, ED2] = {
val newETable = eTable.join(vTableReplicated).mapValues{ val newETable = eTable.zipJoin(vTableReplicationMap).zipJoin(vTableReplicatedValues).mapValues{
case (edgePartition, vmap) => case ((edgePartition, vidToIndex), vertexArray) =>
val et = new EdgeTriplet[VD, ED] val et = new EdgeTriplet[VD, ED]
edgePartition.map{e => edgePartition.map{e =>
et.set(e) et.set(e)
et.srcAttr = vmap(e.srcId) et.srcAttr = vertexArray(vidToIndex(e.srcId))
et.dstAttr = vmap(e.dstId) et.dstAttr = vertexArray(vidToIndex(e.dstId))
f(et) f(et)
} }
}.asInstanceOf[IndexedRDD[Pid, EdgePartition[ED2]]] }.asInstanceOf[IndexedRDD[Pid, EdgePartition[ED2]]]
new GraphImpl(vTable, vid2pid, newETable) new GraphImpl(vTable, vid2pid, newETable)
} }
...@@ -344,20 +348,20 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( ...@@ -344,20 +348,20 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
ClosureCleaner.clean(reduceFunc) ClosureCleaner.clean(reduceFunc)
// Map and preaggregate // Map and preaggregate
val preAgg = vTableReplicated.zipJoinRDD(eTable).flatMap{ val preAgg = vTableReplicationMap.zipJoin(vTableReplicatedValues).zipJoinRDD(eTable).flatMap{
case (pid, (vmap, edgePartition)) => case (pid, ((vidToIndex, vertexArray), edgePartition)) =>
val aggMap = new VertexHashMap[A] val aggMap = new VertexHashMap[A]
val et = new EdgeTriplet[VD, ED] val et = new EdgeTriplet[VD, ED]
edgePartition.foreach{e => edgePartition.foreach{e =>
et.set(e) et.set(e)
et.srcAttr = vmap(e.srcId) et.srcAttr = vertexArray(vidToIndex(e.srcId))
et.dstAttr = vmap(e.dstId) et.dstAttr = vertexArray(vidToIndex(e.dstId))
mapFunc(et).foreach{case (vid, a) => mapFunc(et).foreach{case (vid, a) =>
if(aggMap.containsKey(vid)) { if(aggMap.containsKey(vid)) {
aggMap.put(vid, reduceFunc(aggMap.get(vid), a)) aggMap.put(vid, reduceFunc(aggMap.get(vid), a))
} else { aggMap.put(vid, a) } } else { aggMap.put(vid, a) }
}
} }
}
// Return the aggregate map // Return the aggregate map
aggMap.long2ObjectEntrySet().fastIterator().map{ aggMap.long2ObjectEntrySet().fastIterator().map{
entry => (entry.getLongKey(), entry.getValue()) entry => (entry.getLongKey(), entry.getValue())
...@@ -475,21 +479,37 @@ object GraphImpl { ...@@ -475,21 +479,37 @@ object GraphImpl {
protected def createVTableReplicated[VD: ClassManifest, ED: ClassManifest]( protected def createVTableReplicated[VD: ClassManifest, ED: ClassManifest](
vTable: IndexedRDD[Vid, VD], vid2pid: IndexedRDD[Vid, Array[Pid]], vTable: IndexedRDD[Vid, VD], vid2pid: IndexedRDD[Vid, Array[Pid]],
eTable: IndexedRDD[Pid, EdgePartition[ED]]): eTable: IndexedRDD[Pid, EdgePartition[ED]]):
IndexedRDD[Pid, VertexHashMap[VD]] = { (IndexedRDD[Pid, VertexIdToIndexMap], IndexedRDD[Pid, Array[VD]]) = {
// Join vid2pid and vTable, generate a shuffle dependency on the joined // Join vid2pid and vTable, generate a shuffle dependency on the joined
// result, and get the shuffle id so we can use it on the slave. // result, and get the shuffle id so we can use it on the slave.
vTable.zipJoinRDD(vid2pid) val msgsByPartition =
.flatMap { case (vid, (vdata, pids)) => vTable.zipJoinRDD(vid2pid)
pids.iterator.map { pid => MessageToPartition(pid, (vid, vdata)) } .flatMap { case (vid, (vdata, pids)) =>
} pids.iterator.map { pid => MessageToPartition(pid, (vid, vdata)) }
.partitionBy(eTable.partitioner.get) //@todo assert edge table has partitioner }
.mapPartitionsWithIndex( (pid, iter) => { .partitionBy(eTable.partitioner.get).cache()
// Build the hashmap for each partition // @todo assert edge table has partitioner
val vmap = new VertexHashMap[VD]
for( msg <- iter ) { vmap.put(msg.data._1, msg.data._2) } val vTableReplicationMap: IndexedRDD[Pid, VertexIdToIndexMap] =
Array((pid, vmap)).iterator msgsByPartition.mapPartitionsWithIndex( (pid, iter) => {
}, preservesPartitioning = true) val vidToIndex = new VertexIdToIndexMap
.indexed(eTable.index) var i = 0
for (msg <- iter) {
vidToIndex.put(msg.data._1, i)
}
Array((pid, vidToIndex)).iterator
}, preservesPartitioning = true).indexed(eTable.index)
val vTableReplicatedValues: IndexedRDD[Pid, Array[VD]] =
msgsByPartition.mapPartitionsWithIndex( (pid, iter) => {
val vertexArray = ArrayBuilder.make[VD]
for (msg <- iter) {
vertexArray += msg.data._2
}
Array((pid, vertexArray.result)).iterator
}, preservesPartitioning = true).indexed(eTable.index)
(vTableReplicationMap, vTableReplicatedValues)
} }
......
...@@ -8,6 +8,8 @@ package object graph { ...@@ -8,6 +8,8 @@ package object graph {
type VertexHashMap[T] = it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap[T] type VertexHashMap[T] = it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap[T]
type VertexSet = it.unimi.dsi.fastutil.longs.LongOpenHashSet type VertexSet = it.unimi.dsi.fastutil.longs.LongOpenHashSet
type VertexArrayList = it.unimi.dsi.fastutil.longs.LongArrayList type VertexArrayList = it.unimi.dsi.fastutil.longs.LongArrayList
// @todo replace with rxin's fast hashmap
type VertexIdToIndexMap = scala.collection.mutable.HashMap[Vid, Int]
/** /**
* Return the default null-like value for a data type T. * Return the default null-like value for a data type T.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment