diff --git a/graph/src/main/scala/org/apache/spark/graph/Analytics.scala b/graph/src/main/scala/org/apache/spark/graph/Analytics.scala
index 8455a145ffffdcc8d9e1e288877e618ee4788a0e..f542ec60695afb0eb77f4422d903f431b062e0b4 100644
--- a/graph/src/main/scala/org/apache/spark/graph/Analytics.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/Analytics.scala
@@ -241,7 +241,7 @@ object Analytics extends Logging {
          var outFname = ""
          var numVPart = 4
          var numEPart = 4
-         var partitionStrategy: PartitionStrategy = RandomVertexCut
+         var partitionStrategy: PartitionStrategy = RandomVertexCut()
 
          options.foreach{
            case ("numIter", v) => numIter = v.toInt
@@ -251,11 +251,11 @@ object Analytics extends Logging {
            case ("numVPart", v) => numVPart = v.toInt
            case ("numEPart", v) => numEPart = v.toInt
            case ("partStrategy", v) =>  {
-             v match {
-               case "RandomVertexCut" => partitionStrategy = RandomVertexCut
-               case "EdgePartition1D" => partitionStrategy = EdgePartition1D
-               case "EdgePartition2D" => partitionStrategy = EdgePartition2D
-               case "CanonicalRandomVertexCut" => partitionStrategy = CanonicalRandomVertexCut
+             partitionStrategy = v match {
+               case "RandomVertexCut" => RandomVertexCut()
+               case "EdgePartition1D" => EdgePartition1D()
+               case "EdgePartition2D" => EdgePartition2D()
+               case "CanonicalRandomVertexCut" => CanonicalRandomVertexCut()
                case _ => throw new IllegalArgumentException("Invalid Partition Strategy: " + v)
              }
            }
diff --git a/graph/src/main/scala/org/apache/spark/graph/Graph.scala b/graph/src/main/scala/org/apache/spark/graph/Graph.scala
index 6ce3f5d2e75c05b21a88c9151d0b9c1f5b10c6de..87667f69586efa7c7b469fb0b3e00ded280b8572 100644
--- a/graph/src/main/scala/org/apache/spark/graph/Graph.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/Graph.scala
@@ -1,7 +1,6 @@
 package org.apache.spark.graph
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.Logging
 import org.apache.spark.storage.StorageLevel
 
 /**
@@ -22,7 +21,7 @@ import org.apache.spark.storage.StorageLevel
  * @tparam VD the vertex attribute type
  * @tparam ED the edge attribute type
  */
-abstract class Graph[VD: ClassManifest, ED: ClassManifest] extends Logging {
+abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
 
   /**
    * Get the vertices and their data.
diff --git a/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala b/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala
index 39dc33acf0f6e954290f6bda1e3f5dba4f3b96e9..b8503ab7fdb6c266161f31b3308bf7fcd554c1f8 100644
--- a/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala
@@ -2,12 +2,11 @@ package org.apache.spark.graph
 
 import scala.collection.JavaConversions._
 import org.apache.spark.rdd.RDD
-import org.apache.spark.Logging
 
 /**
  * This object implements the GraphLab gather-apply-scatter api.
  */
-object GraphLab extends Logging {
+object GraphLab {
 
   /**
    * Execute the GraphLab Gather-Apply-Scatter API
diff --git a/graph/src/main/scala/org/apache/spark/graph/GraphLoader.scala b/graph/src/main/scala/org/apache/spark/graph/GraphLoader.scala
index 813f176313a80d27d3052be29f9a26120cec97c5..4dc33a02ceacce4c6d32900e0103956a47dd53bd 100644
--- a/graph/src/main/scala/org/apache/spark/graph/GraphLoader.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/GraphLoader.scala
@@ -28,7 +28,7 @@ object GraphLoader {
       edgeParser: Array[String] => ED,
       minEdgePartitions: Int = 1,
       minVertexPartitions: Int = 1,
-      partitionStrategy: PartitionStrategy = RandomVertexCut): GraphImpl[Int, ED] = {
+      partitionStrategy: PartitionStrategy = RandomVertexCut()): GraphImpl[Int, ED] = {
 
     // Parse the edge data table
     val edges = sc.textFile(path, minEdgePartitions).flatMap { line =>
diff --git a/graph/src/main/scala/org/apache/spark/graph/PartitionStrategy.scala b/graph/src/main/scala/org/apache/spark/graph/PartitionStrategy.scala
index f7db667e2ff0553c233d7e23c9704fe2267d4f76..cf65f5065786a8a6c26930f24088c8c406808350 100644
--- a/graph/src/main/scala/org/apache/spark/graph/PartitionStrategy.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/PartitionStrategy.scala
@@ -50,7 +50,7 @@ sealed trait PartitionStrategy extends Serializable {
  *
  *
  */
-object EdgePartition2D extends PartitionStrategy {
+case class EdgePartition2D() extends PartitionStrategy {
   override def getPartition(src: Vid, dst: Vid, numParts: Pid): Pid = {
     val ceilSqrtNumParts: Pid = math.ceil(math.sqrt(numParts)).toInt
     val mixingPrime: Vid = 1125899906842597L
@@ -61,7 +61,7 @@ object EdgePartition2D extends PartitionStrategy {
 }
 
 
-object EdgePartition1D extends PartitionStrategy {
+case class EdgePartition1D() extends PartitionStrategy {
   override def getPartition(src: Vid, dst: Vid, numParts: Pid): Pid = {
     val mixingPrime: Vid = 1125899906842597L
     (math.abs(src) * mixingPrime).toInt % numParts
@@ -73,7 +73,7 @@ object EdgePartition1D extends PartitionStrategy {
  * Assign edges to an aribtrary machine corresponding to a
  * random vertex cut.
  */
-object RandomVertexCut extends PartitionStrategy {
+case class RandomVertexCut() extends PartitionStrategy {
   override def getPartition(src: Vid, dst: Vid, numParts: Pid): Pid = {
     math.abs((src, dst).hashCode()) % numParts
   }
@@ -85,7 +85,7 @@ object RandomVertexCut extends PartitionStrategy {
  * function ensures that edges of opposite direction between the same two vertices
  * will end up on the same partition.
  */
-object CanonicalRandomVertexCut extends PartitionStrategy {
+case class CanonicalRandomVertexCut() extends PartitionStrategy {
   override def getPartition(src: Vid, dst: Vid, numParts: Pid): Pid = {
     val lower = math.min(src, dst)
     val higher = math.max(src, dst)
diff --git a/graph/src/main/scala/org/apache/spark/graph/Pregel.scala b/graph/src/main/scala/org/apache/spark/graph/Pregel.scala
index f3016e6ad3ee6b164cd542bc93bf141f449a02f5..3b4d3c0df2a51ca178194f4623ff02e4baa960ec 100644
--- a/graph/src/main/scala/org/apache/spark/graph/Pregel.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/Pregel.scala
@@ -1,7 +1,6 @@
 package org.apache.spark.graph
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.Logging
 
 
 /**
@@ -42,7 +41,7 @@ import org.apache.spark.Logging
  * }}}
  *
  */
-object Pregel extends Logging {
+object Pregel {
 
 
   /**
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
index 7c3d40183248214df960121777c0eece658ca667..6ad0ce60a7adda07537431c3d4ba05a40cce795b 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
@@ -8,6 +8,7 @@ import scala.collection.mutable.ArrayBuffer
 import org.apache.spark.SparkContext._
 import org.apache.spark.HashPartitioner
 import org.apache.spark.util.ClosureCleaner
+import org.apache.spark.SparkException
 
 import org.apache.spark.Partitioner
 import org.apache.spark.graph._
@@ -97,8 +98,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
   @transient override val triplets: RDD[EdgeTriplet[VD, ED]] =
     makeTriplets(localVidMap, vTableReplicatedValues.bothAttrs, eTable)
 
-  //@transient private val partitioner: PartitionStrategy = partitionStrategy
-
   override def persist(newLevel: StorageLevel): Graph[VD, ED] = {
     eTable.persist(newLevel)
     vid2pid.persist(newLevel)
@@ -250,43 +249,55 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
 
   override def groupEdgeTriplets[ED2: ClassManifest](
     f: Iterator[EdgeTriplet[VD,ED]] => ED2 ): Graph[VD,ED2] = {
-      val newEdges: RDD[Edge[ED2]] = triplets.mapPartitions { partIter =>
-        partIter
-        // TODO(crankshaw) toList requires that the entire edge partition
-        // can fit in memory right now.
-        .toList
-        // groups all ETs in this partition that have the same src and dst
-        // Because all ETs with the same src and dst will live on the same
-        // partition due to the canonicalRandomVertexCut partitioner, this
-        // guarantees that these ET groups will be complete.
-        .groupBy { t: EdgeTriplet[VD, ED] =>  (t.srcId, t.dstId) }
-        .mapValues { ts: List[EdgeTriplet[VD, ED]] => f(ts.toIterator) }
-        .toList
-        .toIterator
-        .map { case ((src, dst), data) => Edge(src, dst, data) }
-        .toIterator
-      }
+      partitioner match {
+        case _: CanonicalRandomVertexCut => {
+          val newEdges: RDD[Edge[ED2]] = triplets.mapPartitions { partIter =>
+            partIter
+            // TODO(crankshaw) toList requires that the entire edge partition
+            // can fit in memory right now.
+            .toList
+            // groups all ETs in this partition that have the same src and dst
+            // Because all ETs with the same src and dst will live on the same
+            // partition due to the canonicalRandomVertexCut partitioner, this
+            // guarantees that these ET groups will be complete.
+            .groupBy { t: EdgeTriplet[VD, ED] =>  (t.srcId, t.dstId) }
+            .mapValues { ts: List[EdgeTriplet[VD, ED]] => f(ts.toIterator) }
+            .toList
+            .toIterator
+            .map { case ((src, dst), data) => Edge(src, dst, data) }
+            .toIterator
+          }
+          //TODO(crankshaw) eliminate the need to call createETable
+          val newETable = createETable(newEdges, partitioner)
+          new GraphImpl(vTable, vid2pid, localVidMap, newETable, partitioner)
+        }
 
-      //TODO(crankshaw) eliminate the need to call createETable
-      val newETable = createETable(newEdges, partitioner)
-      new GraphImpl(vTable, vid2pid, localVidMap, newETable, partitioner)
+        case _ => throw new SparkException(partitioner.getClass.getName
+          + " is incompatible with groupEdgeTriplets")
+      }
   }
 
   override def groupEdges[ED2: ClassManifest](f: Iterator[Edge[ED]] => ED2 ):
     Graph[VD,ED2] = {
+      partitioner match {
+        case _: CanonicalRandomVertexCut => {
+          val newEdges: RDD[Edge[ED2]] = edges.mapPartitions { partIter =>
+            partIter.toList
+            .groupBy { e: Edge[ED] => (e.srcId, e.dstId) }
+            .mapValues { ts => f(ts.toIterator) }
+            .toList
+            .toIterator
+            .map { case ((src, dst), data) => Edge(src, dst, data) }
+          }
+          // TODO(crankshaw) eliminate the need to call createETable
+          val newETable = createETable(newEdges, partitioner)
 
-      val newEdges: RDD[Edge[ED2]] = edges.mapPartitions { partIter =>
-        partIter.toList
-        .groupBy { e: Edge[ED] => (e.srcId, e.dstId) }
-        .mapValues { ts => f(ts.toIterator) }
-        .toList
-        .toIterator
-        .map { case ((src, dst), data) => Edge(src, dst, data) }
-      }
-      // TODO(crankshaw) eliminate the need to call createETable
-      val newETable = createETable(newEdges, partitioner)
+          new GraphImpl(vTable, vid2pid, localVidMap, newETable, partitioner)
+        }
 
-      new GraphImpl(vTable, vid2pid, localVidMap, newETable, partitioner)
+        case _ => throw new SparkException(partitioner.getClass.getName
+          + " is incompatible with groupEdges")
+      }
   }
 
   //////////////////////////////////////////////////////////////////////////////////////////////////
@@ -315,7 +326,7 @@ object GraphImpl {
     vertices: RDD[(Vid, VD)],
     edges: RDD[Edge[ED]],
     defaultVertexAttr: VD): GraphImpl[VD,ED] = {
-    apply(vertices, edges, defaultVertexAttr, (a:VD, b:VD) => a, RandomVertexCut)
+    apply(vertices, edges, defaultVertexAttr, (a:VD, b:VD) => a, RandomVertexCut())
   }
 
   def apply[VD: ClassManifest, ED: ClassManifest](
@@ -331,7 +342,7 @@ object GraphImpl {
     edges: RDD[Edge[ED]],
     defaultVertexAttr: VD,
     mergeFunc: (VD, VD) => VD): GraphImpl[VD,ED] = {
-    apply(vertices, edges, defaultVertexAttr, mergeFunc, RandomVertexCut)
+    apply(vertices, edges, defaultVertexAttr, mergeFunc, RandomVertexCut())
   }
 
   def apply[VD: ClassManifest, ED: ClassManifest](
@@ -362,14 +373,6 @@ object GraphImpl {
   }
 
 
-
-
-  // TODO(crankshaw) - can I remove this
-  //protected def createETable[ED: ClassManifest](edges: RDD[Edge[ED]])
-  //  : RDD[(Pid, EdgePartition[ED])] = {
-  //    createETable(edges, RandomVertexCut)
-  //}
-
   /**
    * Create the edge table RDD, which is much more efficient for Java heap storage than the
    * normal edges data structure (RDD[(Vid, Vid, ED)]).