From 46b195253ecb54ff8a202a53773fc9388b2c753c Mon Sep 17 00:00:00 2001
From: "Joseph E. Gonzalez" <joseph.e.gonzalez@gmail.com>
Date: Tue, 22 Oct 2013 15:01:49 -0700
Subject: [PATCH] Adding some additional graph generators to support unit
 testing of the analytics package.

---
 .../spark/graph/util/GraphGenerators.scala    | 46 ++++++++++++++++++-
 1 file changed, 45 insertions(+), 1 deletion(-)

diff --git a/graph/src/main/scala/org/apache/spark/graph/util/GraphGenerators.scala b/graph/src/main/scala/org/apache/spark/graph/util/GraphGenerators.scala
index 895c65c14c..1bbcce5039 100644
--- a/graph/src/main/scala/org/apache/spark/graph/util/GraphGenerators.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/util/GraphGenerators.scala
@@ -236,7 +236,51 @@ object GraphGenerators {
     }
   }
 
-}
+
+
+  /**
+   * Create `rows` by `cols` grid graph with each vertex connected to its
+   * row+1 and col+1 neighbors.  Vertex ids are assigned in row major
+   * order.
+   * 
+   * @param sc the spark context in which to construct the graph
+   * @param rows the number of rows
+   * @param cols the number of columns
+   *
+   * @return A graph containing vertices with the row and column ids
+   * as their attributes and edge values as 1.0.
+   */
+  def gridGraph(sc: SparkContext, rows: Int, cols: Int): Graph[(Int,Int), Double] = {
+    // Convert row column address into vertex ids (row major order)
+    def sub2ind(r: Int, c: Int): Vid = r * cols + c 
+
+    val vertices: RDD[(Vid, (Int,Int))] = 
+      sc.parallelize(0 until rows).flatMap( r => (0 until cols).map( c => (sub2ind(r,c), (r,c)) ) )
+    val edges: RDD[Edge[Double]] = 
+      vertices.flatMap{ case (vid, (r,c)) => 
+        (if (r+1 < rows) { Seq( (sub2ind(r, c), sub2ind(r+1, c))) } else { Seq.empty }) ++
+        (if (c+1 < cols) { Seq( (sub2ind(r, c), sub2ind(r, c+1))) } else { Seq.empty })
+      }.map{ case (src, dst) => Edge(src, dst, 1.0) }
+    Graph(vertices, edges)
+  } // end of gridGraph
+
+  /**
+   * Create a star graph with vertex 0 being the center.
+   * 
+   * @param sc the spark context in which to construct the graph
+   * @param the number of vertices in the star
+   *
+   * @return A star graph containing `nverts` vertices with vertex 0
+   * being the center vertex.
+   */
+  def starGraph(sc: SparkContext, nverts: Int): Graph[Int, Int] = {
+    val edges: RDD[(Vid, Vid)] = sc.parallelize(1 until nverts).map(vid => (vid, 0))
+    Graph(edges, false)
+  } // end of starGraph
+
+
+
+} // end of Graph Generators
 
 
 
-- 
GitLab