diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
index b88c952feb65d426ab8fafa8fdbd6304d91968c4..d0df35d4226f724b3b419653b7277894261d7015 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
@@ -5,7 +5,6 @@ import scala.collection.JavaConversions._
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
-
 import org.apache.spark.SparkContext._
 import org.apache.spark.HashPartitioner 
 import org.apache.spark.util.ClosureCleaner
@@ -72,8 +71,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
 
   def this() = this(null, null, null, null)
 
-
-
   /**
    * (localVidMap: VertexSetRDD[Pid, VertexIdToIndexMap]) is a version of the
    * vertex data after it is replicated. Within each partition, it holds a map
@@ -86,22 +83,18 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
   @transient val vTableReplicatedValues: RDD[(Pid, Array[VD])] =
     createVTableReplicated(vTable, vid2pid, localVidMap)
 
-
   /** Return a RDD of vertices. */
   @transient override val vertices = vTable
 
-
   /** Return a RDD of edges. */
   @transient override val edges: RDD[Edge[ED]] = {
     eTable.mapPartitions( iter => iter.next()._2.iterator , true )
   }
 
-
   /** Return a RDD that brings edges with its source and destination vertices together. */
   @transient override val triplets: RDD[EdgeTriplet[VD, ED]] =
     makeTriplets(localVidMap, vTableReplicatedValues, eTable)
 
-
   override def cache(): Graph[VD, ED] = {
     eTable.cache()
     vid2pid.cache()
@@ -109,7 +102,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     this
   }
 
-
   override def statistics: Map[String, Any] = {
     val numVertices = this.numVertices
     val numEdges = this.numEdges
@@ -125,7 +117,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
       "Min Load" -> minLoad, "Max Load" -> maxLoad) 
   }
 
-
   /**
    * Display the lineage information for this graph.
    */
@@ -183,14 +174,12 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     println(visited)
   } // end of print lineage
 
-
   override def reverse: Graph[VD, ED] = {
     val newEtable = eTable.mapPartitions( _.map{ case (pid, epart) => (pid, epart.reverse) }, 
       preservesPartitioning = true)
     new GraphImpl(vTable, vid2pid, localVidMap, newEtable)
   }
 
-
   override def mapVertices[VD2: ClassManifest](f: (Vid, VD) => VD2): Graph[VD2, ED] = {
     val newVTable = vTable.mapValuesWithKeys((vid, data) => f(vid, data))
     new GraphImpl(newVTable, vid2pid, localVidMap, eTable)
@@ -202,11 +191,9 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     new GraphImpl(vTable, vid2pid, localVidMap, newETable)
   }
 
-
   override def mapTriplets[ED2: ClassManifest](f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] =
     GraphImpl.mapTriplets(this, f)
 
-
   override def subgraph(epred: EdgeTriplet[VD,ED] => Boolean = (x => true), 
     vpred: (Vid, VD) => Boolean = ((a,b) => true) ): Graph[VD, ED] = {
 
@@ -246,7 +233,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     new GraphImpl(newVTable, newVid2Pid, localVidMap, newETable)
   }
 
-
   override def groupEdgeTriplets[ED2: ClassManifest](
     f: Iterator[EdgeTriplet[VD,ED]] => ED2 ): Graph[VD,ED2] = {
       val newEdges: RDD[Edge[ED2]] = triplets.mapPartitions { partIter =>
@@ -271,7 +257,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
       new GraphImpl(vTable, vid2pid, localVidMap, newETable)
   }
 
-
   override def groupEdges[ED2: ClassManifest](f: Iterator[Edge[ED]] => ED2 ):
     Graph[VD,ED2] = {
 
@@ -289,8 +274,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
       new GraphImpl(vTable, vid2pid, localVidMap, newETable)
   }
 
-
-
   //////////////////////////////////////////////////////////////////////////////////////////////////
   // Lower level transformation methods
   //////////////////////////////////////////////////////////////////////////////////////////////////
@@ -301,7 +284,6 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     : VertexSetRDD[A] = 
     GraphImpl.mapReduceTriplets(this, mapFunc, reduceFunc)
 
-
   override def outerJoinVertices[U: ClassManifest, VD2: ClassManifest]
     (updates: RDD[(Vid, U)])(updateF: (Vid, VD, Option[U]) => VD2)
     : Graph[VD2, ED] = {
@@ -309,15 +291,9 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     val newVTable = vTable.leftJoin(updates)(updateF)
     new GraphImpl(newVTable, vid2pid, localVidMap, eTable)
   }
-
-
 } // end of class GraphImpl
 
 
