diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
index 14ae50e6657fd1eb1f4026156c54af0c73258a66..4db45c9af8fae765c9edbc132be3e647caa82290 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
@@ -138,7 +138,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
    * }}}
    *
    */
-  def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2): Graph[VD2, ED]
+  def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2)
+    (implicit eq: VD =:= VD2 = null): Graph[VD2, ED]
 
   /**
    * Transforms each edge attribute in the graph using the map function.  The map function is not
@@ -348,7 +349,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
    * }}}
    */
   def outerJoinVertices[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)])
-      (mapFunc: (VertexId, VD, Option[U]) => VD2)
+      (mapFunc: (VertexId, VD, Option[U]) => VD2)(implicit eq: VD =:= VD2 = null)
     : Graph[VD2, ED]
 
   /**
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
index 15ea05cbe281d55a9788533d7d114724de9605b9..ccdaa82eb91626092560dc65b1c2e7d3f04e8b7e 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
@@ -104,8 +104,11 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
     new GraphImpl(vertices.reverseRoutingTables(), replicatedVertexView.reverse())
   }
 
-  override def mapVertices[VD2: ClassTag](f: (VertexId, VD) => VD2): Graph[VD2, ED] = {
-    if (classTag[VD] equals classTag[VD2]) {
+  override def mapVertices[VD2: ClassTag]
+    (f: (VertexId, VD) => VD2)(implicit eq: VD =:= VD2 = null): Graph[VD2, ED] = {
+    // The implicit parameter eq will be populated by the compiler if VD and VD2 are equal, and left
+    // null if not
+    if (eq != null) {
       vertices.cache()
       // The map preserves type, so we can use incremental replication
       val newVerts = vertices.mapVertexPartitions(_.map(f)).cache()
@@ -232,8 +235,11 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
 
   override def outerJoinVertices[U: ClassTag, VD2: ClassTag]
       (other: RDD[(VertexId, U)])
-      (updateF: (VertexId, VD, Option[U]) => VD2): Graph[VD2, ED] = {
-    if (classTag[VD] equals classTag[VD2]) {
+      (updateF: (VertexId, VD, Option[U]) => VD2)
+      (implicit eq: VD =:= VD2 = null): Graph[VD2, ED] = {
+    // The implicit parameter eq will be populated by the compiler if VD and VD2 are equal, and left
+    // null if not
+    if (eq != null) {
       vertices.cache()
       // updateF preserves type, so we can use incremental replication
       val newVerts = vertices.leftJoin(other)(updateF).cache()
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala
index 776bfb8dd6bfa8e3e4e6855bbc7efa67961eeded..82e9e065151794edf3f7b474eec6469ee77e5577 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala
@@ -41,7 +41,7 @@ object LabelPropagation {
    *
    * @return a graph with vertex attributes containing the label of community affiliation
    */
-  def run[ED: ClassTag](graph: Graph[_, ED], maxSteps: Int): Graph[VertexId, ED] = {
+  def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = {
     val lpaGraph = graph.mapVertices { case (vid, _) => vid }
     def sendMessage(e: EdgeTriplet[VertexId, ED]) = {
       Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L)))
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala
index bba070f256d801448e5e2756386c3dd7204faa2d..590f0474957dd739431284ba31161bdcddbfdaf6 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala
@@ -49,7 +49,7 @@ object ShortestPaths {
    * @return a graph where each vertex attribute is a map containing the shortest-path distance to
    * each reachable landmark vertex.
    */
-  def run[ED: ClassTag](graph: Graph[_, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = {
+  def run[VD, ED: ClassTag](graph: Graph[VD, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = {
     val spGraph = graph.mapVertices { (vid, attr) =>
       if (landmarks.contains(vid)) makeMap(vid -> 0) else makeMap()
     }
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index abc25d0671133986adaf814b651672ee5beaf3bc..6506bac73d71c9aca6467199c24fc7c94431c9c9 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -159,6 +159,31 @@ class GraphSuite extends FunSuite with LocalSparkContext {
     }
   }
 
+  test("mapVertices changing type with same erased type") {
+    withSpark { sc =>
+      val vertices = sc.parallelize(Array[(Long, Option[java.lang.Integer])](
+        (1L, Some(1)),
+        (2L, Some(2)),
+        (3L, Some(3))
+      ))
+      val edges = sc.parallelize(Array(
+        Edge(1L, 2L, 0),
+        Edge(2L, 3L, 0),
+        Edge(3L, 1L, 0)
+      ))
+      val graph0 = Graph(vertices, edges)
+      // Trigger initial vertex replication
+      graph0.triplets.foreach(x => {})
+      // Change type of replicated vertices, but preserve erased type
+      val graph1 = graph0.mapVertices {
+        case (vid, integerOpt) => integerOpt.map((x: java.lang.Integer) => (x.toDouble): java.lang.Double)
+      }
+      // Access replicated vertices, exposing the erased type
+      val graph2 = graph1.mapTriplets(t => t.srcAttr.get)
+      assert(graph2.edges.map(_.attr).collect.toSet === Set[java.lang.Double](1.0, 2.0, 3.0))
+    }
+  }
+
   test("mapEdges") {
     withSpark { sc =>
       val n = 3