From 80abc2807825d69b0f7a5e374eb6e6442332f400 Mon Sep 17 00:00:00 2001 From: Ankur Dave <ankurdave@gmail.com> Date: Wed, 6 Nov 2013 22:50:30 -0800 Subject: [PATCH] Optimize mrTriplets for source-attr-only mapF using bytecode inspection --- .../apache/spark/graph/impl/GraphImpl.scala | 41 +++++++++++++++++-- .../org/apache/spark/graph/GraphSuite.scala | 24 ++++++++++- 2 files changed, 60 insertions(+), 5 deletions(-) diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala index 0d7546b575..64fdb10831 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala @@ -12,6 +12,7 @@ import org.apache.spark.util.ClosureCleaner import org.apache.spark.graph._ import org.apache.spark.graph.impl.GraphImpl._ import org.apache.spark.graph.impl.MsgRDDFunctions._ +import org.apache.spark.graph.util.BytecodeUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveKeyOpenHashMap} @@ -384,6 +385,22 @@ object GraphImpl { .mapValues(a => a.toArray).cache() } + protected def createVid2PidSourceAttrOnly[ED: ClassManifest]( + eTable: RDD[(Pid, EdgePartition[ED])], + vTableIndex: VertexSetIndex): VertexSetRDD[Array[Pid]] = { + val preAgg = eTable.mapPartitions { iter => + val (pid, edgePartition) = iter.next() + val vSet = new VertexSet + edgePartition.foreach(e => {vSet.add(e.srcId)}) + vSet.iterator.map { vid => (vid.toLong, pid) } + } + VertexSetRDD[Pid, ArrayBuffer[Pid]](preAgg, vTableIndex, + (p: Pid) => ArrayBuffer(p), + (ab: ArrayBuffer[Pid], p:Pid) => {ab.append(p); ab}, + (a: ArrayBuffer[Pid], b: ArrayBuffer[Pid]) => a ++ b) + .mapValues(a => a.toArray).cache() + } + protected def createLocalVidMap[ED: ClassManifest](eTable: RDD[(Pid, EdgePartition[ED])]): RDD[(Pid, VertexIdToIndexMap)] = { eTable.mapPartitions( _.map{ case (pid, epart) => @@ -468,8 +485,22 @@ object GraphImpl { ClosureCleaner.clean(mapFunc) ClosureCleaner.clean(reduceFunc) + // For each vertex, replicate its attribute only to partitions where it is + // in the relevant position in an edge. + val mapFuncUsesSrcAttr = + BytecodeUtils.invokedMethod(mapFunc, classOf[EdgeTriplet[VD, ED]], "srcAttr") + val mapFuncUsesDstAttr = + BytecodeUtils.invokedMethod(mapFunc, classOf[EdgeTriplet[VD, ED]], "dstAttr") + val vTableReplicatedValues = + if (mapFuncUsesSrcAttr && !mapFuncUsesDstAttr) { + val vid2pidSourceAttrOnly = createVid2PidSourceAttrOnly(g.eTable, g.vTable.index) + createVTableReplicated(g.vTable, vid2pidSourceAttrOnly, g.localVidMap) + } else { + g.vTableReplicatedValues + } + // Map and preaggregate - val preAgg = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValues){ + val preAgg = g.eTable.zipPartitions(g.localVidMap, vTableReplicatedValues){ (edgePartitionIter, vidToIndexIter, vertexArrayIter) => val (_, edgePartition) = edgePartitionIter.next() val (_, vidToIndex) = vidToIndexIter.next() @@ -488,8 +519,12 @@ object GraphImpl { edgePartition.foreach { e => et.set(e) - et.srcAttr = vmap(e.srcId) - et.dstAttr = vmap(e.dstId) + if (mapFuncUsesSrcAttr) { + et.srcAttr = vmap(e.srcId) + } + if (mapFuncUsesDstAttr) { + et.dstAttr = vmap(e.dstId) + } // TODO(rxin): rewrite the foreach using a simple while loop to speed things up. // Also given we are only allowing zero, one, or two messages, we can completely unroll // the for loop. diff --git a/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala b/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala index ec548bda16..37fb60c4cc 100644 --- a/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala +++ b/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala @@ -58,6 +58,26 @@ class GraphSuite extends FunSuite with LocalSparkContext { } } + test("aggregateNeighborsSourceAttrOnly") { + withSpark(new SparkContext("local", "test")) { sc => + val n = 3 + // Create a star graph where the degree of each vertex is its attribute + val star = Graph(sc.parallelize((1 to n).map(x => ((n + 1): Vid, x: Vid)))) + + val totalOfInNeighborDegrees = star.aggregateNeighbors( + (vid, edge) => { + // All edges have the center vertex as the source, which has degree n + if (edge.srcAttr != n) { + throw new Exception("edge.srcAttr is %d, expected %d".format(edge.srcAttr, n)) + } + Some(edge.srcAttr) + }, + (a: Int, b: Int) => a + b, + EdgeDirection.In) + assert(totalOfInNeighborDegrees.collect().toSet === (1 to n).map(x => (x, n)).toSet) + } + } + test("joinVertices") { withSpark(new SparkContext("local", "test")) { sc => val vertices = sc.parallelize(Seq[(Vid, String)]((1, "one"), (2, "two"), (3, "three")), 2) @@ -87,6 +107,6 @@ class GraphSuite extends FunSuite with LocalSparkContext { assert(b.zipJoin(c)((id, b, c) => b + c).map(x => x._2).reduce(_+_) === 0) } - } - + } + } -- GitLab