diff --git a/graph/src/main/scala/org/apache/spark/graph/Graph.scala b/graph/src/main/scala/org/apache/spark/graph/Graph.scala
index 7d2a586037530a27da6499e1c46a4b6485010286..09a1af63a6713f3b6cc94553c59e03ae770ed39a 100644
--- a/graph/src/main/scala/org/apache/spark/graph/Graph.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/Graph.scala
@@ -227,12 +227,11 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
    * }}}
    *
    */
-  def aggregateNeighbors[VD2: ClassManifest](
-      mapFunc: (Vid, EdgeTriplet[VD, ED]) => Option[VD2],
-      mergeFunc: (VD2, VD2) => VD2,
+  def aggregateNeighbors[A: ClassManifest](
+      mapFunc: (Vid, EdgeTriplet[VD, ED]) => Option[A],
+      mergeFunc: (A, A) => A,
       direction: EdgeDirection)
-    : RDD[(Vid, VD2)]
-
+    : Graph[(VD, Option[A]), ED]
 
   /**
    * This function is used to compute a statistic for the neighborhood of each
@@ -276,12 +275,12 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
    * @todo Should this return a graph with the new vertex values?
    *
    */
-  def aggregateNeighbors[VD2: ClassManifest](
-      mapFunc: (Vid, EdgeTriplet[VD, ED]) => Option[VD2],
-      reduceFunc: (VD2, VD2) => VD2,
-      default: VD2, // Should this be a function or a value?
+  def aggregateNeighbors[A: ClassManifest](
+      mapFunc: (Vid, EdgeTriplet[VD, ED]) => Option[A],
+      reduceFunc: (A, A) => A,
+      default: A, // Should this be a function or a value?
       direction: EdgeDirection)
-    : RDD[(Vid, VD2)]
+    : Graph[(VD, Option[A]), ED]
 
 
   /**
diff --git a/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala b/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala
index 1dba813e91cd8bf40ce561e93b0b050b2d19afb6..01f24a13024c7344fc1050baf84a6ba815d5c48f 100644
--- a/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/GraphLab.scala
@@ -44,13 +44,13 @@ object GraphLab {
 
 
     // Add an active attribute to all vertices to track convergence.
-    var activeGraph = graph.mapVertices {
+    var activeGraph: Graph[(Boolean, VD), ED] = graph.mapVertices {
       case Vertex(id, data) => (true, data)
     }.cache()
 
     // The gather function wrapper strips the active attribute and
     // only invokes the gather function on active vertices
-    def gather(vid: Vid, e: EdgeTriplet[(Boolean, VD), ED]) = {
+    def gather(vid: Vid, e: EdgeTriplet[(Boolean, VD), ED]): Option[A] = {
       if (e.vertex(vid).data._1) {
         val edge = new EdgeTriplet[VD,ED]
         edge.src = Vertex(e.src.id, e.src.data._2)
@@ -64,14 +64,15 @@ object GraphLab {
 
     // The apply function wrapper strips the vertex of the active attribute
     // and only invokes the apply function on active vertices
-    def apply(v: Vertex[(Boolean, VD)], accum: Option[A]) = {
-      if (v.data._1) (true, applyFunc(Vertex(v.id, v.data._2), accum))
-      else (false, v.data._2)
+    def apply(v: Vertex[((Boolean, VD), Option[A])]): (Boolean, VD) = {
+      val ((active, vData), accum) = v.data
+      if (active) (true, applyFunc(Vertex(v.id, vData), accum))
+      else (false, vData)
     }
 
     // The scatter function wrapper strips the vertex of the active attribute
     // and only invokes the scatter function on active vertices
-    def scatter(rawVid: Vid, e: EdgeTriplet[(Boolean, VD), ED]) = {
+    def scatter(rawVid: Vid, e: EdgeTriplet[(Boolean, VD), ED]): Option[Boolean] = {
       val vid = e.otherVertex(rawVid).id
       if (e.vertex(vid).data._1) {
         val edge = new EdgeTriplet[VD,ED]
@@ -88,24 +89,31 @@ object GraphLab {
     }
 
     // Used to set the active status of vertices for the next round
-    def applyActive(v: Vertex[(Boolean, VD)], accum: Option[Boolean]) =
-      (accum.getOrElse(false), v.data._2)
+    def applyActive(v: Vertex[((Boolean, VD), Option[Boolean])]): (Boolean, VD) = {
+      val ((prevActive, vData), newActive) = v.data
+      (newActive.getOrElse(false), vData)
+    }
 
     // Main Loop ---------------------------------------------------------------------
     var i = 0
     var numActive = activeGraph.numVertices
     while (i < numIter && numActive > 0) {
 
-      val accUpdates: RDD[(Vid, A)] =
+      val gathered: Graph[((Boolean, VD), Option[A]), ED] =
         activeGraph.aggregateNeighbors(gather, mergeFunc, gatherDirection)
 
-      activeGraph = activeGraph.leftJoinVertices(accUpdates, apply).cache()
+      val applied: Graph[(Boolean, VD), ED] = gathered.mapVertices(apply).cache()
+
+      activeGraph = applied.cache()
 
       // Scatter is basically a gather in the opposite direction so we reverse the edge direction
-      val activeVertices: RDD[(Vid, Boolean)] =
+      // activeGraph: Graph[(Boolean, VD), ED]
+      val scattered: Graph[((Boolean, VD), Option[Boolean]), ED] =
         activeGraph.aggregateNeighbors(scatter, _ || _, scatterDirection.reverse)
+      val newActiveGraph: Graph[(Boolean, VD), ED] =
+        scattered.mapVertices(applyActive)
 
-      activeGraph = activeGraph.leftJoinVertices(activeVertices, applyActive).cache()
+      activeGraph = newActiveGraph.cache()
 
       numActive = activeGraph.vertices.map(v => if (v.data._1) 1 else 0).reduce(_ + _)
       println("Number active vertices: " + numActive)
diff --git a/graph/src/main/scala/org/apache/spark/graph/GraphOps.scala b/graph/src/main/scala/org/apache/spark/graph/GraphOps.scala
index 8de96680b80865d6a1975f7f5d87d98b99d2fe52..9e8cc0a6d52a7ec7d9ef37b5d289463977b37b38 100644
--- a/graph/src/main/scala/org/apache/spark/graph/GraphOps.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/GraphOps.scala
@@ -9,22 +9,29 @@ class GraphOps[VD: ClassManifest, ED: ClassManifest](g: Graph[VD, ED]) {
 
   lazy val numVertices: Long = g.vertices.count()
 
-  lazy val inDegrees: RDD[(Vid, Int)] = {
-    g.aggregateNeighbors((vid, edge) => Some(1), _+_, EdgeDirection.In)
-  }
+  lazy val inDegrees: RDD[(Vid, Int)] = degreesRDD(EdgeDirection.In)
 
-  lazy val outDegrees: RDD[(Vid, Int)] = {
-    g.aggregateNeighbors((vid, edge) => Some(1), _+_, EdgeDirection.Out)
-  }
+  lazy val outDegrees: RDD[(Vid, Int)] = degreesRDD(EdgeDirection.Out)
 
-  lazy val degrees: RDD[(Vid, Int)] = {
-    g.aggregateNeighbors((vid, edge) => Some(1), _+_, EdgeDirection.Both)
-  }
+  lazy val degrees: RDD[(Vid, Int)] = degreesRDD(EdgeDirection.Both)
 
   def collectNeighborIds(edgeDirection: EdgeDirection) : RDD[(Vid, Array[Vid])] = {
-    g.aggregateNeighbors(
+    val graph: Graph[(VD, Option[Array[Vid]]), ED] = g.aggregateNeighbors(
       (vid, edge) => Some(Array(edge.otherVertex(vid).id)),
       (a, b) => a ++ b,
       edgeDirection)
+    graph.vertices.map(v => {
+      val (_, neighborIds) = v.data
+      (v.id, neighborIds.getOrElse(Array()))
+    })
+  }
+
+  private def degreesRDD(edgeDirection: EdgeDirection): RDD[(Vid, Int)] = {
+    val degreeGraph: Graph[(VD, Option[Int]), ED] =
+      g.aggregateNeighbors((vid, edge) => Some(1), _+_, edgeDirection)
+    degreeGraph.vertices.map(v => {
+      val (_, degree) = v.data
+      (v.id, degree.getOrElse(0))
+    })
   }
 }
diff --git a/graph/src/main/scala/org/apache/spark/graph/Pregel.scala b/graph/src/main/scala/org/apache/spark/graph/Pregel.scala
index 27b75a7988013e3e206bf067daff4c68916d9975..09bcc67c8ced9c384cc5fe759856f4c8f087ca1c 100644
--- a/graph/src/main/scala/org/apache/spark/graph/Pregel.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/Pregel.scala
@@ -19,18 +19,25 @@ object Pregel {
 
     def mapF(vid: Vid, edge: EdgeTriplet[VD,ED]) = sendMsg(edge.otherVertex(vid).id, edge)
 
-    def runProg(v: Vertex[VD], msg: Option[A]): VD = {
-      if (msg.isEmpty) v.data else vprog(v, msg.get)
+    def runProg(vertexWithMsgs: Vertex[(VD, Option[A])]): VD = {
+      val (vData, msg) = vertexWithMsgs.data
+      val v = Vertex(vertexWithMsgs.id, vData)
+      msg match {
+        case Some(m) => vprog(v, m)
+        case None => v.data
+      }
     }
 
-    var msgs: RDD[(Vid, A)] = g.vertices.map{ v => (v.id, initialMsg) }
+    var graphWithMsgs: Graph[(VD, Option[A]), ED] =
+      g.mapVertices(v => (v.data, Some(initialMsg)))
 
     while (i < numIter) {
-      g = g.leftJoinVertices(msgs, runProg).cache()
-      msgs = g.aggregateNeighbors(mapF, mergeMsg, EdgeDirection.In)
+      val newGraph: Graph[VD, ED] = graphWithMsgs.mapVertices(runProg).cache()
+      graphWithMsgs = newGraph.aggregateNeighbors(mapF, mergeMsg, EdgeDirection.In)
       i += 1
     }
-    g
+    graphWithMsgs.mapVertices(vertexWithMsgs => vertexWithMsgs.data match {
+      case (vData, _) => vData
+    })
   }
-
 }
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
index 68ac9f724c3bb0f4acbd8e88d3a8261cf789321b..e397293a3d63d623556b471e372f1280d0a18c32 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
@@ -122,112 +122,118 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
   // Lower level transformation methods
   //////////////////////////////////////////////////////////////////////////////////////////////////
 
-  override def aggregateNeighbors[VD2: ClassManifest](
-      mapFunc: (Vid, EdgeTriplet[VD, ED]) => Option[VD2],
-      reduceFunc: (VD2, VD2) => VD2,
-      default: VD2,
+  override def aggregateNeighbors[A: ClassManifest](
+      mapFunc: (Vid, EdgeTriplet[VD, ED]) => Option[A],
+      reduceFunc: (A, A) => A,
+      default: A,
       gatherDirection: EdgeDirection)
-    : RDD[(Vid, VD2)] = {
+    : Graph[(VD, Option[A]), ED] = {
 
     ClosureCleaner.clean(mapFunc)
     ClosureCleaner.clean(reduceFunc)
 
     val newVTable = vTableReplicated.mapPartitions({ part =>
-        part.map { v => (v._1, MutableTuple2(v._2, Option.empty[VD2])) }
+        part.map { v => (v._1, MutableTuple2(v._2, Option.empty[A])) }
       }, preservesPartitioning = true)
 
-    new EdgeTripletRDD[MutableTuple2[VD, Option[VD2]], ED](newVTable, eTable)
-      .mapPartitions { part =>
-        val (vmap, edges) = part.next()
-        val edgeSansAcc = new EdgeTriplet[VD, ED]()
-        edgeSansAcc.src = new Vertex[VD]
-        edgeSansAcc.dst = new Vertex[VD]
-        edges.foreach { e: EdgeTriplet[MutableTuple2[VD, Option[VD2]], ED] =>
-          edgeSansAcc.data = e.data
-          edgeSansAcc.src.data = e.src.data._1
-          edgeSansAcc.dst.data = e.dst.data._1
-          edgeSansAcc.src.id = e.src.id
-          edgeSansAcc.dst.id = e.dst.id
-          if (gatherDirection == EdgeDirection.In || gatherDirection == EdgeDirection.Both) {
-            e.dst.data._2 =
-              if (e.dst.data._2.isEmpty) {
-                mapFunc(edgeSansAcc.dst.id, edgeSansAcc)
-              } else {
-                val tmp = mapFunc(edgeSansAcc.dst.id, edgeSansAcc)
-                if (!tmp.isEmpty) Some(reduceFunc(e.dst.data._2.get, tmp.get)) else e.dst.data._2
-              }
+    val newVertices: RDD[(Vid, A)] =
+      new EdgeTripletRDD[MutableTuple2[VD, Option[A]], ED](newVTable, eTable)
+        .mapPartitions { part =>
+          val (vmap, edges) = part.next()
+          val edgeSansAcc = new EdgeTriplet[VD, ED]()
+          edgeSansAcc.src = new Vertex[VD]
+          edgeSansAcc.dst = new Vertex[VD]
+          edges.foreach { e: EdgeTriplet[MutableTuple2[VD, Option[A]], ED] =>
+            edgeSansAcc.data = e.data
+            edgeSansAcc.src.data = e.src.data._1
+            edgeSansAcc.dst.data = e.dst.data._1
+            edgeSansAcc.src.id = e.src.id
+            edgeSansAcc.dst.id = e.dst.id
+            if (gatherDirection == EdgeDirection.In || gatherDirection == EdgeDirection.Both) {
+              e.dst.data._2 =
+                if (e.dst.data._2.isEmpty) {
+                  mapFunc(edgeSansAcc.dst.id, edgeSansAcc)
+                } else {
+                  val tmp = mapFunc(edgeSansAcc.dst.id, edgeSansAcc)
+                  if (!tmp.isEmpty) Some(reduceFunc(e.dst.data._2.get, tmp.get)) else e.dst.data._2
+                }
+            }
+            if (gatherDirection == EdgeDirection.Out || gatherDirection == EdgeDirection.Both) {
+              e.dst.data._2 =
+                if (e.dst.data._2.isEmpty) {
+                  mapFunc(edgeSansAcc.src.id, edgeSansAcc)
+                } else {
+                  val tmp = mapFunc(edgeSansAcc.src.id, edgeSansAcc)
+                  if (!tmp.isEmpty) Some(reduceFunc(e.src.data._2.get, tmp.get)) else e.src.data._2
+                }
+            }
           }
-          if (gatherDirection == EdgeDirection.Out || gatherDirection == EdgeDirection.Both) {
-            e.dst.data._2 =
-              if (e.dst.data._2.isEmpty) {
-                mapFunc(edgeSansAcc.src.id, edgeSansAcc)
-              } else {
-                val tmp = mapFunc(edgeSansAcc.src.id, edgeSansAcc)
-                if (!tmp.isEmpty) Some(reduceFunc(e.src.data._2.get, tmp.get)) else e.src.data._2
-              }
+          vmap.long2ObjectEntrySet().fastIterator().filter(!_.getValue()._2.isEmpty).map{ entry =>
+            (entry.getLongKey(), entry.getValue()._2)
           }
         }
-        vmap.long2ObjectEntrySet().fastIterator().filter(!_.getValue()._2.isEmpty).map{ entry =>
-          (entry.getLongKey(), entry.getValue()._2)
-        }
-      }
-      .map{ case (vid, aOpt) => (vid, aOpt.get) }
-      .combineByKey((v: VD2) => v, reduceFunc, null, vertexPartitioner, false)
+        .map{ case (vid, aOpt) => (vid, aOpt.get) }
+        .combineByKey((v: A) => v, reduceFunc, null, vertexPartitioner, false)
+
+    this.leftJoinVertices(newVertices, (v: Vertex[VD], a: Option[A]) => (v.data, a))
   }
 
   /**
    * Same as aggregateNeighbors but map function can return none and there is no default value.
    * As a consequence, the resulting table may be much smaller than the set of vertices.
    */
-  override def aggregateNeighbors[VD2: ClassManifest](
-    mapFunc: (Vid, EdgeTriplet[VD, ED]) => Option[VD2],
-    reduceFunc: (VD2, VD2) => VD2,
-    gatherDirection: EdgeDirection): RDD[(Vid, VD2)] = {
+  override def aggregateNeighbors[A: ClassManifest](
+    mapFunc: (Vid, EdgeTriplet[VD, ED]) => Option[A],
+    reduceFunc: (A, A) => A,
+    gatherDirection: EdgeDirection): Graph[(VD, Option[A]), ED] = {
 
     ClosureCleaner.clean(mapFunc)
     ClosureCleaner.clean(reduceFunc)
 
     val newVTable = vTableReplicated.mapPartitions({ part =>
-        part.map { v => (v._1, MutableTuple2(v._2, Option.empty[VD2])) }
+        part.map { v => (v._1, MutableTuple2(v._2, Option.empty[A])) }
       }, preservesPartitioning = true)
 
-    new EdgeTripletRDD[MutableTuple2[VD, Option[VD2]], ED](newVTable, eTable)
-      .mapPartitions { part =>
-        val (vmap, edges) = part.next()
-        val edgeSansAcc = new EdgeTriplet[VD, ED]()
-        edgeSansAcc.src = new Vertex[VD]
-        edgeSansAcc.dst = new Vertex[VD]
-        edges.foreach { e: EdgeTriplet[MutableTuple2[VD, Option[VD2]], ED] =>
-          edgeSansAcc.data = e.data
-          edgeSansAcc.src.data = e.src.data._1
-          edgeSansAcc.dst.data = e.dst.data._1
-          edgeSansAcc.src.id = e.src.id
-          edgeSansAcc.dst.id = e.dst.id
-          if (gatherDirection == EdgeDirection.In || gatherDirection == EdgeDirection.Both) {
-            e.dst.data._2 =
-              if (e.dst.data._2.isEmpty) {
-                mapFunc(edgeSansAcc.dst.id, edgeSansAcc)
-              } else {
-                val tmp = mapFunc(edgeSansAcc.dst.id, edgeSansAcc)
-                if (!tmp.isEmpty) Some(reduceFunc(e.dst.data._2.get, tmp.get)) else e.dst.data._2
-              }
+    val newVertices: RDD[(Vid, A)] =
+      new EdgeTripletRDD[MutableTuple2[VD, Option[A]], ED](newVTable, eTable)
+        .mapPartitions { part =>
+          val (vmap, edges) = part.next()
+          val edgeSansAcc = new EdgeTriplet[VD, ED]()
+          edgeSansAcc.src = new Vertex[VD]
+          edgeSansAcc.dst = new Vertex[VD]
+          edges.foreach { e: EdgeTriplet[MutableTuple2[VD, Option[A]], ED] =>
+            edgeSansAcc.data = e.data
+            edgeSansAcc.src.data = e.src.data._1
+            edgeSansAcc.dst.data = e.dst.data._1
+            edgeSansAcc.src.id = e.src.id
+            edgeSansAcc.dst.id = e.dst.id
+            if (gatherDirection == EdgeDirection.In || gatherDirection == EdgeDirection.Both) {
+              e.dst.data._2 =
+                if (e.dst.data._2.isEmpty) {
+                  mapFunc(edgeSansAcc.dst.id, edgeSansAcc)
+                } else {
+                  val tmp = mapFunc(edgeSansAcc.dst.id, edgeSansAcc)
+                  if (!tmp.isEmpty) Some(reduceFunc(e.dst.data._2.get, tmp.get)) else e.dst.data._2
+                }
+            }
+            if (gatherDirection == EdgeDirection.Out || gatherDirection == EdgeDirection.Both) {
+              e.src.data._2 =
+                if (e.src.data._2.isEmpty) {
+                  mapFunc(edgeSansAcc.src.id, edgeSansAcc)
+                } else {
+                  val tmp = mapFunc(edgeSansAcc.src.id, edgeSansAcc)
+                  if (!tmp.isEmpty) Some(reduceFunc(e.src.data._2.get, tmp.get)) else e.src.data._2
+                }
+            }
           }
-          if (gatherDirection == EdgeDirection.Out || gatherDirection == EdgeDirection.Both) {
-            e.src.data._2 =
-              if (e.src.data._2.isEmpty) {
-                mapFunc(edgeSansAcc.src.id, edgeSansAcc)
-              } else {
-                val tmp = mapFunc(edgeSansAcc.src.id, edgeSansAcc)
-                if (!tmp.isEmpty) Some(reduceFunc(e.src.data._2.get, tmp.get)) else e.src.data._2
-              }
+          vmap.long2ObjectEntrySet().fastIterator().filter(!_.getValue()._2.isEmpty).map{ entry =>
+            (entry.getLongKey(), entry.getValue()._2)
           }
         }
-        vmap.long2ObjectEntrySet().fastIterator().filter(!_.getValue()._2.isEmpty).map{ entry =>
-          (entry.getLongKey(), entry.getValue()._2)
-        }
-      }
-      .map{ case (vid, aOpt) => (vid, aOpt.get) }
-      .combineByKey((v: VD2) => v, reduceFunc, null, vertexPartitioner, false)
+        .map{ case (vid, aOpt) => (vid, aOpt.get) }
+        .combineByKey((v: A) => v, reduceFunc, null, vertexPartitioner, false)
+
+    this.leftJoinVertices(newVertices, (v: Vertex[VD], a: Option[A]) => (v.data, a))
   }
 
   override def leftJoinVertices[U: ClassManifest, VD2: ClassManifest](
diff --git a/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala b/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala
index 6f070aac5917b14a0846e7ad70d4cf726bf98e84..aa885de957939f9727f545f706db30584392c056 100644
--- a/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala
+++ b/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala
@@ -20,7 +20,27 @@ class GraphSuite extends FunSuite with LocalSparkContext {
   }
 
   test("aggregateNeighbors") {
+    withSpark(new SparkContext("local", "test")) { sc =>
+      val star = Graph(sc.parallelize(List((0, 1), (0, 2), (0, 3))))
+
+      val indegrees = star.aggregateNeighbors(
+        (vid, edge) => Some(1),
+        (a: Int, b: Int) => a + b,
+        EdgeDirection.In).vertices.map(v => (v.id, v.data._2.getOrElse(0)))
+      assert(indegrees.collect().toSet === Set((0, 0), (1, 1), (2, 1), (3, 1)))
 
+      val outdegrees = star.aggregateNeighbors(
+        (vid, edge) => Some(1),
+        (a: Int, b: Int) => a + b,
+        EdgeDirection.Out).vertices.map(v => (v.id, v.data._2.getOrElse(0)))
+      assert(outdegrees.collect().toSet === Set((0, 3), (1, 0), (2, 0), (3, 0)))
+
+      val noVertexValues = star.aggregateNeighbors[Int](
+        (vid: Vid, edge: EdgeTriplet[Int, Int]) => None,
+        (a: Int, b: Int) => throw new Exception("reduceFunc called unexpectedly"),
+        EdgeDirection.In).vertices.map(v => (v.id, v.data._2))
+      assert(noVertexValues.collect().toSet === Set((0, None), (1, None), (2, None), (3, None)))
+    }
   }
 
  /* test("joinVertices") {