From 37a524d91c2eb03dd9e4a24d8af33769a89a78e3 Mon Sep 17 00:00:00 2001
From: Dan Crankshaw <dscrankshaw@gmail.com>
Date: Tue, 19 Nov 2013 16:39:39 -0800
Subject: [PATCH] Addressed code review comments.

---
 .../org/apache/spark/graph/Analytics.scala    | 12 +--
 .../scala/org/apache/spark/graph/Graph.scala  |  3 +-
 .../org/apache/spark/graph/GraphLab.scala     |  3 +-
 .../org/apache/spark/graph/GraphLoader.scala  |  2 +-
 .../spark/graph/PartitionStrategy.scala       |  8 +-
 .../scala/org/apache/spark/graph/Pregel.scala |  3 +-
 .../apache/spark/graph/impl/GraphImpl.scala   | 87 ++++++++++---------
 7 files changed, 59 insertions(+), 59 deletions(-)

diff --git a/graph/src/main/scala/org/apache/spark/graph/Analytics.scala b/graph/src/main/scala/org/apache/spark/graph/Analytics.scala
index 8455a145ff..f542ec6069 100644
--- a/graph/src/main/scala/org/apache/spark/graph/Analytics.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/Analytics.scala
@@ -241,7 +241,7 @@ object Analytics extends Logging {
          var outFname = ""
          var numVPart = 4
          var numEPart = 4
-         var partitionStrategy: PartitionStrategy = RandomVertexCut
+         var partitionStrategy: PartitionStrategy = RandomVertexCut()
 
          options.foreach{
            case ("numIter", v) => numIter = v.toInt
@@ -251,11 +251,11 @@ object Analytics extends Logging {
            case ("numVPart", v) => numVPart = v.toInt
            case ("numEPart", v) => numEPart = v.toInt
            case ("partStrategy", v) =>  {
-             v match {
-               case "RandomVertexCut" => partitionStrategy = RandomVertexCut
-               case "EdgePartition1D" => partitionStrategy = EdgePartition1D
-               case "EdgePartition2D" => partitionStrategy = EdgePartition2D
-               case "CanonicalRandomVertexCut" => partitionStrategy = CanonicalRandomVertexCut
+             partitionStrategy = v match {
+               case "RandomVertexCut" => RandomVertexCut()
+               case "EdgePartition1D" => EdgePartition1D()
+               case "EdgePartition2D" => EdgePartition2D()
+               case "CanonicalRandomVertexCut" => CanonicalRandomVertexCut()
                case _ => throw new IllegalArgumentException("Invalid Partition Strategy: " + v)
              }
            }
diff --git a/graph/src/main/scala/org/apache/spark/graph/Graph.scala b/graph/src/main/scala/org/apache/spark/graph/Graph.scala
index 6ce3f5d2e7..87667f6958 100644
--- a/graph/src/main/scala/org/apache/spark/graph/Graph.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/Graph.scala
@@ -1,7 +1,6 @@
 package org.apache.spark.graph
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.Logging
 import org.apache.spark.storage.StorageLevel
 
 /**
@@ -22,7 +21,7 @@ import org.apache.spark.storage.StorageLevel
  * @tparam VD the vertex attribute type
  * @tparam ED the edge attribute type
  */
-abstract class Graph[VD: ClassManifest, ED: ClassManifest] extends Logging {
+abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
 
   /**
    * Get the vertices and their data.
diff --git a/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala b/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala
index 39dc33acf0..b8503ab7fd 100644
--- a/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala
@@ -2,12 +2,11 @@ package org.apache.spark.graph
 
 import scala.collection.JavaConversions._
 import org.apache.spark.rdd.RDD
-import org.apache.spark.Logging
 
 /**
  * This object implements the GraphLab gather-apply-scatter api.
  */
-object GraphLab extends Logging {
+object GraphLab {
 
   /**
    * Execute the GraphLab Gather-Apply-Scatter API
diff --git a/graph/src/main/scala/org/apache/spark/graph/GraphLoader.scala b/graph/src/main/scala/org/apache/spark/graph/GraphLoader.scala
index 813f176313..4dc33a02ce 100644
--- a/graph/src/main/scala/org/apache/spark/graph/GraphLoader.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/GraphLoader.scala
@@ -28,7 +28,7 @@ object GraphLoader {
       edgeParser: Array[String] => ED,
       minEdgePartitions: Int = 1,
       minVertexPartitions: Int = 1,
-      partitionStrategy: PartitionStrategy = RandomVertexCut): GraphImpl[Int, ED] = {
+      partitionStrategy: PartitionStrategy = RandomVertexCut()): GraphImpl[Int, ED] = {
 
     // Parse the edge data table
     val edges = sc.textFile(path, minEdgePartitions).flatMap { line =>
diff --git a/graph/src/main/scala/org/apache/spark/graph/PartitionStrategy.scala b/graph/src/main/scala/org/apache/spark/graph/PartitionStrategy.scala
index f7db667e2f..cf65f50657 100644
--- a/graph/src/main/scala/org/apache/spark/graph/PartitionStrategy.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/PartitionStrategy.scala
@@ -50,7 +50,7 @@ sealed trait PartitionStrategy extends Serializable {
  *
  *
  */
-object EdgePartition2D extends PartitionStrategy {
+case class EdgePartition2D() extends PartitionStrategy {
   override def getPartition(src: Vid, dst: Vid, numParts: Pid): Pid = {
     val ceilSqrtNumParts: Pid = math.ceil(math.sqrt(numParts)).toInt
     val mixingPrime: Vid = 1125899906842597L
@@ -61,7 +61,7 @@ object EdgePartition2D extends PartitionStrategy {
 }
 
 
-object EdgePartition1D extends PartitionStrategy {
+case class EdgePartition1D() extends PartitionStrategy {
   override def getPartition(src: Vid, dst: Vid, numParts: Pid): Pid = {
     val mixingPrime: Vid = 1125899906842597L
     (math.abs(src) * mixingPrime).toInt % numParts
@@ -73,7 +73,7 @@ object EdgePartition1D extends PartitionStrategy {
  * Assign edges to an aribtrary machine corresponding to a
  * random vertex cut.
  */
-object RandomVertexCut extends PartitionStrategy {
+case class RandomVertexCut() extends PartitionStrategy {
   override def getPartition(src: Vid, dst: Vid, numParts: Pid): Pid = {
     math.abs((src, dst).hashCode()) % numParts
   }
@@ -85,7 +85,7 @@ object RandomVertexCut extends PartitionStrategy {
  * function ensures that edges of opposite direction between the same two vertices
  * will end up on the same partition.
  */
-object CanonicalRandomVertexCut extends PartitionStrategy {
+case class CanonicalRandomVertexCut() extends PartitionStrategy {
   override def getPartition(src: Vid, dst: Vid, numParts: Pid): Pid = {
     val lower = math.min(src, dst)
     val higher = math.max(src, dst)
diff --git a/graph/src/main/scala/org/apache/spark/graph/Pregel.scala b/graph/src/main/scala/org/apache/spark/graph/Pregel.scala
index f3016e6ad3..3b4d3c0df2 100644
--- a/graph/src/main/scala/org/apache/spark/graph/Pregel.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/Pregel.scala
@@ -1,7 +1,6 @@
 package org.apache.spark.graph
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.Logging
 
 
 /**
@@ -42,7 +41,7 @@ import org.apache.spark.Logging
  * }}}
  *
  */
-object Pregel extends Logging {
+object Pregel {
 
 
   /**
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
index 7c3d401832..6ad0ce60a7 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
@@ -8,6 +8,7 @@ import scala.collection.mutable.ArrayBuffer
 import org.apache.spark.SparkContext._
 import org.apache.spark.HashPartitioner
 import org.apache.spark.util.ClosureCleaner
+import org.apache.spark.SparkException
 
 import org.apache.spark.Partitioner
 import org.apache.spark.graph._
@@ -97,8 +98,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
   @transient override val triplets: RDD[EdgeTriplet[VD, ED]] =
     makeTriplets(localVidMap, vTableReplicatedValues.bothAttrs, eTable)
 
-  //@transient private val partitioner: PartitionStrategy = partitionStrategy
-
   override def persist(newLevel: StorageLevel): Graph[VD, ED] = {
     eTable.persist(newLevel)
     vid2pid.persist(newLevel)
@@ -250,43 +249,55 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
 
   override def groupEdgeTriplets[ED2: ClassManifest](
     f: Iterator[EdgeTriplet[VD,ED]] => ED2 ): Graph[VD,ED2] = {
-      val newEdges: RDD[Edge[ED2]] = triplets.mapPartitions { partIter =>
-        partIter
-        // TODO(crankshaw) toList requires that the entire edge partition
-        // can fit in memory right now.
-        .toList
-        // groups all ETs in this partition that have the same src and dst
-        // Because all ETs with the same src and dst will live on the same
-        // partition due to the canonicalRandomVertexCut partitioner, this
-        // guarantees that these ET groups will be complete.
-        .groupBy { t: EdgeTriplet[VD, ED] =>  (t.srcId, t.dstId) }
-        .mapValues { ts: List[EdgeTriplet[VD, ED]] => f(ts.toIterator) }
-        .toList
-        .toIterator
-        .map { case ((src, dst), data) => Edge(src, dst, data) }
-        .toIterator
-      }
+      partitioner match {
+        case _: CanonicalRandomVertexCut => {
+          val newEdges: RDD[Edge[ED2]] = triplets.mapPartitions { partIter =>
+            partIter
+            // TODO(crankshaw) toList requires that the entire edge partition
+            // can fit in memory right now.
+            .toList
+            // groups all ETs in this partition that have the same src and dst
+            // Because all ETs with the same src and dst will live on the same
+            // partition due to the canonicalRandomVertexCut partitioner, this
+            // guarantees that these ET groups will be complete.
+            .groupBy { t: EdgeTriplet[VD, ED] =>  (t.srcId, t.dstId) }
+            .mapValues { ts: List[EdgeTriplet[VD, ED]] => f(ts.toIterator) }
+            .toList
+            .toIterator
+            .map { case ((src, dst), data) => Edge(src, dst, data) }
+            .toIterator
+          }
+          //TODO(crankshaw) eliminate the need to call createETable
+          val newETable = createETable(newEdges, partitioner)
+          new GraphImpl(vTable, vid2pid, localVidMap, newETable, partitioner)
+        }
 
-      //TODO(crankshaw) eliminate the need to call createETable
-      val newETable = createETable(newEdges, partitioner)
-      new GraphImpl(vTable, vid2pid, localVidMap, newETable, partitioner)
+        case _ => throw new SparkException(partitioner.getClass.getName
+          + " is incompatible with groupEdgeTriplets")
+      }
   }
 
   override def groupEdges[ED2: ClassManifest](f: Iterator[Edge[ED]] => ED2 ):
     Graph[VD,ED2] = {
+      partitioner match {
+        case _: CanonicalRandomVertexCut => {
+          val newEdges: RDD[Edge[ED2]] = edges.mapPartitions { partIter =>
+            partIter.toList
+            .groupBy { e: Edge[ED] => (e.srcId, e.dstId) }
+            .mapValues { ts => f(ts.toIterator) }
+            .toList
+            .toIterator
+            .map { case ((src, dst), data) => Edge(src, dst, data) }
+          }
+          // TODO(crankshaw) eliminate the need to call createETable
+          val newETable = createETable(newEdges, partitioner)
 
-      val newEdges: RDD[Edge[ED2]] = edges.mapPartitions { partIter =>
-        partIter.toList
-        .groupBy { e: Edge[ED] => (e.srcId, e.dstId) }
-        .mapValues { ts => f(ts.toIterator) }
-        .toList
-        .toIterator
-        .map { case ((src, dst), data) => Edge(src, dst, data) }
-      }
-      // TODO(crankshaw) eliminate the need to call createETable
-      val newETable = createETable(newEdges, partitioner)
+          new GraphImpl(vTable, vid2pid, localVidMap, newETable, partitioner)
+        }
 
-      new GraphImpl(vTable, vid2pid, localVidMap, newETable, partitioner)
+        case _ => throw new SparkException(partitioner.getClass.getName
+          + " is incompatible with groupEdges")
+      }
   }
 
   //////////////////////////////////////////////////////////////////////////////////////////////////
@@ -315,7 +326,7 @@ object GraphImpl {
     vertices: RDD[(Vid, VD)],
     edges: RDD[Edge[ED]],
     defaultVertexAttr: VD): GraphImpl[VD,ED] = {
-    apply(vertices, edges, defaultVertexAttr, (a:VD, b:VD) => a, RandomVertexCut)
+    apply(vertices, edges, defaultVertexAttr, (a:VD, b:VD) => a, RandomVertexCut())
   }
 
   def apply[VD: ClassManifest, ED: ClassManifest](
@@ -331,7 +342,7 @@ object GraphImpl {
     edges: RDD[Edge[ED]],
     defaultVertexAttr: VD,
     mergeFunc: (VD, VD) => VD): GraphImpl[VD,ED] = {
-    apply(vertices, edges, defaultVertexAttr, mergeFunc, RandomVertexCut)
+    apply(vertices, edges, defaultVertexAttr, mergeFunc, RandomVertexCut())
   }
 
   def apply[VD: ClassManifest, ED: ClassManifest](
@@ -362,14 +373,6 @@ object GraphImpl {
   }
 
 
-
-
-  // TODO(crankshaw) - can I remove this
-  //protected def createETable[ED: ClassManifest](edges: RDD[Edge[ED]])
-  //  : RDD[(Pid, EdgePartition[ED])] = {
-  //    createETable(edges, RandomVertexCut)
-  //}
-
   /**
    * Create the edge table RDD, which is much more efficient for Java heap storage than the
    * normal edges data structure (RDD[(Vid, Vid, ED)]).
-- 
GitLab