-
-
-
-
 object GraphImpl {
 
   def apply[VD: ClassManifest, ED: ClassManifest](
@@ -327,7 +303,6 @@ object GraphImpl {
     apply(vertices, edges, defaultVertexAttr, (a:VD, b:VD) => a)
   }
 
-
   def apply[VD: ClassManifest, ED: ClassManifest](
     vertices: RDD[(Vid, VD)], 
     edges: RDD[Edge[ED]],
@@ -353,7 +328,6 @@ object GraphImpl {
     new GraphImpl(vtable, vid2pid, localVidMap, etable)
   }
 
-
   /**
    * Create the edge table RDD, which is much more efficient for Java heap storage than the
    * normal edges data structure (RDD[(Vid, Vid, ED)]).
@@ -389,7 +363,6 @@ object GraphImpl {
     }, preservesPartitioning = true).cache()
   }
 
-
   protected def createVid2Pid[ED: ClassManifest](
     eTable: RDD[(Pid, EdgePartition[ED])],
     vTableIndex: VertexSetIndex): VertexSetRDD[Array[Pid]] = {
@@ -406,7 +379,6 @@ object GraphImpl {
       .mapValues(a => a.toArray).cache()
   }
 
-
   protected def createLocalVidMap[ED: ClassManifest](eTable: RDD[(Pid, EdgePartition[ED])]): 
     RDD[(Pid, VertexIdToIndexMap)] = {
     eTable.mapPartitions( _.map{ case (pid, epart) =>
@@ -419,7 +391,6 @@ object GraphImpl {
     }, preservesPartitioning = true).cache()
   }
 
-
   protected def createVTableReplicated[VD: ClassManifest](
       vTable: VertexSetRDD[VD], 
       vid2pid: VertexSetRDD[Array[Pid]],
@@ -428,9 +399,9 @@ object GraphImpl {
     // Join vid2pid and vTable, generate a shuffle dependency on the joined 
     // result, and get the shuffle id so we can use it on the slave.
     val msgsByPartition = vTable.zipJoinFlatMap(vid2pid) { (vid, vdata, pids) =>
-      pids.iterator.map { pid => MessageToPartition(pid, (vid, vdata)) }
+      pids.iterator.map { pid => new VertexMessage[VD](pid, vid, vdata) }
     }.partitionBy(replicationMap.partitioner.get).cache()
-   
+
     replicationMap.zipPartitions(msgsByPartition){ 
       (mapIter, msgsIter) =>
       val (pid, vidToIndex) = mapIter.next()
@@ -438,8 +409,8 @@ object GraphImpl {
       // Populate the vertex array using the vidToIndex map
       val vertexArray = new Array[VD](vidToIndex.capacity)
       for (msg <- msgsIter) {
-        val ind = vidToIndex.getPos(msg.data._1) & OpenHashSet.POSITION_MASK
-        vertexArray(ind) = msg.data._2
+        val ind = vidToIndex.getPos(msg.vid) & OpenHashSet.POSITION_MASK
+        vertexArray(ind) = msg.data
       }
       Iterator((pid, vertexArray))
     }.cache()
@@ -447,7 +418,6 @@ object GraphImpl {
     // @todo assert edge table has partitioner
   }
 
-
   def makeTriplets[VD: ClassManifest, ED: ClassManifest]( 
     localVidMap: RDD[(Pid, VertexIdToIndexMap)],
     vTableReplicatedValues: RDD[(Pid, Array[VD]) ],
@@ -461,7 +431,6 @@ object GraphImpl {
     }
   }
 
-
   def mapTriplets[VD: ClassManifest, ED: ClassManifest, ED2: ClassManifest](
     g: GraphImpl[VD, ED],   
     f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
@@ -483,7 +452,6 @@ object GraphImpl {
     new GraphImpl(g.vTable, g.vid2pid, g.localVidMap, newETable)
   }
 
-
   def mapReduceTriplets[VD: ClassManifest, ED: ClassManifest, A: ClassManifest](
     g: GraphImpl[VD, ED],
     mapFunc: EdgeTriplet[VD, ED] => Array[(Vid, A)],
@@ -495,33 +463,34 @@ object GraphImpl {
     // Map and preaggregate 
     val preAgg = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValues){ 
       (edgePartitionIter, vidToIndexIter, vertexArrayIter) =>
-      val (pid, edgePartition) = edgePartitionIter.next()
+      val (_, edgePartition) = edgePartitionIter.next()
       val (_, vidToIndex) = vidToIndexIter.next()
       val (_, vertexArray) = vertexArrayIter.next()
       assert(!edgePartitionIter.hasNext)
       assert(!vidToIndexIter.hasNext)
       assert(!vertexArrayIter.hasNext)
       assert(vidToIndex.capacity == vertexArray.size)
+      // Reuse the vidToIndex map to run aggregation.
       val vmap = new PrimitiveKeyOpenHashMap[Vid, VD](vidToIndex, vertexArray)
-      // We can reuse the vidToIndex map for aggregation here as well.
-      /** @todo Since this has the downside of not allowing "messages" to arbitrary
-       * vertices we should consider just using a fresh map.
-       */
+      // TODO(jegonzal): This doesn't allow users to send messages to arbitrary vertices.
       val msgArray = new Array[A](vertexArray.size)
       val msgBS = new BitSet(vertexArray.size)
       // Iterate over the partition
       val et = new EdgeTriplet[VD, ED]
-      edgePartition.foreach{e => 
+      edgePartition.foreach { e =>
         et.set(e)
         et.srcAttr = vmap(e.srcId)
         et.dstAttr = vmap(e.dstId)
+        // TODO(rxin): rewrite the foreach using a simple while loop to speed things up.
+        // Also given we are only allowing zero, one, or two messages, we can completely unroll
+        // the for loop.
         mapFunc(et).foreach{ case (vid, msg) =>
           // verify that the vid is valid
           assert(vid == et.srcId || vid == et.dstId)
           // Get the index of the key
           val ind = vidToIndex.getPos(vid) & OpenHashSet.POSITION_MASK
           // Populate the aggregator map
-          if(msgBS.get(ind)) {
+          if (msgBS.get(ind)) {
             msgArray(ind) = reduceFunc(msgArray(ind), msg)
           } else { 
             msgArray(ind) = msg
@@ -536,14 +505,11 @@ object GraphImpl {
     VertexSetRDD(preAgg, g.vTable.index, reduceFunc)
   }
 
-
   protected def edgePartitionFunction1D(src: Vid, dst: Vid, numParts: Pid): Pid = {
     val mixingPrime: Vid = 1125899906842597L 
     (math.abs(src) * mixingPrime).toInt % numParts
   }
 
-
-
   /**
    * This function implements a classic 2D-Partitioning of a sparse matrix.  
    * Suppose we have a graph with 11 vertices that we want to partition 
@@ -596,7 +562,6 @@ object GraphImpl {
     (col * ceilSqrtNumParts + row) % numParts
   }
 
-
   /**
    * Assign edges to an aribtrary machine corresponding to a 
    * random vertex cut.
@@ -605,7 +570,6 @@ object GraphImpl {
     math.abs((src, dst).hashCode()) % numParts
   }
 
-
   /**
    * @todo This will only partition edges to the upper diagonal
    * of the 2D processor space.
@@ -622,4 +586,3 @@ object GraphImpl {
   }
 
 } // end of object GraphImpl
-
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala b/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala
index b7bbf257a4a5692163d5e53c07473ca35d3dab9b..9ac2c59bf844ebf49ace731a396bef3efbfcd731 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala
@@ -1,10 +1,24 @@
 package org.apache.spark.graph.impl
 
 import org.apache.spark.Partitioner
-import org.apache.spark.graph.Pid
+import org.apache.spark.graph.{Pid, Vid}
 import org.apache.spark.rdd.{ShuffledRDD, RDD}
 
 
+class VertexMessage[@specialized(Int, Long, Double, Boolean/*, AnyRef*/) T](
+    @transient var partition: Pid,
+    var vid: Vid,
+    var data: T)
+  extends Product2[Pid, (Vid, T)] {
+
+  override def _1 = partition
+
+  override def _2 = (vid, data)
+
+  override def canEqual(that: Any): Boolean = that.isInstanceOf[VertexMessage[_]]
+}
+
+
 /**
  * A message used to send a specific value to a partition.
  * @param partition index of the target partition.
@@ -30,6 +44,21 @@ object MessageToPartition {
 }
 
 
+class VertexMessageRDDFunctions[T: ClassManifest](self: RDD[VertexMessage[T]]) {
+  def partitionBy(partitioner: Partitioner): RDD[VertexMessage[T]] = {
+    val rdd = new ShuffledRDD[Pid, (Vid, T), VertexMessage[T]](self, partitioner)
+
+    // Set a custom serializer if the data is of int or double type.
+    if (classManifest[T] == ClassManifest.Int) {
+      rdd.setSerializer(classOf[IntVertexMessageSerializer].getName)
+    } else if (classManifest[T] == ClassManifest.Double) {
+      rdd.setSerializer(classOf[DoubleVertexMessageSerializer].getName)
+    }
+    rdd
+  }
+}
+
+
 class MessageToPartitionRDDFunctions[T: ClassManifest](self: RDD[MessageToPartition[T]]) {
 
   /**
@@ -46,4 +75,8 @@ object MessageToPartitionRDDFunctions {
   implicit def rdd2PartitionRDDFunctions[T: ClassManifest](rdd: RDD[MessageToPartition[T]]) = {
     new MessageToPartitionRDDFunctions(rdd)
   }
+
+  implicit def rdd2vertexMessageRDDFunctions[T: ClassManifest](rdd: RDD[VertexMessage[T]]) = {
+    new VertexMessageRDDFunctions(rdd)
+  }
 }
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala b/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala
new file mode 100644
index 0000000000000000000000000000000000000000..0092aa7c6b538bfbcc0d8ec04ddcc9f922e41377
--- /dev/null
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala
@@ -0,0 +1,125 @@
+package org.apache.spark.graph.impl
+
+import java.io.{InputStream, OutputStream}
+import java.nio.ByteBuffer
+
+import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance, Serializer}
+
+
+/** A special shuffle serializer for VertexMessage[Int]. */
+class IntVertexMessageSerializer extends Serializer {
+  override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+    override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+      def writeObject[T](t: T) = {
+        val msg = t.asInstanceOf[VertexMessage[Int]]
+        writeLong(msg.vid)
+        writeInt(msg.data)
+        this
+      }
+    }
+
+    override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+      override def readObject[T](): T = {
+        new VertexMessage[Int](0, readLong(), readInt()).asInstanceOf[T]
+      }
+    }
+  }
+}
+
+
+/** A special shuffle serializer for VertexMessage[Double]. */
+class DoubleVertexMessageSerializer extends Serializer {
+  override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+    override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+      def writeObject[T](t: T) = {
+        val msg = t.asInstanceOf[VertexMessage[Double]]
+        writeLong(msg.vid)
+        writeDouble(msg.data)
+        this
+      }
+    }
+
+    override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+      def readObject[T](): T = {
+        new VertexMessage[Double](0, readLong(), readDouble()).asInstanceOf[T]
+      }
+    }
+  }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Helper classes to shorten the implementation of those special serializers.
+////////////////////////////////////////////////////////////////////////////////
+
+sealed abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream {
+  // The implementation should override this one.
+  def writeObject[T](t: T): SerializationStream
+
+  def writeInt(v: Int) {
+    s.write(v >> 24)
+    s.write(v >> 16)
+    s.write(v >> 8)
+    s.write(v)
+  }
+
+  def writeLong(v: Long) {
+    s.write((v >>> 56).toInt)
+    s.write((v >>> 48).toInt)
+    s.write((v >>> 40).toInt)
+    s.write((v >>> 32).toInt)
+    s.write((v >>> 24).toInt)
+    s.write((v >>> 16).toInt)
+    s.write((v >>> 8).toInt)
+    s.write(v.toInt)
+  }
+
+  def writeDouble(v: Double) {
+    writeLong(java.lang.Double.doubleToLongBits(v))
+  }
+
+  override def flush(): Unit = s.flush()
+
+  override def close(): Unit = s.close()
+}
+
+
+sealed abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream {
+  // The implementation should override this one.
+  def readObject[T](): T
+
+  def readInt(): Int = {
+    (s.read() & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF)
+  }
+
+  def readLong(): Long = {
+    (s.read().toLong << 56) |
+      (s.read() & 0xFF).toLong << 48 |
+      (s.read() & 0xFF).toLong << 40 |
+      (s.read() & 0xFF).toLong << 32 |
+      (s.read() & 0xFF).toLong << 24 |
+      (s.read() & 0xFF) << 16 |
+      (s.read() & 0xFF) << 8 |
+      (s.read() & 0xFF)
+  }
+
+  def readDouble(): Double = java.lang.Double.longBitsToDouble(readLong())
+
+  override def close(): Unit = s.close()
+}
+
+
+sealed trait ShuffleSerializerInstance extends SerializerInstance {
+
+  override def serialize[T](t: T): ByteBuffer = throw new UnsupportedOperationException
+
+  override def deserialize[T](bytes: ByteBuffer): T = throw new UnsupportedOperationException
+
+  override def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T =
+    throw new UnsupportedOperationException
+
+  // The implementation should override the following two.
+  override def serializeStream(s: OutputStream): SerializationStream
+  override def deserializeStream(s: InputStream): DeserializationStream
+}