diff --git a/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala b/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala
index f65f96ed0c1b2dd62f1d5bf0b20b77cb060267af..82b9198e432c728179de848b2b8b1672e7dbbc4d 100644
--- a/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala
@@ -2,7 +2,7 @@ package org.apache.spark.graph
 
 import com.esotericsoftware.kryo.Kryo
 
-import org.apache.spark.graph.impl.{EdgePartition, MessageToPartition}
+import org.apache.spark.graph.impl._
 import org.apache.spark.serializer.KryoRegistrator
 import org.apache.spark.util.collection.BitSet
 
@@ -12,6 +12,8 @@ class GraphKryoRegistrator extends KryoRegistrator {
     kryo.register(classOf[Edge[Object]])
     kryo.register(classOf[MutableTuple2[Object, Object]])
     kryo.register(classOf[MessageToPartition[Object]])
+    kryo.register(classOf[VertexBroadcastMsg[Object]])
+    kryo.register(classOf[AggregationMsg[Object]])
     kryo.register(classOf[(Vid, Object)])
     kryo.register(classOf[EdgePartition[Object]])
     kryo.register(classOf[BitSet])
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala b/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala
index 3fc0b7c0f7588d3b2aff84868367edc7f45dcd39..d0a5adb85cd8a3d6bb28ff588e527465a0b3c3dc 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/MessageToPartition.scala
@@ -55,6 +55,8 @@ class VertexBroadcastMsgRDDFunctions[T: ClassManifest](self: RDD[VertexBroadcast
     // Set a custom serializer if the data is of int or double type.
     if (classManifest[T] == ClassManifest.Int) {
       rdd.setSerializer(classOf[IntVertexBroadcastMsgSerializer].getName)
+    } else if (classManifest[T] == ClassManifest.Long) {
+      rdd.setSerializer(classOf[LongVertexBroadcastMsgSerializer].getName)
     } else if (classManifest[T] == ClassManifest.Double) {
       rdd.setSerializer(classOf[DoubleVertexBroadcastMsgSerializer].getName)
     }
@@ -70,6 +72,8 @@ class AggregationMessageRDDFunctions[T: ClassManifest](self: RDD[AggregationMsg[
     // Set a custom serializer if the data is of int or double type.
     if (classManifest[T] == ClassManifest.Int) {
       rdd.setSerializer(classOf[IntAggMsgSerializer].getName)
+    } else if (classManifest[T] == ClassManifest.Long) {
+      rdd.setSerializer(classOf[LongAggMsgSerializer].getName)
     } else if (classManifest[T] == ClassManifest.Double) {
       rdd.setSerializer(classOf[DoubleAggMsgSerializer].getName)
     }
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala b/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala
index 8b4c0868b1a6dff011cab21938f60fc1632c6ece..54fd65e7381f2512a8c069bc91abf95f7d58c320 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/Serializers.scala
@@ -27,6 +27,28 @@ class IntVertexBroadcastMsgSerializer extends Serializer {
   }
 }
 
+/** A special shuffle serializer for VertexBroadcastMessage[Long]. */
+class LongVertexBroadcastMsgSerializer extends Serializer {
+  override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+    override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+      def writeObject[T](t: T) = {
+        val msg = t.asInstanceOf[VertexBroadcastMsg[Long]]
+        writeLong(msg.vid)
+        writeLong(msg.data)
+        this
+      }
+    }
+
+    override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+      override def readObject[T](): T = {
+        val a = readLong()
+        val b = readLong()
+        new VertexBroadcastMsg[Long](0, a, b).asInstanceOf[T]
+      }
+    }
+  }
+}
 
 /** A special shuffle serializer for VertexBroadcastMessage[Double]. */
 class DoubleVertexBroadcastMsgSerializer extends Serializer {
@@ -43,7 +65,9 @@ class DoubleVertexBroadcastMsgSerializer extends Serializer {
 
     override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
       def readObject[T](): T = {
-        new VertexBroadcastMsg[Double](0, readLong(), readDouble()).asInstanceOf[T]
+        val a = readLong()
+        val b = readDouble()
+        new VertexBroadcastMsg[Double](0, a, b).asInstanceOf[T]
       }
     }
   }
@@ -65,7 +89,32 @@ class IntAggMsgSerializer extends Serializer {
 
     override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
       override def readObject[T](): T = {
-        new AggregationMsg[Int](readLong(), readInt()).asInstanceOf[T]
+        val a = readLong()
+        val b = readInt()
+        new AggregationMsg[Int](a, b).asInstanceOf[T]
+      }
+    }
+  }
+}
+
+/** A special shuffle serializer for AggregationMessage[Long]. */
+class LongAggMsgSerializer extends Serializer {
+  override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
+
+    override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
+      def writeObject[T](t: T) = {
+        val msg = t.asInstanceOf[AggregationMsg[Long]]
+        writeLong(msg.vid)
+        writeLong(msg.data)
+        this
+      }
+    }
+
+    override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
+      override def readObject[T](): T = {
+        val a = readLong()
+        val b = readLong()
+        new AggregationMsg[Long](a, b).asInstanceOf[T]
       }
     }
   }
