diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
index 3827ac8d0fd6aaceb5b250a91e88c8ca7227e80d..502b112d31c2edf3cce6bb5838b34444148577ef 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
@@ -119,7 +119,7 @@ object RoutingTablePartition {
  */
 private[graphx]
 class RoutingTablePartition(
-    private val routingTable: Array[(Array[VertexId], BitSet, BitSet)]) {
+    private val routingTable: Array[(Array[VertexId], BitSet, BitSet)]) extends Serializable {
   /** The maximum number of edge partitions this `RoutingTablePartition` is built to join with. */
   val numEdgePartitions: Int = routingTable.size
 
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala
index 34939b24440aa2f39b4f65f5a366353ce7db0e28..5ad6390a56c4f6965c5732ccea9e7f756a975c5b 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala
@@ -60,7 +60,8 @@ private[graphx] object VertexPartitionBase {
  * `VertexPartitionBaseOpsConstructor` typeclass (for example,
  * [[VertexPartition.VertexPartitionOpsConstructor]]).
  */
-private[graphx] abstract class VertexPartitionBase[@specialized(Long, Int, Double) VD: ClassTag] {
+private[graphx] abstract class VertexPartitionBase[@specialized(Long, Int, Double) VD: ClassTag]
+  extends Serializable {
 
   def index: VertexIdToIndexMap
   def values: Array[VD]
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala
index a4f769b29401078f5658d1823316d33db8f9f0e8..b40aa1b417a0fa76a660fa770664d974c6a46dfb 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala
@@ -35,7 +35,7 @@ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
 private[graphx] abstract class VertexPartitionBaseOps
     [VD: ClassTag, Self[X] <: VertexPartitionBase[X] : VertexPartitionBaseOpsConstructor]
     (self: Self[VD])
-    extends Logging {
+  extends Serializable with Logging {
 
   def withIndex(index: VertexIdToIndexMap): Self[VD]
   def withValues[VD2: ClassTag](values: Array[VD2]): Self[VD2]
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
index 28fd112f2b124a23b96b2f00e85b0f5a753c7525..9d00f76327e4c38af2fce2514b4d7136a7903615 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
@@ -23,6 +23,7 @@ import scala.util.Random
 import org.scalatest.FunSuite
 
 import org.apache.spark.SparkConf
+import org.apache.spark.serializer.JavaSerializer
 import org.apache.spark.serializer.KryoSerializer
 
 import org.apache.spark.graphx._
@@ -124,18 +125,21 @@ class EdgePartitionSuite extends FunSuite {
     assert(ep.numActives == Some(2))
   }
 
-  test("Kryo serialization") {
+  test("serialization") {
     val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0))
     val a: EdgePartition[Int, Int] = makeEdgePartition(aList)
-    val conf = new SparkConf()
+    val javaSer = new JavaSerializer(new SparkConf())
+    val kryoSer = new KryoSerializer(new SparkConf()
       .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
-      .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
-    val s = new KryoSerializer(conf).newInstance()
-    val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a))
-    assert(aSer.srcIds.toList === a.srcIds.toList)
-    assert(aSer.dstIds.toList === a.dstIds.toList)
-    assert(aSer.data.toList === a.data.toList)
-    assert(aSer.index != null)
-    assert(aSer.vertices.iterator.toSet === a.vertices.iterator.toSet)
+      .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator"))
+
+    for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) {
+      val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a))
+      assert(aSer.srcIds.toList === a.srcIds.toList)
+      assert(aSer.dstIds.toList === a.dstIds.toList)
+      assert(aSer.data.toList === a.data.toList)
+      assert(aSer.index != null)
+      assert(aSer.vertices.iterator.toSet === a.vertices.iterator.toSet)
+    }
   }
 }
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
index 8bf1384d514c13bff7f9cf1733f24317f320846c..f9e771a9000130fe9ac5478067cd1a1c58e0f7c2 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
@@ -17,9 +17,14 @@
 
 package org.apache.spark.graphx.impl
 
-import org.apache.spark.graphx._
 import org.scalatest.FunSuite
 
+import org.apache.spark.SparkConf
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.serializer.KryoSerializer
+
+import org.apache.spark.graphx._
+
 class VertexPartitionSuite extends FunSuite {
 
   test("isDefined, filter") {
@@ -116,4 +121,17 @@ class VertexPartitionSuite extends FunSuite {
     assert(vp3.index.getPos(2) === -1)
   }
 
+  test("serialization") {
+    val verts = Set((0L, 1), (1L, 1), (2L, 1))
+    val vp = VertexPartition(verts.iterator)
+    val javaSer = new JavaSerializer(new SparkConf())
+    val kryoSer = new KryoSerializer(new SparkConf()
+      .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+      .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator"))
+
+    for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) {
+      val vpSer: VertexPartition[Int] = s.deserialize(s.serialize(vp))
+      assert(vpSer.iterator.toSet === verts)
+    }
+  }
 }