diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala
index c1513a00453cf8ab2fdd92f19f74839678fc07be..27529573eaa1273f7c31a2a4cfe822d792da45e7 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala
@@ -45,17 +45,6 @@ object Analytics extends Logging {
     }
     val options = mutable.Map(optionsList: _*)
 
-    def pickPartitioner(v: String): PartitionStrategy = {
-      // TODO: Use reflection rather than listing all the partitioning strategies here.
-      v match {
-        case "RandomVertexCut" => RandomVertexCut
-        case "EdgePartition1D" => EdgePartition1D
-        case "EdgePartition2D" => EdgePartition2D
-        case "CanonicalRandomVertexCut" => CanonicalRandomVertexCut
-        case _ => throw new IllegalArgumentException("Invalid PartitionStrategy: " + v)
-      }
-    }
-
     val conf = new SparkConf()
       .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
       .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
@@ -66,7 +55,7 @@ object Analytics extends Logging {
       sys.exit(1)
     }
     val partitionStrategy: Option[PartitionStrategy] = options.remove("partStrategy")
-      .map(pickPartitioner(_))
+      .map(PartitionStrategy.fromString(_))
     val edgeStorageLevel = options.remove("edgeStorageLevel")
       .map(StorageLevel.fromString(_)).getOrElse(StorageLevel.MEMORY_ONLY)
     val vertexStorageLevel = options.remove("vertexStorageLevel")
@@ -106,7 +95,7 @@ object Analytics extends Logging {
 
         if (!outFname.isEmpty) {
           logWarning("Saving pageranks of pages to " + outFname)
-          pr.map{case (id, r) => id + "\t" + r}.saveAsTextFile(outFname)
+          pr.map { case (id, r) => id + "\t" + r }.saveAsTextFile(outFname)
         }
 
         sc.stop()
@@ -128,7 +117,7 @@ object Analytics extends Logging {
         val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_))
 
         val cc = ConnectedComponents.run(graph)
-        println("Components: " + cc.vertices.map{ case (vid,data) => data}.distinct())
+        println("Components: " + cc.vertices.map { case (vid, data) => data }.distinct())
         sc.stop()
 
       case "triangles" =>
@@ -146,7 +135,7 @@ object Analytics extends Logging {
           minEdgePartitions = numEPart,
           edgeStorageLevel = edgeStorageLevel,
           vertexStorageLevel = vertexStorageLevel)
-        // TriangleCount requires the graph to be partitioned
+          // TriangleCount requires the graph to be partitioned
           .partitionBy(partitionStrategy.getOrElse(RandomVertexCut)).cache()
         val triangles = TriangleCount.run(graph)
         println("Triangles: " + triangles.vertices.map {