diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 5b1bf9476e4d5a4076a2f74453c500b48c03f398..cd0aea0cb3d1f18c4af0ae14821c88fa483637ef 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -277,6 +277,29 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
   def subtract(other: JavaPairRDD[K, V], p: Partitioner): JavaPairRDD[K, V] =
     fromRDD(rdd.subtract(other, p))
 
+  /**
+   * Return an RDD with the pairs from `this` whose keys are not in `other`.
+   *
+   * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
+   * RDD will be <= us.
+   */
+  def subtractByKey[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, V] = {
+    implicit val cmw: ClassTag[W] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+    fromRDD(rdd.subtractByKey(other))
+  }
+
+  /** Return an RDD with the pairs from `this` whose keys are not in `other`. */
+  def subtractByKey[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, V] = {
+    implicit val cmw: ClassTag[W] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+    fromRDD(rdd.subtractByKey(other, numPartitions))
+  }
+
+  /** Return an RDD with the pairs from `this` whose keys are not in `other`. */
+  def subtractByKey[W](other: JavaPairRDD[K, W], p: Partitioner): JavaPairRDD[K, V] = {
+    implicit val cmw: ClassTag[W] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[W]]
+    fromRDD(rdd.subtractByKey(other, p))
+  }
+
   /**
    * Return a copy of the RDD partitioned using the specified partitioner.
    */