diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
index f8d54a8f738ce741693a70001bb385b9b5785389..e86d7ef76779867181ab3885c257864375623919 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -157,6 +157,16 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
   /** Return the value at the specified position. */
   def getValue(pos: Int): T = _data(pos)
 
+  def iterator() = new Iterator[T] {
+    var pos = nextPos(0)
+    override def hasNext: Boolean = pos != INVALID_POS
+    override def next(): T = {
+      val tmp = getValue(pos)
+      pos = nextPos(pos+1)
+      tmp
+    }
+  }
+
   /** Return the value at the specified position. */
   def getValueSafe(pos: Int): T = {
     assert(_bitset.get(pos))
diff --git a/graph/src/main/scala/org/apache/spark/graph/Graph.scala b/graph/src/main/scala/org/apache/spark/graph/Graph.scala
index acfdc4378b0a0cd612832e8e3860e53690215827..f5b4c57f72902167d0560ed98a1cafbd72b2a439 100644
--- a/graph/src/main/scala/org/apache/spark/graph/Graph.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/Graph.scala
@@ -1,7 +1,7 @@
 package org.apache.spark.graph
 
 import org.apache.spark.rdd.RDD
-
+import org.apache.spark.storage.StorageLevel
 
 /**
  * The Graph abstractly represents a graph with arbitrary objects
@@ -12,21 +12,21 @@ import org.apache.spark.rdd.RDD
  * operations return new graphs.
  *
  * @see GraphOps for additional graph member functions.
- * 
+ *
  * @note The majority of the graph operations are implemented in
  * `GraphOps`.  All the convenience operations are defined in the
  * `GraphOps` class which may be shared across multiple graph
  * implementations.
  *
  * @tparam VD the vertex attribute type
- * @tparam ED the edge attribute type 
+ * @tparam ED the edge attribute type
  */
 abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
 
   /**
    * Get the vertices and their data.
    *
-   * @note vertex ids are unique. 
+   * @note vertex ids are unique.
    * @return An RDD containing the vertices in this graph
    *
    * @see Vertex for the vertex type.
@@ -70,6 +70,11 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
    */
   val triplets: RDD[EdgeTriplet[VD, ED]]
 
+
+
+  def persist(newLevel: StorageLevel): Graph[VD, ED]
+
+
   /**
    * Return a graph that is cached when first created. This is used to
    * pin a graph in memory enabling multiple queries to reuse the same
@@ -100,7 +105,7 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
    * @tparam VD2 the new vertex data type
    *
    * @example We might use this operation to change the vertex values
-   * from one type to another to initialize an algorithm.   
+   * from one type to another to initialize an algorithm.
    * {{{
    * val rawGraph: Graph[(), ()] = Graph.textFile("hdfs://file")
    * val root = 42
@@ -190,7 +195,7 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
    * @return the subgraph containing only the vertices and edges that
    * satisfy the predicates.
    */
-  def subgraph(epred: EdgeTriplet[VD,ED] => Boolean = (x => true), 
+  def subgraph(epred: EdgeTriplet[VD,ED] => Boolean = (x => true),
     vpred: (Vid, VD) => Boolean = ((v,d) => true) ): Graph[VD, ED]
 
 
@@ -255,12 +260,12 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
    * @param reduceFunc the user defined reduce function which should
    * be commutative and assosciative and is used to combine the output
    * of the map phase.
-   * 
+   *
    * @example We can use this function to compute the inDegree of each
    * vertex
    * {{{
    * val rawGraph: Graph[(),()] = Graph.textFile("twittergraph")
-   * val inDeg: RDD[(Vid, Int)] = 
+   * val inDeg: RDD[(Vid, Int)] =
    *   mapReduceTriplets[Int](et => Array((et.dst.id, 1)), _ + _)
    * }}}
    *
@@ -269,12 +274,12 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
    * Graph API in that enables neighborhood level computation. For
    * example this function can be used to count neighbors satisfying a
    * predicate or implement PageRank.
-   * 
+   *
    */
   def mapReduceTriplets[A: ClassManifest](
       mapFunc: EdgeTriplet[VD, ED] => Array[(Vid, A)],
       reduceFunc: (A, A) => A)
-    : VertexSetRDD[A] 
+    : VertexSetRDD[A]
 
 
   /**
@@ -296,11 +301,11 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
    * @example This function is used to update the vertices with new
    * values based on external data.  For example we could add the out
    * degree to each vertex record
-   * 
+   *
    * {{{
    * val rawGraph: Graph[(),()] = Graph.textFile("webgraph")
    * val outDeg: RDD[(Vid, Int)] = rawGraph.outDegrees()
-   * val graph = rawGraph.outerJoinVertices(outDeg) { 
+   * val graph = rawGraph.outerJoinVertices(outDeg) {
    *   (vid, data, optDeg) => optDeg.getOrElse(0)
    * }
    * }}}
@@ -337,7 +342,7 @@ object Graph {
    * (i.e., the undirected degree).
    *
    * @param rawEdges the RDD containing the set of edges in the graph
-   * 
+   *
    * @return a graph with edge attributes containing the count of
    * duplicate edges and vertex attributes containing the total degree
    * of each vertex.
@@ -368,10 +373,10 @@ object Graph {
         rawEdges.map { case (s, t) => Edge(s, t, 1) }
       }
     // Determine unique vertices
-    /** @todo Should this reduceByKey operation be indexed? */ 
-    val vertices: RDD[(Vid, Int)] = 
+    /** @todo Should this reduceByKey operation be indexed? */
+    val vertices: RDD[(Vid, Int)] =
       edges.flatMap{ case Edge(s, t, cnt) => Array((s, 1), (t, 1)) }.reduceByKey(_ + _)
- 
+
     // Return graph
     GraphImpl(vertices, edges, 0)
   }
@@ -392,7 +397,7 @@ object Graph {
    *
    */
   def apply[VD: ClassManifest, ED: ClassManifest](
-      vertices: RDD[(Vid,VD)], 
+      vertices: RDD[(Vid,VD)],
       edges: RDD[Edge[ED]]): Graph[VD, ED] = {
     val defaultAttr: VD = null.asInstanceOf[VD]
     Graph(vertices, edges, defaultAttr, (a:VD,b:VD) => a)
@@ -416,7 +421,7 @@ object Graph {
    *
    */
   def apply[VD: ClassManifest, ED: ClassManifest](
-      vertices: RDD[(Vid,VD)], 
+      vertices: RDD[(Vid,VD)],
       edges: RDD[Edge[ED]],
       defaultVertexAttr: VD,
       mergeFunc: (VD, VD) => VD): Graph[VD, ED] = {
diff --git a/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala b/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala
index f65f96ed0c1b2dd62f1d5bf0b20b77cb060267af..82b9198e432c728179de848b2b8b1672e7dbbc4d 100644
--- a/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala
@@ -2,7 +2,7 @@ package org.apache.spark.graph
 
 import com.esotericsoftware.kryo.Kryo
 
-import org.apache.spark.graph.impl.{EdgePartition, MessageToPartition}
+import org.apache.spark.graph.impl._
 import org.apache.spark.serializer.KryoRegistrator
 import org.apache.spark.util.collection.BitSet
 
@@ -12,6 +12,8 @@ class GraphKryoRegistrator extends KryoRegistrator {
     kryo.register(classOf[Edge[Object]])
     kryo.register(classOf[MutableTuple2[Object, Object]])
     kryo.register(classOf[MessageToPartition[Object]])
+    kryo.register(classOf[VertexBroadcastMsg[Object]])
+    kryo.register(classOf[AggregationMsg[Object]])
     kryo.register(classOf[(Vid, Object)])
     kryo.register(classOf[EdgePartition[Object]])
     kryo.register(classOf[BitSet])
diff --git a/graph/src/main/scala/org/apache/spark/graph/Pregel.scala b/graph/src/main/scala/org/apache/spark/graph/Pregel.scala
index 501e593e917eae3cf4df6940b138d2ff9a5d2c0b..3b4d3c0df2a51ca178194f4623ff02e4baa960ec 100644
--- a/graph/src/main/scala/org/apache/spark/graph/Pregel.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/Pregel.scala
@@ -98,14 +98,14 @@ object Pregel {
     : Graph[VD, ED] = {
 
     // Receive the first set of messages
-    var g = graph.mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg))
+    var g = graph.mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg)).cache
 
     var i = 0
     while (i < numIter) {
       // compute the messages
       val messages = g.mapReduceTriplets(sendMsg, mergeMsg)
       // receive the messages
-      g = g.joinVertices(messages)(vprog)
+      g = g.joinVertices(messages)(vprog).cache
       // count the iteration
       i += 1
     }
diff --git a/graph/src/main/scala/org/apache/spark/graph/VertexSetRDD.scala b/graph/src/main/scala/org/apache/spark/graph/VertexSetRDD.scala
index 8611d2f0ce1a98da2adcfd08e6c232dd107b5904..62608e506d85b820275d9a5a627ddb7c05cf7173 100644
--- a/graph/src/main/scala/org/apache/spark/graph/VertexSetRDD.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/VertexSetRDD.scala
@@ -22,13 +22,14 @@ import org.apache.spark.SparkContext._
 import org.apache.spark.rdd._
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveKeyOpenHashMap}
-
+import org.apache.spark.graph.impl.AggregationMsg
+import org.apache.spark.graph.impl.MsgRDDFunctions._
 
 /**
  * The `VertexSetIndex` maintains the per-partition mapping from
  * vertex id to the corresponding location in the per-partition values
  * array.  This class is meant to be an opaque type.
- * 
+ *
  */
 class VertexSetIndex(private[spark] val rdd: RDD[VertexIdToIndexMap]) {
   /**
@@ -55,7 +56,7 @@ class VertexSetIndex(private[spark] val rdd: RDD[VertexIdToIndexMap]) {
  * In addition to providing the basic RDD[(Vid,V)] functionality the
  * VertexSetRDD exposes an index member which can be used to "key"
  * other VertexSetRDDs
- * 
+ *
  * @tparam V the vertex attribute associated with each vertex in the
  * set.
  *
@@ -84,7 +85,7 @@ class VertexSetIndex(private[spark] val rdd: RDD[VertexIdToIndexMap]) {
 class VertexSetRDD[@specialized V: ClassManifest](
     @transient val index:  VertexSetIndex,
     @transient val valuesRDD: RDD[ ( Array[V], BitSet) ])
-  extends RDD[(Vid, V)](index.rdd.context, 
+  extends RDD[(Vid, V)](index.rdd.context,
     List(new OneToOneDependency(index.rdd), new OneToOneDependency(valuesRDD)) ) {
 
 
@@ -100,32 +101,32 @@ class VertexSetRDD[@specialized V: ClassManifest](
    * An internal representation which joins the block indices with the values
    * This is used by the compute function to emulate RDD[(Vid, V)]
    */
-  protected[spark] val tuples = 
+  protected[spark] val tuples =
     new ZippedRDD(index.rdd.context, index.rdd, valuesRDD)
 
 
   /**
-   * The partitioner is defined by the index.  
+   * The partitioner is defined by the index.
    */
   override val partitioner = index.rdd.partitioner
-  
+
 
   /**
    * The actual partitions are defined by the tuples.
    */
-  override def getPartitions: Array[Partition] = tuples.getPartitions 
-  
+  override def getPartitions: Array[Partition] = tuples.getPartitions
+
 
   /**
-   * The preferred locations are computed based on the preferred 
-   * locations of the tuples. 
+   * The preferred locations are computed based on the preferred
+   * locations of the tuples.
    */
-  override def getPreferredLocations(s: Partition): Seq[String] = 
+  override def getPreferredLocations(s: Partition): Seq[String] =
     tuples.getPreferredLocations(s)
 
 
   /**
-   * Caching an VertexSetRDD causes the index and values to be cached separately. 
+   * Caching an VertexSetRDD causes the index and values to be cached separately.
    */
   override def persist(newLevel: StorageLevel): VertexSetRDD[V] = {
     index.persist(newLevel)
@@ -143,7 +144,7 @@ class VertexSetRDD[@specialized V: ClassManifest](
 
 
   /**
-   * Provide the RDD[(K,V)] equivalent output. 
+   * Provide the RDD[(K,V)] equivalent output.
    */
   override def compute(part: Partition, context: TaskContext): Iterator[(Vid, V)] = {
     tuples.compute(part, context).flatMap { case (indexMap, (values, bs) ) =>
@@ -154,19 +155,19 @@ class VertexSetRDD[@specialized V: ClassManifest](
 
   /**
    * Restrict the vertex set to the set of vertices satisfying the
-   * given predicate. 
-   * 
+   * given predicate.
+   *
    * @param pred the user defined predicate
    *
    * @note The vertex set preserves the original index structure
    * which means that the returned RDD can be easily joined with
-   * the original vertex-set.  Furthermore, the filter only 
-   * modifies the bitmap index and so no new values are allocated. 
+   * the original vertex-set.  Furthermore, the filter only
+   * modifies the bitmap index and so no new values are allocated.
    */
   override def filter(pred: Tuple2[Vid,V] => Boolean): VertexSetRDD[V] = {
     val cleanPred = index.rdd.context.clean(pred)
-    val newValues = index.rdd.zipPartitions(valuesRDD){ 
-      (keysIter: Iterator[VertexIdToIndexMap], 
+    val newValues = index.rdd.zipPartitions(valuesRDD){
+      (keysIter: Iterator[VertexIdToIndexMap],
        valuesIter: Iterator[(Array[V], BitSet)]) =>
       val index = keysIter.next()
       assert(keysIter.hasNext == false)
@@ -174,7 +175,7 @@ class VertexSetRDD[@specialized V: ClassManifest](
       assert(valuesIter.hasNext == false)
       // Allocate the array to store the results into
       val newBS = new BitSet(index.capacity)
-      // Iterate over the active bits in the old bitset and 
+      // Iterate over the active bits in the old bitset and
       // evaluate the predicate
       var ind = bs.nextSetBit(0)
       while(ind >= 0) {
@@ -193,7 +194,7 @@ class VertexSetRDD[@specialized V: ClassManifest](
   /**
    * Pass each vertex attribute through a map function and retain the
    * original RDD's partitioning and index.
-   * 
+   *
    * @tparam U the type returned by the map function
    *
    * @param f the function applied to each value in the RDD
@@ -204,12 +205,12 @@ class VertexSetRDD[@specialized V: ClassManifest](
   def mapValues[U: ClassManifest](f: V => U): VertexSetRDD[U] = {
     val cleanF = index.rdd.context.clean(f)
     val newValuesRDD: RDD[ (Array[U], BitSet) ] =
-      valuesRDD.mapPartitions(iter => iter.map{ 
+      valuesRDD.mapPartitions(iter => iter.map{
         case (values, bs: BitSet) =>
           val newValues = new Array[U](values.size)
           bs.iterator.foreach { ind => newValues(ind) = cleanF(values(ind)) }
           (newValues, bs)
-      }, preservesPartitioning = true)   
+      }, preservesPartitioning = true)
     new VertexSetRDD[U](index, newValuesRDD)
   } // end of mapValues
 
@@ -217,7 +218,7 @@ class VertexSetRDD[@specialized V: ClassManifest](
   /**
    * Pass each vertex attribute along with the vertex id through a map
    * function and retain the original RDD's partitioning and index.
-   * 
+   *
    * @tparam U the type returned by the map function
    *
    * @param f the function applied to each vertex id and vertex
@@ -229,8 +230,8 @@ class VertexSetRDD[@specialized V: ClassManifest](
   def mapValuesWithKeys[U: ClassManifest](f: (Vid, V) => U): VertexSetRDD[U] = {
     val cleanF = index.rdd.context.clean(f)
     val newValues: RDD[ (Array[U], BitSet) ] =
-      index.rdd.zipPartitions(valuesRDD){ 
-        (keysIter: Iterator[VertexIdToIndexMap], 
+      index.rdd.zipPartitions(valuesRDD){
+        (keysIter: Iterator[VertexIdToIndexMap],
          valuesIter: Iterator[(Array[V], BitSet)]) =>
         val index = keysIter.next()
         assert(keysIter.hasNext == false)
@@ -254,7 +255,7 @@ class VertexSetRDD[@specialized V: ClassManifest](
    * vertices that are in both this and the other vertex set.
    *
    * @tparam W the attribute type of the other VertexSet
-   * 
+   *
    * @param other the other VertexSet with which to join.
    * @return a VertexSetRDD containing only the vertices in both this
    * and the other VertexSet and with tuple attributes.
@@ -324,7 +325,7 @@ class VertexSetRDD[@specialized V: ClassManifest](
    * any vertex in this VertexSet then a `None` attribute is generated
    *
    * @tparam W the attribute type of the other VertexSet
-   * 
+   *
    * @param other the other VertexSet with which to join.
    * @return a VertexSetRDD containing all the vertices in this
    * VertexSet with `None` attributes used for Vertices missing in the
@@ -365,7 +366,7 @@ class VertexSetRDD[@specialized V: ClassManifest](
    * VertexSet then a `None` attribute is generated
    *
    * @tparam W the attribute type of the other VertexSet
-   * 
+   *
    * @param other the other VertexSet with which to join.
    * @param merge the function used combine duplicate vertex
    * attributes
@@ -398,28 +399,28 @@ class VertexSetRDD[@specialized V: ClassManifest](
 
 
   /**
-   * For each key k in `this` or `other`, return a resulting RDD that contains a 
+   * For each key k in `this` or `other`, return a resulting RDD that contains a
    * tuple with the list of values for that key in `this` as well as `other`.
    */
    /*
-  def cogroup[W: ClassManifest](other: RDD[(Vid, W)], partitioner: Partitioner): 
+  def cogroup[W: ClassManifest](other: RDD[(Vid, W)], partitioner: Partitioner):
   VertexSetRDD[(Seq[V], Seq[W])] = {
     //RDD[(K, (Seq[V], Seq[W]))] = {
     other match {
       case other: VertexSetRDD[_] if index == other.index => {
-        // if both RDDs share exactly the same index and therefore the same 
-        // super set of keys then we simply merge the value RDDs. 
-        // However it is possible that both RDDs are missing a value for a given key in 
+        // if both RDDs share exactly the same index and therefore the same
+        // super set of keys then we simply merge the value RDDs.
+        // However it is possible that both RDDs are missing a value for a given key in
         // which case the returned RDD should have a null value
-        val newValues: RDD[(IndexedSeq[(Seq[V], Seq[W])], BitSet)] = 
+        val newValues: RDD[(IndexedSeq[(Seq[V], Seq[W])], BitSet)] =
           valuesRDD.zipPartitions(other.valuesRDD){
-          (thisIter, otherIter) => 
+          (thisIter, otherIter) =>
             val (thisValues, thisBS) = thisIter.next()
             assert(!thisIter.hasNext)
             val (otherValues, otherBS) = otherIter.next()
             assert(!otherIter.hasNext)
-            /** 
-             * @todo consider implementing this with a view as in leftJoin to 
+            /**
+             * @todo consider implementing this with a view as in leftJoin to
              * reduce array allocations
              */
             val newValues = new Array[(Seq[V], Seq[W])](thisValues.size)
@@ -428,20 +429,20 @@ class VertexSetRDD[@specialized V: ClassManifest](
             var ind = newBS.nextSetBit(0)
             while(ind >= 0) {
               val a = if (thisBS.get(ind)) Seq(thisValues(ind)) else Seq.empty[V]
-              val b = if (otherBS.get(ind)) Seq(otherValues(ind)) else Seq.empty[W]   
+              val b = if (otherBS.get(ind)) Seq(otherValues(ind)) else Seq.empty[W]
               newValues(ind) = (a, b)
               ind = newBS.nextSetBit(ind+1)
             }
             Iterator((newValues.toIndexedSeq, newBS))
         }
-        new VertexSetRDD(index, newValues) 
+        new VertexSetRDD(index, newValues)
       }
-      case other: VertexSetRDD[_] 
+      case other: VertexSetRDD[_]
         if index.rdd.partitioner == other.index.rdd.partitioner => {
         // If both RDDs are indexed using different indices but with the same partitioners
         // then we we need to first merge the indicies and then use the merged index to
         // merge the values.
-        val newIndex = 
+        val newIndex =
           index.rdd.zipPartitions(other.index.rdd)(
             (thisIter, otherIter) => {
             val thisIndex = thisIter.next()
@@ -463,7 +464,7 @@ class VertexSetRDD[@specialized V: ClassManifest](
             List(newIndex).iterator
           }).cache()
         // Use the new index along with the this and the other indices to merge the values
-        val newValues: RDD[(IndexedSeq[(Seq[V], Seq[W])], BitSet)] = 
+        val newValues: RDD[(IndexedSeq[(Seq[V], Seq[W])], BitSet)] =
           newIndex.zipPartitions(tuples, other.tuples)(
             (newIndexIter, thisTuplesIter, otherTuplesIter) => {
               // Get the new index for this partition
@@ -507,7 +508,7 @@ class VertexSetRDD[@specialized V: ClassManifest](
           case None => throw new SparkException("An index must have a partitioner.")
         }
         // Shuffle the other RDD using the partitioner for this index
-        val otherShuffled = 
+        val otherShuffled =
           if (other.partitioner == Some(partitioner)) {
             other
           } else {
@@ -527,7 +528,7 @@ class VertexSetRDD[@specialized V: ClassManifest](
             // populate the newValues with the values in this VertexSetRDD
             for ((k,i) <- thisIndex) {
               if (thisBS.get(i)) {
-                newValues(i) = (Seq(thisValues(i)), ArrayBuffer.empty[W]) 
+                newValues(i) = (Seq(thisValues(i)), ArrayBuffer.empty[W])
                 newBS.set(i)
               }
             }
@@ -538,28 +539,28 @@ class VertexSetRDD[@specialized V: ClassManifest](
                 if(newBS.get(ind)) {
                   newValues(ind)._2.asInstanceOf[ArrayBuffer[W]].append(w)
                 } else {
-                  // If the other key was in the index but not in the values 
-                  // of this indexed RDD then create a new values entry for it 
+                  // If the other key was in the index but not in the values
+                  // of this indexed RDD then create a new values entry for it
                   newBS.set(ind)
                   newValues(ind) = (Seq.empty[V], ArrayBuffer(w))
-                }              
+                }
               } else {
                 // update the index
                 val ind = newIndex.size
                 newIndex.put(k, ind)
                 newBS.set(ind)
                 // Update the values
-                newValues.append( (Seq.empty[V], ArrayBuffer(w) ) ) 
+                newValues.append( (Seq.empty[V], ArrayBuffer(w) ) )
               }
             }
             Iterator( (newIndex, (newValues.toIndexedSeq, newBS)) )
           }).cache()
 
-        // Extract the index and values from the above RDD  
+        // Extract the index and values from the above RDD
         val newIndex = groups.mapPartitions(_.map{ case (kMap,vAr) => kMap }, true)
-        val newValues: RDD[(IndexedSeq[(Seq[V], Seq[W])], BitSet)] = 
+        val newValues: RDD[(IndexedSeq[(Seq[V], Seq[W])], BitSet)] =
           groups.mapPartitions(_.map{ case (kMap,vAr) => vAr }, true)
-          
+
         new VertexSetRDD[(Seq[V], Seq[W])](new VertexSetIndex(newIndex), newValues)
       }
     }
@@ -583,7 +584,7 @@ object VertexSetRDD {
    *
    * @param rdd the collection of vertex-attribute pairs
    */
-  def apply[V: ClassManifest](rdd: RDD[(Vid,V)]): VertexSetRDD[V] = 
+  def apply[V: ClassManifest](rdd: RDD[(Vid,V)]): VertexSetRDD[V] =
     apply(rdd, (a:V, b:V) => a )
 
   /**
@@ -591,7 +592,7 @@ object VertexSetRDD {
    * where duplicate entries are merged using the reduceFunc
    *
    * @tparam V the vertex attribute type
-   * 
+   *
    * @param rdd the collection of vertex-attribute pairs
    * @param reduceFunc the function used to merge attributes of
    * duplicate vertices.
@@ -602,12 +603,12 @@ object VertexSetRDD {
     // Preaggregate and shuffle if necessary
     val preAgg = rdd.partitioner match {
       case Some(p) => rdd
-      case None => 
+      case None =>
         val partitioner = new HashPartitioner(rdd.partitions.size)
         // Preaggregation.
         val aggregator = new Aggregator[Vid, V, V](v => v, cReduceFunc, cReduceFunc)
         rdd.mapPartitions(aggregator.combineValuesByKey, true).partitionBy(partitioner)
-    } 
+    }
 
     val groups = preAgg.mapPartitions( iter => {
       val hashMap = new PrimitiveKeyOpenHashMap[Vid, V]
@@ -629,8 +630,8 @@ object VertexSetRDD {
 
   /**
    * Construct a vertex set from an RDD using an existing index.
-   * 
-   * @note duplicate vertices are discarded arbitrarily 
+   *
+   * @note duplicate vertices are discarded arbitrarily
    *
    * @tparam V the vertex attribute type
    * @param rdd the rdd containing vertices
@@ -638,13 +639,13 @@ object VertexSetRDD {
    * in RDD
    */
   def apply[V: ClassManifest](
-    rdd: RDD[(Vid,V)], index: VertexSetIndex): VertexSetRDD[V] = 
+    rdd: RDD[(Vid,V)], index: VertexSetIndex): VertexSetRDD[V] =
     apply(rdd, index, (a:V,b:V) => a)
 
 
   /**
    * Construct a vertex set from an RDD using an existing index and a
-   * user defined `combiner` to merge duplicate vertices. 
+   * user defined `combiner` to merge duplicate vertices.
    *
    * @tparam V the vertex attribute type
    * @param rdd the rdd containing vertices
@@ -655,13 +656,50 @@ object VertexSetRDD {
    */
   def apply[V: ClassManifest](
     rdd: RDD[(Vid,V)], index: VertexSetIndex,
-    reduceFunc: (V, V) => V): VertexSetRDD[V] = 
+    reduceFunc: (V, V) => V): VertexSetRDD[V] =
     apply(rdd,index, (v:V) => v, reduceFunc, reduceFunc)
-  
+
+
+  def aggregate[V: ClassManifest](
+    rdd: RDD[AggregationMsg[V]], index: VertexSetIndex,
+    reduceFunc: (V, V) => V): VertexSetRDD[V] = {
+
+    val cReduceFunc = index.rdd.context.clean(reduceFunc)
+    assert(rdd.partitioner == index.rdd.partitioner)
+    // Use the index to build the new values table
+    val values: RDD[ (Array[V], BitSet) ] = index.rdd.zipPartitions(rdd)( (indexIter, tblIter) => {
+      // There is only one map
+      val index = indexIter.next()
+      assert(!indexIter.hasNext)
+      val values = new Array[V](index.capacity)
+      val bs = new BitSet(index.capacity)
+      for (msg <- tblIter) {
+        // Get the location of the key in the index
+        val pos = index.getPos(msg.vid)
+        if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
+          throw new SparkException("Error: Trying to bind an external index " +
+            "to an RDD which contains keys that are not in the index.")
+        } else {
+          // Get the actual index
+          val ind = pos & OpenHashSet.POSITION_MASK
+          // If this value has already been seen then merge
+          if (bs.get(ind)) {
+            values(ind) = cReduceFunc(values(ind), msg.data)
+          } else { // otherwise just store the new value
+            bs.set(ind)
+            values(ind) = msg.data
+          }
+        }
+      }
+      Iterator((values, bs))
+    })
+    new VertexSetRDD(index, values)
+  }
+
 
   /**
    * Construct a vertex set from an RDD using an existing index and a
-   * user defined `combiner` to merge duplicate vertices. 
+   * user defined `combiner` to merge duplicate vertices.
    *
    * @tparam V the vertex attribute type
    * @param rdd the rdd containing vertices
@@ -675,11 +713,11 @@ object VertexSetRDD {
    *
    */
   def apply[V: ClassManifest, C: ClassManifest](
-    rdd: RDD[(Vid,V)], 
-    index: VertexSetIndex,
-    createCombiner: V => C,
-    mergeValue: (C, V) => C,
-    mergeCombiners: (C, C) => C): VertexSetRDD[C] = {
+      rdd: RDD[(Vid,V)],
+      index: VertexSetIndex,
+      createCombiner: V => C,
+      mergeValue: (C, V) => C,
+      mergeCombiners: (C, C) => C): VertexSetRDD[C] = {
     val cCreateCombiner = index.rdd.context.clean(createCombiner)
     val cMergeValue = index.rdd.context.clean(mergeValue)
     val cMergeCombiners = index.rdd.context.clean(mergeCombiners)
@@ -689,7 +727,7 @@ object VertexSetRDD {
       case None => throw new SparkException("An index must have a partitioner.")
     }
     // Preaggregate and shuffle if necessary
-    val partitioned = 
+    val partitioned =
       if (rdd.partitioner != Some(partitioner)) {
         // Preaggregation.
         val aggregator = new Aggregator[Vid, V, C](cCreateCombiner, cMergeValue,
@@ -732,23 +770,23 @@ object VertexSetRDD {
 
   /**
    * Construct an index of the unique vertices.  The resulting index
-   * can be used to build VertexSets over subsets of the vertices in 
+   * can be used to build VertexSets over subsets of the vertices in
    * the input.
    */
-  def makeIndex(keys: RDD[Vid], 
+  def makeIndex(keys: RDD[Vid],
     partitioner: Option[Partitioner] = None): VertexSetIndex = {
     // @todo: I don't need the boolean its only there to be the second type since I want to shuffle a single RDD
-    // Ugly hack :-(.  In order to partition the keys they must have values. 
+    // Ugly hack :-(.  In order to partition the keys they must have values.
     val tbl = keys.mapPartitions(_.map(k => (k, false)), true)
     // Shuffle the table (if necessary)
     val shuffledTbl = partitioner match {
       case None =>  {
         if (tbl.partitioner.isEmpty) {
-          // @todo: I don't need the boolean its only there to be the second type of the shuffle. 
+          // @todo: I don't need the boolean its only there to be the second type of the shuffle.
           new ShuffledRDD[Vid, Boolean, (Vid, Boolean)](tbl, Partitioner.defaultPartitioner(tbl))
         } else { tbl }
       }
-      case Some(partitioner) => 
+      case Some(partitioner) =>
         tbl.partitionBy(partitioner)
     }
 
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
index c38780a265707b154a8389cf4613dfed9b1aae53..0d7546b57594cdbfe27c4dd2ab95e5e9c15569f6 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
@@ -13,6 +13,7 @@ import org.apache.spark.graph._
 import org.apache.spark.graph.impl.GraphImpl._
 import org.apache.spark.graph.impl.MsgRDDFunctions._
 import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveKeyOpenHashMap}
 
 
@@ -95,13 +96,17 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
   @transient override val triplets: RDD[EdgeTriplet[VD, ED]] =
     makeTriplets(localVidMap, vTableReplicatedValues, eTable)
 
-  override def cache(): Graph[VD, ED] = {
-    eTable.cache()
-    vid2pid.cache()
-    vTable.cache()
+  override def persist(newLevel: StorageLevel): Graph[VD, ED] = {
+    eTable.persist(newLevel)
+    vid2pid.persist(newLevel)
+    vTable.persist(newLevel)
+    localVidMap.persist(newLevel)
+    // vTableReplicatedValues.persist(newLevel)
     this
   }
 
+  override def cache(): Graph[VD, ED] = persist(StorageLevel.MEMORY_ONLY)
+
   override def statistics: Map[String, Any] = {
     val numVertices = this.numVertices
     val numEdges = this.numEdges
@@ -371,7 +376,7 @@ object GraphImpl {
       val vSet = new VertexSet
       edgePartition.foreach(e => {vSet.add(e.srcId); vSet.add(e.dstId)})
       vSet.iterator.map { vid => (vid.toLong, pid) }
-    }
+    }.partitionBy(vTableIndex.rdd.partitioner.get)
     VertexSetRDD[Pid, ArrayBuffer[Pid]](preAgg, vTableIndex,
       (p: Pid) => ArrayBuffer(p),
       (ab: ArrayBuffer[Pid], p:Pid) => {ab.append(p); ab},
@@ -508,7 +513,7 @@ object GraphImpl {
       }
     }.partitionBy(g.vTable.index.rdd.partitioner.get)
     // do the final reduction reusing the index map
-    VertexSetRDD(preAgg, g.vTable.index, reduceFunc)
+    VertexSetRDD.aggregate(preAgg, g.vTable.index, reduceFunc)
   }
 
   protected def edgePartitionFunction1D(src: Vid, dst: Vid, numParts: Pid): Pid = {
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala b/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala
index 3fc0b7c0f7588d3b2aff84868367edc7f45dcd39..d0a5adb85cd8a3d6bb28ff588e527465a0b3c3dc 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala
@@ -55,6 +55,8 @@ class VertexBroadcastMsgRDDFunctions[T: ClassManifest](self: RDD[VertexBroadcast
     // Set a custom serializer if the data is of int or double type.
     if (classManifest[T] == ClassManifest.Int) {
       rdd.setSerializer(classOf[IntVertexBroadcastMsgSerializer].getName)
+    } else if (classManifest[T] == ClassManifest.Long) {
+      rdd.setSerializer(classOf[LongVertexBroadcastMsgSerializer].getName)
     } else if (classManifest[T] == ClassManifest.Double) {
       rdd.setSerializer(classOf[DoubleVertexBroadcastMsgSerializer].getName)
     }
@@ -70,6 +72,8 @@ class AggregationMessageRDDFunctions[T: ClassManifest](self: RDD[AggregationMsg[
     // Set a custom serializer if the data is of int or double type.
     if (classManifest[T] == ClassManifest.Int) {
       rdd.setSerializer(classOf[IntAggMsgSerializer].getName)
+    } else if (classManifest[T] == ClassManifest.Long) {
+      rdd.setSerializer(classOf[LongAggMsgSerializer].getName)
     } else if (classManifest[T] == ClassManifest.Double) {
       rdd.setSerializer(classOf[DoubleAggMsgSerializer].getName)
     }
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala b/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala
index 8b4c0868b1a6dff011cab21938f60fc1632c6ece..54fd65e7381f2512a8c069bc91abf95f7d58c320 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala
@@ -27,6 +27,28 @@ class IntVertexBroadcastMsgSerializer extends Serializer {
   }
 }
 
+/** A special shuffle serializer for VertexBroadcastMessage[Long]. */
+class LongVertexBroadcastMsgSerializer extends Serializer {
+  override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+    override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+      def writeObject[T](t: T) = {
+        val msg = t.asInstanceOf[VertexBroadcastMsg[Long]]
+        writeLong(msg.vid)
+        writeLong(msg.data)
+        this
+      }
+    }
+
+    override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+      override def readObject[T](): T = {
+        val a = readLong()
+        val b = readLong()
+        new VertexBroadcastMsg[Long](0, a, b).asInstanceOf[T]
+      }
+    }
+  }
+}
 
 /** A special shuffle serializer for VertexBroadcastMessage[Double]. */
 class DoubleVertexBroadcastMsgSerializer extends Serializer {
@@ -43,7 +65,9 @@ class DoubleVertexBroadcastMsgSerializer extends Serializer {
 
     override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
       def readObject[T](): T = {
-        new VertexBroadcastMsg[Double](0, readLong(), readDouble()).asInstanceOf[T]
+        val a = readLong()
+        val b = readDouble()
+        new VertexBroadcastMsg[Double](0, a, b).asInstanceOf[T]
       }
     }
   }
@@ -65,7 +89,32 @@ class IntAggMsgSerializer extends Serializer {
 
     override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
       override def readObject[T](): T = {
-        new AggregationMsg[Int](readLong(), readInt()).asInstanceOf[T]
+        val a = readLong()
+        val b = readInt()
+        new AggregationMsg[Int](a, b).asInstanceOf[T]
+      }
+    }
+  }
+}
+
+/** A special shuffle serializer for AggregationMessage[Long]. */
+class LongAggMsgSerializer extends Serializer {
+  override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+    override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+      def writeObject[T](t: T) = {
+        val msg = t.asInstanceOf[AggregationMsg[Long]]
+        writeLong(msg.vid)
+        writeLong(msg.data)
+        this
+      }
+    }
+
+    override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+      override def readObject[T](): T = {
+        val a = readLong()
+        val b = readLong()
+        new AggregationMsg[Long](a, b).asInstanceOf[T]
       }
     }
   }
@@ -87,7 +136,9 @@ class DoubleAggMsgSerializer extends Serializer {
 
     override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
       def readObject[T](): T = {
-        new AggregationMsg[Double](readLong(), readDouble()).asInstanceOf[T]
+        val a = readLong()
+        val b = readDouble()
+        new AggregationMsg[Double](a, b).asInstanceOf[T]
       }
     }
   }
diff --git a/graph/src/main/scala/org/apache/spark/graph/package.scala b/graph/src/main/scala/org/apache/spark/graph/package.scala
index ee28d1429e017c8d9a5a1847665322dca9cfc01c..7b53e9cce82a3a7eb32ff7daa82015965d0e1cd2 100644
--- a/graph/src/main/scala/org/apache/spark/graph/package.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/package.scala
@@ -8,10 +8,9 @@ package object graph {
   type Vid = Long
   type Pid = Int
 
-  type VertexHashMap[T] = it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap[T]
-  type VertexSet = it.unimi.dsi.fastutil.longs.LongOpenHashSet
+  type VertexSet = OpenHashSet[Vid]
   type VertexArrayList = it.unimi.dsi.fastutil.longs.LongArrayList
-  
+
   //  type VertexIdToIndexMap = it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap
   type VertexIdToIndexMap = OpenHashSet[Vid]
 
diff --git a/graph/src/test/scala/org/apache/spark/graph/SerializerSuite.scala b/graph/src/test/scala/org/apache/spark/graph/SerializerSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..5a59fd912a519a7e45ba504e4c4f3ae1608fee2d
--- /dev/null
+++ b/graph/src/test/scala/org/apache/spark/graph/SerializerSuite.scala
@@ -0,0 +1,139 @@
+package org.apache.spark.graph
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.graph.LocalSparkContext._
+import java.io.ByteArrayInputStream
+import java.io.ByteArrayOutputStream
+import org.apache.spark.graph.impl._
+import org.apache.spark.graph.impl.MsgRDDFunctions._
+import org.apache.spark._
+
+
+class SerializerSuite extends FunSuite with LocalSparkContext {
+
+  System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+  System.setProperty("spark.kryo.registrator", "org.apache.spark.graph.GraphKryoRegistrator")
+
+  test("TestVertexBroadcastMessageInt") {
+    val outMsg = new VertexBroadcastMsg[Int](3,4,5)
+    val bout = new ByteArrayOutputStream
+    val outStrm = new IntVertexBroadcastMsgSerializer().newInstance().serializeStream(bout)
+    outStrm.writeObject(outMsg)
+    outStrm.writeObject(outMsg)
+    bout.flush
+    val bin = new ByteArrayInputStream(bout.toByteArray)
+    val inStrm = new IntVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin)
+    val inMsg1: VertexBroadcastMsg[Int] = inStrm.readObject()
+    val inMsg2: VertexBroadcastMsg[Int] = inStrm.readObject()
+    assert(outMsg.vid === inMsg1.vid)
+    assert(outMsg.vid === inMsg2.vid)
+    assert(outMsg.data === inMsg1.data)
+    assert(outMsg.data === inMsg2.data)
+  }
+
+  test("TestVertexBroadcastMessageLong") {
+    val outMsg = new VertexBroadcastMsg[Long](3,4,5)
+    val bout = new ByteArrayOutputStream
+    val outStrm = new LongVertexBroadcastMsgSerializer().newInstance().serializeStream(bout)
+    outStrm.writeObject(outMsg)
+    outStrm.writeObject(outMsg)
+    bout.flush
+    val bin = new ByteArrayInputStream(bout.toByteArray)
+    val inStrm = new LongVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin)
+    val inMsg1: VertexBroadcastMsg[Long] = inStrm.readObject()
+    val inMsg2: VertexBroadcastMsg[Long] = inStrm.readObject()
+    assert(outMsg.vid === inMsg1.vid)
+    assert(outMsg.vid === inMsg2.vid)
+    assert(outMsg.data === inMsg1.data)
+    assert(outMsg.data === inMsg2.data)
+  }
+
+  test("TestVertexBroadcastMessageDouble") {
+    val outMsg = new VertexBroadcastMsg[Double](3,4,5.0)
+    val bout = new ByteArrayOutputStream
+    val outStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().serializeStream(bout)
+    outStrm.writeObject(outMsg)
+    outStrm.writeObject(outMsg)
+    bout.flush
+    val bin = new ByteArrayInputStream(bout.toByteArray)
+    val inStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin)
+    val inMsg1: VertexBroadcastMsg[Double] = inStrm.readObject()
+    val inMsg2: VertexBroadcastMsg[Double] = inStrm.readObject()
+    assert(outMsg.vid === inMsg1.vid)
+    assert(outMsg.vid === inMsg2.vid)
+    assert(outMsg.data === inMsg1.data)
+    assert(outMsg.data === inMsg2.data)
+  }
+
+  test("TestAggregationMessageInt") {
+    val outMsg = new AggregationMsg[Int](4,5)
+    val bout = new ByteArrayOutputStream
+    val outStrm = new IntAggMsgSerializer().newInstance().serializeStream(bout)
+    outStrm.writeObject(outMsg)
+    outStrm.writeObject(outMsg)
+    bout.flush
+    val bin = new ByteArrayInputStream(bout.toByteArray)
+    val inStrm = new IntAggMsgSerializer().newInstance().deserializeStream(bin)
+    val inMsg1: AggregationMsg[Int] = inStrm.readObject()
+    val inMsg2: AggregationMsg[Int] = inStrm.readObject()
+    assert(outMsg.vid === inMsg1.vid)
+    assert(outMsg.vid === inMsg2.vid)
+    assert(outMsg.data === inMsg1.data)
+    assert(outMsg.data === inMsg2.data)
+  }
+
+  test("TestAggregationMessageLong") {
+    val outMsg = new AggregationMsg[Long](4,5)
+    val bout = new ByteArrayOutputStream
+    val outStrm = new LongAggMsgSerializer().newInstance().serializeStream(bout)
+    outStrm.writeObject(outMsg)
+    outStrm.writeObject(outMsg)
+    bout.flush
+    val bin = new ByteArrayInputStream(bout.toByteArray)
+    val inStrm = new LongAggMsgSerializer().newInstance().deserializeStream(bin)
+    val inMsg1: AggregationMsg[Long] = inStrm.readObject()
+    val inMsg2: AggregationMsg[Long] = inStrm.readObject()
+    assert(outMsg.vid === inMsg1.vid)
+    assert(outMsg.vid === inMsg2.vid)
+    assert(outMsg.data === inMsg1.data)
+    assert(outMsg.data === inMsg2.data)
+  }
+
+  test("TestAggregationMessageDouble") {
+    val outMsg = new AggregationMsg[Double](4,5.0)
+    val bout = new ByteArrayOutputStream
+    val outStrm = new DoubleAggMsgSerializer().newInstance().serializeStream(bout)
+    outStrm.writeObject(outMsg)
+    outStrm.writeObject(outMsg)
+    bout.flush
+    val bin = new ByteArrayInputStream(bout.toByteArray)
+    val inStrm = new DoubleAggMsgSerializer().newInstance().deserializeStream(bin)
+    val inMsg1: AggregationMsg[Double] = inStrm.readObject()
+    val inMsg2: AggregationMsg[Double] = inStrm.readObject()
+    assert(outMsg.vid === inMsg1.vid)
+    assert(outMsg.vid === inMsg2.vid)
+    assert(outMsg.data === inMsg1.data)
+    assert(outMsg.data === inMsg2.data)
+  }
+
+  test("TestShuffleVertexBroadcastMsg") {
+    withSpark(new SparkContext("local[2]", "test")) { sc =>
+      val bmsgs = sc.parallelize(
+        (0 until 100).map(pid => new VertexBroadcastMsg[Int](pid, pid, pid)), 10)
+      val partitioner = new HashPartitioner(3)
+      val bmsgsArray = bmsgs.partitionBy(partitioner).collect
+    }
+  }
+
+  test("TestShuffleAggregationMsg") {
+    withSpark(new SparkContext("local[2]", "test")) { sc =>
+      val bmsgs = sc.parallelize(
+        (0 until 100).map(pid => new AggregationMsg[Int](pid, pid)), 10)
+      val partitioner = new HashPartitioner(3)
+      val bmsgsArray = bmsgs.partitionBy(partitioner).collect
+    }
+  }
+
+}
\ No newline at end of file