diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
index a0e1ad598944abc4ca6cea76e4052e562b20e629..485e49f95e0ca0027e5341d6ac67a9aab6c4269b 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
@@ -77,10 +77,10 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     @transient val vid2pid: Vid2Pid,
     @transient val localVidMap: RDD[(Pid, VertexIdToIndexMap)],
     @transient val eTable: RDD[(Pid, EdgePartition[ED])],
-    @transient val partitionStrategy: PartitionStrategy = RandomVertexCut)
+    @transient val partitioner: PartitionStrategy)
   extends Graph[VD, ED] {
 
-  def this() = this(null, null, null, null)
+  def this() = this(null, null, null, null, null)
 
   @transient val vTableReplicatedValues: VTableReplicatedValues[VD] =
     new VTableReplicatedValues(vTable, vid2pid, localVidMap)
@@ -97,7 +97,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
   @transient override val triplets: RDD[EdgeTriplet[VD, ED]] =
     makeTriplets(localVidMap, vTableReplicatedValues.bothAttrs, eTable)
 
-  @transient private val partitioner: PartitionStrategy = partitionStrategy
+  //@transient private val partitioner: PartitionStrategy = partitionStrategy
 
   override def persist(newLevel: StorageLevel): Graph[VD, ED] = {
     eTable.persist(newLevel)
@@ -192,18 +192,18 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
   override def reverse: Graph[VD, ED] = {
     val newEtable = eTable.mapPartitions( _.map{ case (pid, epart) => (pid, epart.reverse) },
       preservesPartitioning = true)
-    new GraphImpl(vTable, vid2pid, localVidMap, newEtable)
+    new GraphImpl(vTable, vid2pid, localVidMap, newEtable, partitioner)
   }
 
   override def mapVertices[VD2: ClassManifest](f: (Vid, VD) => VD2): Graph[VD2, ED] = {
     val newVTable = vTable.mapValuesWithKeys((vid, data) => f(vid, data))
-    new GraphImpl(newVTable, vid2pid, localVidMap, eTable)
+    new GraphImpl(newVTable, vid2pid, localVidMap, eTable, partitioner)
   }
 
   override def mapEdges[ED2: ClassManifest](f: Edge[ED] => ED2): Graph[VD, ED2] = {
     val newETable = eTable.mapPartitions(_.map{ case (pid, epart) => (pid, epart.map(f)) },
       preservesPartitioning = true)
-    new GraphImpl(vTable, vid2pid, localVidMap, newETable)
+    new GraphImpl(vTable, vid2pid, localVidMap, newETable, partitioner)
   }
 
   override def mapTriplets[ED2: ClassManifest](f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] =
@@ -237,7 +237,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
       triplets.filter(
         t => vpred( t.srcId, t.srcAttr ) && vpred( t.dstId, t.dstAttr ) && epred(t)
         )
-        .map( t => Edge(t.srcId, t.dstId, t.attr) ))
+        .map( t => Edge(t.srcId, t.dstId, t.attr) ), partitioner)
 
     // Construct the Vid2Pid map. Here we assume that the filter operation
     // behaves deterministically.
@@ -245,7 +245,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     val newVid2Pid = new Vid2Pid(newETable, newVTable.index)
     val newVidMap = createLocalVidMap(newETable)
 
-    new GraphImpl(newVTable, newVid2Pid, localVidMap, newETable)
+    new GraphImpl(newVTable, newVid2Pid, localVidMap, newETable, partitioner)
   }
 
   override def groupEdgeTriplets[ED2: ClassManifest](
@@ -268,8 +268,8 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
       }
 
       //TODO(crankshaw) eliminate the need to call createETable
-      val newETable = createETable(newEdges)
-      new GraphImpl(vTable, vid2pid, localVidMap, newETable)
+      val newETable = createETable(newEdges, partitioner)
+      new GraphImpl(vTable, vid2pid, localVidMap, newETable, partitioner)
   }
 
   override def groupEdges[ED2: ClassManifest](f: Iterator[Edge[ED]] => ED2 ):
@@ -284,9 +284,9 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
         .map { case ((src, dst), data) => Edge(src, dst, data) }
       }
       // TODO(crankshaw) eliminate the need to call createETable
-      val newETable = createETable(newEdges)
+      val newETable = createETable(newEdges, partitioner)
 
-      new GraphImpl(vTable, vid2pid, localVidMap, newETable)
+      new GraphImpl(vTable, vid2pid, localVidMap, newETable, partitioner)
   }
 
   //////////////////////////////////////////////////////////////////////////////////////////////////
@@ -304,7 +304,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     : Graph[VD2, ED] = {
     ClosureCleaner.clean(updateF)
     val newVTable = vTable.leftJoin(updates)(updateF)
-    new GraphImpl(newVTable, vid2pid, localVidMap, eTable)
+    new GraphImpl(newVTable, vid2pid, localVidMap, eTable, partitioner)
   }
 } // end of class GraphImpl
 
@@ -358,16 +358,17 @@ object GraphImpl {
 
     val vid2pid = new Vid2Pid(etable, vtable.index)
     val localVidMap = createLocalVidMap(etable)
-    new GraphImpl(vtable, vid2pid, localVidMap, etable)
+    new GraphImpl(vtable, vid2pid, localVidMap, etable, partitionStrategy)
   }
 
 
 
 
-  protected def createETable[ED: ClassManifest](edges: RDD[Edge[ED]])
-    : RDD[(Pid, EdgePartition[ED])] = {
-      createETable(edges, RandomVertexCut)
-  }
+  // TODO(crankshaw) - can I remove this
+  //protected def createETable[ED: ClassManifest](edges: RDD[Edge[ED]])
+  //  : RDD[(Pid, EdgePartition[ED])] = {
+  //    createETable(edges, RandomVertexCut)
+  //}
 
   /**
    * Create the edge table RDD, which is much more efficient for Java heap storage than the
@@ -384,10 +385,6 @@ object GraphImpl {
     val numPartitions = edges.partitions.size
 
     edges.map { e =>
-      // Random partitioning based on the source vertex id.
-      // val part: Pid = edgePartitionFunction1D(e.srcId, e.dstId, numPartitions)
-      // val part: Pid = edgePartitionFunction2D(e.srcId, e.dstId, numPartitions, ceilSqrt)
-      //val part: Pid = randomVertexCut(e.srcId, e.dstId, numPartitions)
       val part: Pid = partitionStrategy.getPartition(e.srcId, e.dstId, numPartitions)
 
       // Should we be using 3-tuple or an optimized class
@@ -449,7 +446,7 @@ object GraphImpl {
       }
       Iterator((pid, newEdgePartition))
     }
-    new GraphImpl(g.vTable, g.vid2pid, g.localVidMap, newETable)
+    new GraphImpl(g.vTable, g.vid2pid, g.localVidMap, newETable, g.partitioner)
   }
 
   def mapReduceTriplets[VD: ClassManifest, ED: ClassManifest, A: ClassManifest](