diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala index b6758b0501e7cab80c1fa1243d878c7400a3b9f2..87fb9dcd2e05f0f6410baa459c9723da98443a7b 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala @@ -4,6 +4,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkContext._ import org.apache.spark.Partitioner @@ -24,7 +25,8 @@ import org.apache.spark.graph.impl.MessageToPartitionRDDFunctions._ * The Iterator type returned when constructing edge triplets */ class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest]( - val vmap: VertexHashMap[VD], + val vidToIndex: VertexIdToIndexMap, + val vertexArray: Array[VD], val edgePartition: EdgePartition[ED]) extends Iterator[EdgeTriplet[VD, ED]] { private var pos = 0 @@ -34,10 +36,10 @@ class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest]( override def next() = { et.srcId = edgePartition.srcIds(pos) // assert(vmap.containsKey(e.src.id)) - et.srcAttr = vmap.get(et.srcId) + et.srcAttr = vertexArray(vidToIndex(et.srcId)) et.dstId = edgePartition.dstIds(pos) // assert(vmap.containsKey(e.dst.id)) - et.dstAttr = vmap.get(et.dstId) + et.dstAttr = vertexArray(vidToIndex(et.dstId)) //println("Iter called: " + pos) et.attr = edgePartition.data(pos) pos += 1 @@ -50,10 +52,10 @@ class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest]( for (i <- (0 until edgePartition.size)) { currentEdge.srcId = edgePartition.srcIds(i) // assert(vmap.containsKey(e.src.id)) - currentEdge.srcAttr = vmap.get(currentEdge.srcId) + currentEdge.srcAttr = vertexArray(vidToIndex(currentEdge.srcId)) currentEdge.dstId = edgePartition.dstIds(i) // assert(vmap.containsKey(e.dst.id)) - currentEdge.dstAttr = vmap.get(currentEdge.dstId) + currentEdge.dstAttr = vertexArray(vidToIndex(currentEdge.dstId)) currentEdge.attr = edgePartition.data(i) lb += currentEdge } @@ -63,17 +65,18 @@ class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest]( object EdgeTripletBuilder { def makeTriplets[VD: ClassManifest, ED: ClassManifest]( - vTableReplicated: IndexedRDD[Pid, VertexHashMap[VD]], + vTableReplicationMap: IndexedRDD[Pid, VertexIdToIndexMap], + vTableReplicatedValues: IndexedRDD[Pid, Array[VD]], eTable: IndexedRDD[Pid, EdgePartition[ED]]): RDD[EdgeTriplet[VD, ED]] = { - val iterFun = (iter: Iterator[(Pid, (VertexHashMap[VD], EdgePartition[ED]))]) => { - val (pid, (vmap, edgePartition)) = iter.next() + val iterFun = (iter: Iterator[(Pid, ((VertexIdToIndexMap, Array[VD]), EdgePartition[ED]))]) => { + val (pid, ((vidToIndex, vertexArray), edgePartition)) = iter.next() //assert(iter.hasNext == false) // Return an iterator that looks up the hash map to find matching // vertices for each edge. - new EdgeTripletIterator(vmap, edgePartition) + new EdgeTripletIterator(vidToIndex, vertexArray, edgePartition) } ClosureCleaner.clean(iterFun) - vTableReplicated.zipJoinRDD(eTable) + vTableReplicationMap.zipJoin(vTableReplicatedValues).zipJoinRDD(eTable) .mapPartitions( iterFun ) // end of map partition } @@ -93,13 +96,16 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( /** - * The vTableReplicated is a version of the vertex data after it is - * replicated. + * (vTableReplicationMap: IndexedRDD[Pid, VertexIdToIndexMap]) is a version of the + * vertex data after it is replicated. Within each partition, it holds a map + * from vertex ID to the index where that vertex's attribute is stored. This + * index refers to an array in the same partition in vTableReplicatedValues. + * + * (vTableReplicatedValues: IndexedRDD[Pid, Array[VD]]) holds the vertex data + * and is arranged as described above. */ - @transient val vTableReplicated: IndexedRDD[Pid, VertexHashMap[VD]] = - createVTableReplicated(vTable, vid2pid, eTable) - - + @transient val (vTableReplicationMap, vTableReplicatedValues) = + createVTableReplicated(vTable, vid2pid, eTable) /** Return a RDD of vertices. */ @@ -114,7 +120,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( /** Return a RDD that brings edges with its source and destination vertices together. */ @transient override val triplets: RDD[EdgeTriplet[VD, ED]] = - EdgeTripletBuilder.makeTriplets(vTableReplicated, eTable) + EdgeTripletBuilder.makeTriplets(vTableReplicationMap, vTableReplicatedValues, eTable) // { @@ -136,8 +142,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( eTable.cache() vid2pid.cache() vTable.cache() - /** @todo should we cache the replicated data? */ - vTableReplicated.cache() this } @@ -179,15 +183,15 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( override def mapTriplets[ED2: ClassManifest](f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = { - val newETable = eTable.join(vTableReplicated).mapValues{ - case (edgePartition, vmap) => - val et = new EdgeTriplet[VD, ED] - edgePartition.map{e => - et.set(e) - et.srcAttr = vmap(e.srcId) - et.dstAttr = vmap(e.dstId) - f(et) - } + val newETable = eTable.zipJoin(vTableReplicationMap).zipJoin(vTableReplicatedValues).mapValues{ + case ((edgePartition, vidToIndex), vertexArray) => + val et = new EdgeTriplet[VD, ED] + edgePartition.map{e => + et.set(e) + et.srcAttr = vertexArray(vidToIndex(e.srcId)) + et.dstAttr = vertexArray(vidToIndex(e.dstId)) + f(et) + } }.asInstanceOf[IndexedRDD[Pid, EdgePartition[ED2]]] new GraphImpl(vTable, vid2pid, newETable) } @@ -344,20 +348,20 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( ClosureCleaner.clean(reduceFunc) // Map and preaggregate - val preAgg = vTableReplicated.zipJoinRDD(eTable).flatMap{ - case (pid, (vmap, edgePartition)) => + val preAgg = vTableReplicationMap.zipJoin(vTableReplicatedValues).zipJoinRDD(eTable).flatMap{ + case (pid, ((vidToIndex, vertexArray), edgePartition)) => val aggMap = new VertexHashMap[A] val et = new EdgeTriplet[VD, ED] edgePartition.foreach{e => et.set(e) - et.srcAttr = vmap(e.srcId) - et.dstAttr = vmap(e.dstId) + et.srcAttr = vertexArray(vidToIndex(e.srcId)) + et.dstAttr = vertexArray(vidToIndex(e.dstId)) mapFunc(et).foreach{case (vid, a) => if(aggMap.containsKey(vid)) { - aggMap.put(vid, reduceFunc(aggMap.get(vid), a)) - } else { aggMap.put(vid, a) } - } + aggMap.put(vid, reduceFunc(aggMap.get(vid), a)) + } else { aggMap.put(vid, a) } } + } // Return the aggregate map aggMap.long2ObjectEntrySet().fastIterator().map{ entry => (entry.getLongKey(), entry.getValue()) @@ -475,21 +479,37 @@ object GraphImpl { protected def createVTableReplicated[VD: ClassManifest, ED: ClassManifest]( vTable: IndexedRDD[Vid, VD], vid2pid: IndexedRDD[Vid, Array[Pid]], eTable: IndexedRDD[Pid, EdgePartition[ED]]): - IndexedRDD[Pid, VertexHashMap[VD]] = { + (IndexedRDD[Pid, VertexIdToIndexMap], IndexedRDD[Pid, Array[VD]]) = { // Join vid2pid and vTable, generate a shuffle dependency on the joined // result, and get the shuffle id so we can use it on the slave. - vTable.zipJoinRDD(vid2pid) - .flatMap { case (vid, (vdata, pids)) => - pids.iterator.map { pid => MessageToPartition(pid, (vid, vdata)) } - } - .partitionBy(eTable.partitioner.get) //@todo assert edge table has partitioner - .mapPartitionsWithIndex( (pid, iter) => { - // Build the hashmap for each partition - val vmap = new VertexHashMap[VD] - for( msg <- iter ) { vmap.put(msg.data._1, msg.data._2) } - Array((pid, vmap)).iterator - }, preservesPartitioning = true) - .indexed(eTable.index) + val msgsByPartition = + vTable.zipJoinRDD(vid2pid) + .flatMap { case (vid, (vdata, pids)) => + pids.iterator.map { pid => MessageToPartition(pid, (vid, vdata)) } + } + .partitionBy(eTable.partitioner.get).cache() + // @todo assert edge table has partitioner + + val vTableReplicationMap: IndexedRDD[Pid, VertexIdToIndexMap] = + msgsByPartition.mapPartitionsWithIndex( (pid, iter) => { + val vidToIndex = new VertexIdToIndexMap + var i = 0 + for (msg <- iter) { + vidToIndex.put(msg.data._1, i) + } + Array((pid, vidToIndex)).iterator + }, preservesPartitioning = true).indexed(eTable.index) + + val vTableReplicatedValues: IndexedRDD[Pid, Array[VD]] = + msgsByPartition.mapPartitionsWithIndex( (pid, iter) => { + val vertexArray = ArrayBuilder.make[VD] + for (msg <- iter) { + vertexArray += msg.data._2 + } + Array((pid, vertexArray.result)).iterator + }, preservesPartitioning = true).indexed(eTable.index) + + (vTableReplicationMap, vTableReplicatedValues) } diff --git a/graph/src/main/scala/org/apache/spark/graph/package.scala b/graph/src/main/scala/org/apache/spark/graph/package.scala index 474ace520f44052642a97846bb74dad4c088b0d5..47d5acb9e76ba55a9bac9972c2f382a4e575cec0 100644 --- a/graph/src/main/scala/org/apache/spark/graph/package.scala +++ b/graph/src/main/scala/org/apache/spark/graph/package.scala @@ -8,6 +8,8 @@ package object graph { type VertexHashMap[T] = it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap[T] type VertexSet = it.unimi.dsi.fastutil.longs.LongOpenHashSet type VertexArrayList = it.unimi.dsi.fastutil.longs.LongArrayList + // @todo replace with rxin's fast hashmap + type VertexIdToIndexMap = scala.collection.mutable.HashMap[Vid, Int] /** * Return the default null-like value for a data type T.