@@ -87,7 +136,9 @@ class DoubleAggMsgSerializer extends Serializer {
 
     override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
       def readObject[T](): T = {
-        new AggregationMsg[Double](readLong(), readDouble()).asInstanceOf[T]
+        val a = readLong()
+        val b = readDouble()
+        new AggregationMsg[Double](a, b).asInstanceOf[T]
       }
     }
   }
diff --git a/graph/src/test/scala/org/apache/spark/graph/SerializerSuite.scala b/graph/src/test/scala/org/apache/spark/graph/SerializerSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..5a59fd912a519a7e45ba504e4c4f3ae1608fee2d
--- /dev/null
+++ b/graph/src/test/scala/org/apache/spark/graph/SerializerSuite.scala
@@ -0,0 +1,139 @@
+package org.apache.spark.graph
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.graph.LocalSparkContext._
+import java.io.ByteArrayInputStream
+import java.io.ByteArrayOutputStream
+import org.apache.spark.graph.impl._
+import org.apache.spark.graph.impl.MsgRDDFunctions._
+import org.apache.spark._
+
+
+class SerializerSuite extends FunSuite with LocalSparkContext {
+
+  System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+  System.setProperty("spark.kryo.registrator", "org.apache.spark.graph.GraphKryoRegistrator")
+
+  test("TestVertexBroadcastMessageInt") {
+    val outMsg = new VertexBroadcastMsg[Int](3,4,5)
+    val bout = new ByteArrayOutputStream
+    val outStrm = new IntVertexBroadcastMsgSerializer().newInstance().serializeStream(bout)
+    outStrm.writeObject(outMsg)
+    outStrm.writeObject(outMsg)
+    bout.flush
+    val bin = new ByteArrayInputStream(bout.toByteArray)
+    val inStrm = new IntVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin)
+    val inMsg1: VertexBroadcastMsg[Int] = inStrm.readObject()
+    val inMsg2: VertexBroadcastMsg[Int] = inStrm.readObject()
+    assert(outMsg.vid === inMsg1.vid)
+    assert(outMsg.vid === inMsg2.vid)
+    assert(outMsg.data === inMsg1.data)
+    assert(outMsg.data === inMsg2.data)
+  }
+
+  test("TestVertexBroadcastMessageLong") {
+    val outMsg = new VertexBroadcastMsg[Long](3,4,5)
+    val bout = new ByteArrayOutputStream
+    val outStrm = new LongVertexBroadcastMsgSerializer().newInstance().serializeStream(bout)
+    outStrm.writeObject(outMsg)
+    outStrm.writeObject(outMsg)
+    bout.flush
+    val bin = new ByteArrayInputStream(bout.toByteArray)
+    val inStrm = new LongVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin)
+    val inMsg1: VertexBroadcastMsg[Long] = inStrm.readObject()
+    val inMsg2: VertexBroadcastMsg[Long] = inStrm.readObject()
+    assert(outMsg.vid === inMsg1.vid)
+    assert(outMsg.vid === inMsg2.vid)
+    assert(outMsg.data === inMsg1.data)
+    assert(outMsg.data === inMsg2.data)
+  }
+
+  test("TestVertexBroadcastMessageDouble") {
+    val outMsg = new VertexBroadcastMsg[Double](3,4,5.0)
+    val bout = new ByteArrayOutputStream
+    val outStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().serializeStream(bout)
+    outStrm.writeObject(outMsg)
+    outStrm.writeObject(outMsg)
+    bout.flush
+    val bin = new ByteArrayInputStream(bout.toByteArray)
+    val inStrm = new DoubleVertexBroadcastMsgSerializer().newInstance().deserializeStream(bin)
+    val inMsg1: VertexBroadcastMsg[Double] = inStrm.readObject()
+    val inMsg2: VertexBroadcastMsg[Double] = inStrm.readObject()
+    assert(outMsg.vid === inMsg1.vid)
+    assert(outMsg.vid === inMsg2.vid)
+    assert(outMsg.data === inMsg1.data)
+    assert(outMsg.data === inMsg2.data)
+  }
+
+  test("TestAggregationMessageInt") {
+    val outMsg = new AggregationMsg[Int](4,5)
+    val bout = new ByteArrayOutputStream
+    val outStrm = new IntAggMsgSerializer().newInstance().serializeStream(bout)
+    outStrm.writeObject(outMsg)
+    outStrm.writeObject(outMsg)
+    bout.flush
+    val bin = new ByteArrayInputStream(bout.toByteArray)
+    val inStrm = new IntAggMsgSerializer().newInstance().deserializeStream(bin)
+    val inMsg1: AggregationMsg[Int] = inStrm.readObject()
+    val inMsg2: AggregationMsg[Int] = inStrm.readObject()
+    assert(outMsg.vid === inMsg1.vid)
+    assert(outMsg.vid === inMsg2.vid)
+    assert(outMsg.data === inMsg1.data)
+    assert(outMsg.data === inMsg2.data)
+  }
+
+  test("TestAggregationMessageLong") {
+    val outMsg = new AggregationMsg[Long](4,5)
+    val bout = new ByteArrayOutputStream
+    val outStrm = new LongAggMsgSerializer().newInstance().serializeStream(bout)
+    outStrm.writeObject(outMsg)
+    outStrm.writeObject(outMsg)
+    bout.flush
+    val bin = new ByteArrayInputStream(bout.toByteArray)
+    val inStrm = new LongAggMsgSerializer().newInstance().deserializeStream(bin)
+    val inMsg1: AggregationMsg[Long] = inStrm.readObject()
+    val inMsg2: AggregationMsg[Long] = inStrm.readObject()
+    assert(outMsg.vid === inMsg1.vid)
+    assert(outMsg.vid === inMsg2.vid)
+    assert(outMsg.data === inMsg1.data)
+    assert(outMsg.data === inMsg2.data)
+  }
+
+  test("TestAggregationMessageDouble") {
+    val outMsg = new AggregationMsg[Double](4,5.0)
+    val bout = new ByteArrayOutputStream
+    val outStrm = new DoubleAggMsgSerializer().newInstance().serializeStream(bout)
+    outStrm.writeObject(outMsg)
+    outStrm.writeObject(outMsg)
+    bout.flush
+    val bin = new ByteArrayInputStream(bout.toByteArray)
+    val inStrm = new DoubleAggMsgSerializer().newInstance().deserializeStream(bin)
+    val inMsg1: AggregationMsg[Double] = inStrm.readObject()
+    val inMsg2: AggregationMsg[Double] = inStrm.readObject()
+    assert(outMsg.vid === inMsg1.vid)
+    assert(outMsg.vid === inMsg2.vid)
+    assert(outMsg.data === inMsg1.data)
+    assert(outMsg.data === inMsg2.data)
+  }
+
+  test("TestShuffleVertexBroadcastMsg") {
+    withSpark(new SparkContext("local[2]", "test")) { sc =>
+      val bmsgs = sc.parallelize(
+        (0 until 100).map(pid => new VertexBroadcastMsg[Int](pid, pid, pid)), 10)
+      val partitioner = new HashPartitioner(3)
+      val bmsgsArray = bmsgs.partitionBy(partitioner).collect
+    }
+  }
+
+  test("TestShuffleAggregationMsg") {
+    withSpark(new SparkContext("local[2]", "test")) { sc =>
+      val bmsgs = sc.parallelize(
+        (0 until 100).map(pid => new AggregationMsg[Int](pid, pid)), 10)
+      val partitioner = new HashPartitioner(3)
+      val bmsgsArray = bmsgs.partitionBy(partitioner).collect
+    }
+  }
+
+}
\ No newline at end of file