diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala
index 6dad167fa7411d1294ca1171ed59004728d4c2a0..904be213147dc5dbba8f9b4c96150044c9399c31 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala
@@ -104,8 +104,14 @@ class VertexRDDImpl[VD] private[graphx] (
     this.mapVertexPartitions(_.map(f))
 
   override def diff(other: VertexRDD[VD]): VertexRDD[VD] = {
+    val otherPartition = other match {
+      case other: VertexRDD[_] if this.partitioner == other.partitioner =>
+        other.partitionsRDD
+      case _ =>
+        VertexRDD(other.partitionBy(this.partitioner.get)).partitionsRDD
+    }
     val newPartitionsRDD = partitionsRDD.zipPartitions(
-      other.partitionsRDD, preservesPartitioning = true
+      otherPartition, preservesPartitioning = true
     ) { (thisIter, otherIter) =>
       val thisPart = thisIter.next()
       val otherPart = otherIter.next()
@@ -133,7 +139,7 @@ class VertexRDDImpl[VD] private[graphx] (
     // Test if the other vertex is a VertexRDD to choose the optimal join strategy.
     // If the other set is a VertexRDD then we use the much more efficient leftZipJoin
     other match {
-      case other: VertexRDD[_] =>
+      case other: VertexRDD[_] if this.partitioner == other.partitioner =>
         leftZipJoin(other)(f)
       case _ =>
         this.withPartitionsRDD[VD3](
@@ -162,7 +168,7 @@ class VertexRDDImpl[VD] private[graphx] (
     // Test if the other vertex is a VertexRDD to choose the optimal join strategy.
     // If the other set is a VertexRDD then we use the much more efficient innerZipJoin
     other match {
-      case other: VertexRDD[_] =>
+      case other: VertexRDD[_] if this.partitioner == other.partitioner =>
         innerZipJoin(other)(f)
       case _ =>
         this.withPartitionsRDD(