diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
index d295d0127ac72ca5101ec6610f2b63d061c1a57a..f97f329c0e832f960efbf1701290509aea2fb9d3 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
@@ -24,6 +24,9 @@ import org.apache.spark.util.BoundedPriorityQueue
 import org.apache.spark.util.collection.BitSet
 
 import org.apache.spark.graphx.impl._
+import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
+import org.apache.spark.util.collection.OpenHashSet
+
 
 /**
  * Registers GraphX classes with Kryo for improved performance.
@@ -43,8 +46,8 @@ class GraphKryoRegistrator extends KryoRegistrator {
     kryo.register(classOf[PartitionStrategy])
     kryo.register(classOf[BoundedPriorityQueue[Object]])
     kryo.register(classOf[EdgeDirection])
-
-    // This avoids a large number of hash table lookups.
-    kryo.setReferences(false)
+    kryo.register(classOf[GraphXPrimitiveKeyOpenHashMap[VertexId, Int]])
+    kryo.register(classOf[OpenHashSet[Int]])
+    kryo.register(classOf[OpenHashSet[Long]])
   }
 }
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
index 871e81f8d245c0c95dd7fc412795e0b4b7f88fdd..a5c9cd1f8b4e632f035a1cab909ddfd5cc11d80b 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
@@ -20,7 +20,7 @@ package org.apache.spark.graphx.impl
 import scala.reflect.{classTag, ClassTag}
 
 import org.apache.spark.graphx._
-import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
+import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
 
 /**
  * A collection of edges stored in columnar format, along with any vertex attributes referenced. The
@@ -42,12 +42,12 @@ import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
 private[graphx]
 class EdgePartition[
     @specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED: ClassTag, VD: ClassTag](
-    @transient val srcIds: Array[VertexId],
-    @transient val dstIds: Array[VertexId],
-    @transient val data: Array[ED],
-    @transient val index: PrimitiveKeyOpenHashMap[VertexId, Int],
-    @transient val vertices: VertexPartition[VD],
-    @transient val activeSet: Option[VertexSet] = None
+    val srcIds: Array[VertexId] = null,
+    val dstIds: Array[VertexId] = null,
+    val data: Array[ED] = null,
+    val index: GraphXPrimitiveKeyOpenHashMap[VertexId, Int] = null,
+    val vertices: VertexPartition[VD] = null,
+    val activeSet: Option[VertexSet] = None
   ) extends Serializable {
 
   /** Return a new `EdgePartition` with the specified edge data. */
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
index ecb49bef42e455223876048102c0305df105f6de..4520beb99151583c0270ba4bbe651ffaaac5940e 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
@@ -23,7 +23,7 @@ import scala.util.Sorting
 import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveVector}
 
 import org.apache.spark.graphx._
