diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala index a0e1ad598944abc4ca6cea76e4052e562b20e629..485e49f95e0ca0027e5341d6ac67a9aab6c4269b 100644 --- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala +++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala @@ -77,10 +77,10 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( @transient val vid2pid: Vid2Pid, @transient val localVidMap: RDD[(Pid, VertexIdToIndexMap)], @transient val eTable: RDD[(Pid, EdgePartition[ED])], - @transient val partitionStrategy: PartitionStrategy = RandomVertexCut) + @transient val partitioner: PartitionStrategy) extends Graph[VD, ED] { - def this() = this(null, null, null, null) + def this() = this(null, null, null, null, null) @transient val vTableReplicatedValues: VTableReplicatedValues[VD] = new VTableReplicatedValues(vTable, vid2pid, localVidMap) @@ -97,7 +97,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( @transient override val triplets: RDD[EdgeTriplet[VD, ED]] = makeTriplets(localVidMap, vTableReplicatedValues.bothAttrs, eTable) - @transient private val partitioner: PartitionStrategy = partitionStrategy + //@transient private val partitioner: PartitionStrategy = partitionStrategy override def persist(newLevel: StorageLevel): Graph[VD, ED] = { eTable.persist(newLevel) @@ -192,18 +192,18 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( override def reverse: Graph[VD, ED] = { val newEtable = eTable.mapPartitions( _.map{ case (pid, epart) => (pid, epart.reverse) }, preservesPartitioning = true) - new GraphImpl(vTable, vid2pid, localVidMap, newEtable) + new GraphImpl(vTable, vid2pid, localVidMap, newEtable, partitioner) } override def mapVertices[VD2: ClassManifest](f: (Vid, VD) => VD2): Graph[VD2, ED] = { val newVTable = vTable.mapValuesWithKeys((vid, data) => f(vid, data)) - new GraphImpl(newVTable, vid2pid, localVidMap, eTable) + new GraphImpl(newVTable, vid2pid, localVidMap, eTable, partitioner) } override def mapEdges[ED2: ClassManifest](f: Edge[ED] => ED2): Graph[VD, ED2] = { val newETable = eTable.mapPartitions(_.map{ case (pid, epart) => (pid, epart.map(f)) }, preservesPartitioning = true) - new GraphImpl(vTable, vid2pid, localVidMap, newETable) + new GraphImpl(vTable, vid2pid, localVidMap, newETable, partitioner) } override def mapTriplets[ED2: ClassManifest](f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = @@ -237,7 +237,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( triplets.filter( t => vpred( t.srcId, t.srcAttr ) && vpred( t.dstId, t.dstAttr ) && epred(t) ) - .map( t => Edge(t.srcId, t.dstId, t.attr) )) + .map( t => Edge(t.srcId, t.dstId, t.attr) ), partitioner) // Construct the Vid2Pid map. Here we assume that the filter operation // behaves deterministically. @@ -245,7 +245,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( val newVid2Pid = new Vid2Pid(newETable, newVTable.index) val newVidMap = createLocalVidMap(newETable) - new GraphImpl(newVTable, newVid2Pid, localVidMap, newETable) + new GraphImpl(newVTable, newVid2Pid, localVidMap, newETable, partitioner) } override def groupEdgeTriplets[ED2: ClassManifest]( @@ -268,8 +268,8 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( } //TODO(crankshaw) eliminate the need to call createETable - val newETable = createETable(newEdges) - new GraphImpl(vTable, vid2pid, localVidMap, newETable) + val newETable = createETable(newEdges, partitioner) + new GraphImpl(vTable, vid2pid, localVidMap, newETable, partitioner) } override def groupEdges[ED2: ClassManifest](f: Iterator[Edge[ED]] => ED2 ): @@ -284,9 +284,9 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( .map { case ((src, dst), data) => Edge(src, dst, data) } } // TODO(crankshaw) eliminate the need to call createETable - val newETable = createETable(newEdges) + val newETable = createETable(newEdges, partitioner) - new GraphImpl(vTable, vid2pid, localVidMap, newETable) + new GraphImpl(vTable, vid2pid, localVidMap, newETable, partitioner) } ////////////////////////////////////////////////////////////////////////////////////////////////// @@ -304,7 +304,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected ( : Graph[VD2, ED] = { ClosureCleaner.clean(updateF) val newVTable = vTable.leftJoin(updates)(updateF) - new GraphImpl(newVTable, vid2pid, localVidMap, eTable) + new GraphImpl(newVTable, vid2pid, localVidMap, eTable, partitioner) } } // end of class GraphImpl @@ -358,16 +358,17 @@ object GraphImpl { val vid2pid = new Vid2Pid(etable, vtable.index) val localVidMap = createLocalVidMap(etable) - new GraphImpl(vtable, vid2pid, localVidMap, etable) + new GraphImpl(vtable, vid2pid, localVidMap, etable, partitionStrategy) } - protected def createETable[ED: ClassManifest](edges: RDD[Edge[ED]]) - : RDD[(Pid, EdgePartition[ED])] = { - createETable(edges, RandomVertexCut) - } + // TODO(crankshaw) - can I remove this + //protected def createETable[ED: ClassManifest](edges: RDD[Edge[ED]]) + // : RDD[(Pid, EdgePartition[ED])] = { + // createETable(edges, RandomVertexCut) + //} /** * Create the edge table RDD, which is much more efficient for Java heap storage than the @@ -384,10 +385,6 @@ object GraphImpl { val numPartitions = edges.partitions.size edges.map { e => - // Random partitioning based on the source vertex id. - // val part: Pid = edgePartitionFunction1D(e.srcId, e.dstId, numPartitions) - // val part: Pid = edgePartitionFunction2D(e.srcId, e.dstId, numPartitions, ceilSqrt) - //val part: Pid = randomVertexCut(e.srcId, e.dstId, numPartitions) val part: Pid = partitionStrategy.getPartition(e.srcId, e.dstId, numPartitions) // Should we be using 3-tuple or an optimized class @@ -449,7 +446,7 @@ object GraphImpl { } Iterator((pid, newEdgePartition)) } - new GraphImpl(g.vTable, g.vid2pid, g.localVidMap, newETable) + new GraphImpl(g.vTable, g.vid2pid, g.localVidMap, newETable, g.partitioner) } def mapReduceTriplets[VD: ClassManifest, ED: ClassManifest, A: ClassManifest](