-import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
+import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
 
 private[graphx]
 class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag](
@@ -41,7 +41,7 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla
     val srcIds = new Array[VertexId](edgeArray.size)
     val dstIds = new Array[VertexId](edgeArray.size)
     val data = new Array[ED](edgeArray.size)
-    val index = new PrimitiveKeyOpenHashMap[VertexId, Int]
+    val index = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int]
     // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and
     // adding them to the index
     if (edgeArray.length > 0) {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala
index ebb0b9418d65dc9e2d3b58d96bbb709295fbc3e5..56f79a7097fce11483bb3ca97424d8bc125a2e75 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.graphx.impl
 import scala.reflect.ClassTag
 
 import org.apache.spark.graphx._
-import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
+import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
 
 /**
  * The Iterator type returned when constructing edge triplets. This could be an anonymous class in
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
index 927e32ad0f4487e4e8e3e11a0ae7f9a1be3b45a1..d02e9238adba59a77cd8c9cc7dd9f208f4096eb3 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
@@ -25,7 +25,7 @@ import org.apache.spark.rdd.ShuffledRDD
 import org.apache.spark.util.collection.{BitSet, PrimitiveVector}
 
 import org.apache.spark.graphx._
-import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
+import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
 
 /**
  * A message from the edge partition `pid` to the vertex partition containing `vid` specifying that
@@ -69,7 +69,7 @@ object RoutingTablePartition {
     : Iterator[RoutingTableMessage] = {
     // Determine which positions each vertex id appears in using a map where the low 2 bits
     // represent src and dst
-    val map = new PrimitiveKeyOpenHashMap[VertexId, Byte]
+    val map = new GraphXPrimitiveKeyOpenHashMap[VertexId, Byte]
     edgePartition.srcIds.iterator.foreach { srcId =>
       map.changeValue(srcId, 0x1, (b: Byte) => (b | 0x1).toByte)
     }
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala
index f4e221d4e05ae4e0813722003834e4f0d62a1fcd..dca54b8a7da86faaa4bceed8698d5f6dd9322d43 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala
@@ -22,7 +22,7 @@ import scala.reflect.ClassTag
 import org.apache.spark.util.collection.{BitSet, PrimitiveVector}
 
 import org.apache.spark.graphx._
-import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
+import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
 
 /** Stores vertex attributes to ship to an edge partition. */
 private[graphx]
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala
index f1d174720a1baa4fe1a5cb2743121f184932d610..55c7a19d1bdab10c3fcdc4aaf3bb6243651d2b44 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala
@@ -22,7 +22,7 @@ import scala.reflect.ClassTag
 import org.apache.spark.util.collection.BitSet
 
 import org.apache.spark.graphx._
-import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
+import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
 
 private[graphx] object VertexPartition {
   /** Construct a `VertexPartition` from the given vertices. */
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala
index 8d9e0204d27f24b64a22c30560bc6a2707035624..34939b24440aa2f39b4f65f5a366353ce7db0e28 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala
@@ -23,7 +23,7 @@ import scala.reflect.ClassTag
 import org.apache.spark.util.collection.BitSet
 
 import org.apache.spark.graphx._
-import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
+import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
 
 private[graphx] object VertexPartitionBase {
   /**
@@ -32,7 +32,7 @@ private[graphx] object VertexPartitionBase {
    */
   def initFrom[VD: ClassTag](iter: Iterator[(VertexId, VD)])
     : (VertexIdToIndexMap, Array[VD], BitSet) = {
-    val map = new PrimitiveKeyOpenHashMap[VertexId, VD]
+    val map = new GraphXPrimitiveKeyOpenHashMap[VertexId, VD]
     iter.foreach { pair =>
       map(pair._1) = pair._2
     }
@@ -45,7 +45,7 @@ private[graphx] object VertexPartitionBase {
    */
   def initFrom[VD: ClassTag](iter: Iterator[(VertexId, VD)], mergeFunc: (VD, VD) => VD)
     : (VertexIdToIndexMap, Array[VD], BitSet) = {
-    val map = new PrimitiveKeyOpenHashMap[VertexId, VD]
+    val map = new GraphXPrimitiveKeyOpenHashMap[VertexId, VD]
     iter.foreach { pair =>
       map.setMerge(pair._1, pair._2, mergeFunc)
     }
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala
index 21ff615feca6c74c2a7d26e8915246bb15acef30..a4f769b29401078f5658d1823316d33db8f9f0e8 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala
@@ -25,7 +25,7 @@ import org.apache.spark.Logging
 import org.apache.spark.util.collection.BitSet
 
 import org.apache.spark.graphx._
-import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
+import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
 
 /**
  * An class containing additional operations for subclasses of VertexPartitionBase that provide
@@ -224,7 +224,7 @@ private[graphx] abstract class VertexPartitionBaseOps
    * Construct a new VertexPartition whose index contains only the vertices in the mask.
    */
   def reindex(): Self[VD] = {
-    val hashMap = new PrimitiveKeyOpenHashMap[VertexId, VD]
+    val hashMap = new GraphXPrimitiveKeyOpenHashMap[VertexId, VD]
     val arbitraryMerge = (a: VD, b: VD) => a
     for ((k, v) <- self.iterator) {
       hashMap.setMerge(k, v, arbitraryMerge)
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/PrimitiveKeyOpenHashMap.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala
similarity index 98%
rename from graphx/src/main/scala/org/apache/spark/graphx/util/collection/PrimitiveKeyOpenHashMap.scala
rename to graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala
index 7b02e2ed1a9cb019508c09e5ab1603b00067cb17..57b01b6f2e1fba897382d183e5511ab190beea92 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/PrimitiveKeyOpenHashMap.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala
@@ -29,7 +29,7 @@ import scala.reflect._
  * Under the hood, it uses our OpenHashSet implementation.
  */
 private[graphx]
-class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
+class GraphXPrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
                               @specialized(Long, Int, Double) V: ClassTag](
     val keySet: OpenHashSet[K], var _values: Array[V])
   extends Iterable[(K, V)]
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
index d2e0c01bc35ef87c6e2d699fd61eaa807ae8ec71..28fd112f2b124a23b96b2f00e85b0f5a753c7525 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
@@ -22,6 +22,9 @@ import scala.util.Random
 
 import org.scalatest.FunSuite
 
+import org.apache.spark.SparkConf
+import org.apache.spark.serializer.KryoSerializer
+
 import org.apache.spark.graphx._
 
 class EdgePartitionSuite extends FunSuite {
@@ -120,4 +123,19 @@ class EdgePartitionSuite extends FunSuite {
     assert(!ep.isActive(-1))
     assert(ep.numActives == Some(2))
   }
+
+  test("Kryo serialization") {
+    val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0))
+    val a: EdgePartition[Int, Int] = makeEdgePartition(aList)
+    val conf = new SparkConf()
+      .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+      .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
+    val s = new KryoSerializer(conf).newInstance()
+    val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a))
+    assert(aSer.srcIds.toList === a.srcIds.toList)
+    assert(aSer.dstIds.toList === a.dstIds.toList)
+    assert(aSer.data.toList === a.data.toList)
+    assert(aSer.index != null)
+    assert(aSer.vertices.iterator.toSet === a.vertices.iterator.toSet)
+  }
 }