diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 9e912d3adbb5e4ef7208420ac78fd019c28d5097..1d71875ed13af9303b7098018edd7232c1d37b05 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -26,7 +26,7 @@ import com.google.common.base.Optional
 import org.apache.hadoop.io.compress.CompressionCodec
 
 import org.apache.spark.{SparkContext, Partition, TaskContext}
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{RDD, PartitionPruningRDD}
 import org.apache.spark.api.java.JavaPairRDD._
 import org.apache.spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
 import org.apache.spark.partial.{PartialResult, BoundedDouble}
@@ -244,6 +244,15 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
     new java.util.ArrayList(arr)
   }
 
+  /**
+   * Return an array that contains all of the elements in a specific partition of this RDD.
+   */
+  def collectPartition(partitionId: Int): JList[T] = {
+    import scala.collection.JavaConversions._
+    val partition = new PartitionPruningRDD[T](rdd, _ == partitionId)
+    new java.util.ArrayList(partition.collect().toSeq)
+  }
+
   /**
    * Reduces the elements of this RDD using the specified commutative and associative binary operator.
    */
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index a659cc06c2beee91ef40ac79bb6754948b68000b..ca42c769286e6532b3c8e74c09d4995656fdba1e 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -235,10 +235,6 @@ private[spark] object PythonRDD {
     file.close()
   }
 
-  def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = {
-    implicit val cm : ClassTag[T] = rdd.elementClassTag
-    rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator
-  }
 }
 
 private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 4234f6eac72f4cebd63232e7c2d33c2521805ce1..2862ed30197d89f7de1a13356729d499baebad98 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -897,4 +897,32 @@ public class JavaAPISuite implements Serializable {
         new Tuple2<Integer, Integer>(0, 4)), rdd3.collect());
 
   }
+
+  @Test
+  public void collectPartition() {
+    JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3);
+
+    JavaPairRDD<Integer, Integer> rdd2 = rdd1.map(new PairFunction<Integer, Integer, Integer>() {
+      @Override
+      public Tuple2<Integer, Integer> call(Integer i) throws Exception {
+        return new Tuple2<Integer, Integer>(i, i % 2);
+      }
+    });
+
+    Assert.assertEquals(Arrays.asList(1, 2), rdd1.collectPartition(0));
+    Assert.assertEquals(Arrays.asList(3, 4), rdd1.collectPartition(1));
+    Assert.assertEquals(Arrays.asList(5, 6, 7), rdd1.collectPartition(2));
+
+    Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(1, 1),
+                                      new Tuple2<Integer, Integer>(2, 0)),
+                        rdd2.collectPartition(0));
+    Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(3, 1),
+                                      new Tuple2<Integer, Integer>(4, 0)),
+                        rdd2.collectPartition(1));
+    Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(5, 1),
+                                      new Tuple2<Integer, Integer>(6, 0),
+                                      new Tuple2<Integer, Integer>(7, 1)),
+                        rdd2.collectPartition(2));
+  }
+
 }
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index cbd41e58c4a780392b2a6b8c58320535e416cd36..0604f6836ca42b3b99cd44a19af0e55fc68415a2 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -43,7 +43,6 @@ class SparkContext(object):
     _gateway = None
     _jvm = None
     _writeToFile = None
-    _takePartition = None
     _next_accum_id = 0
     _active_spark_context = None
     _lock = Lock()
@@ -134,8 +133,6 @@ class SparkContext(object):
                 SparkContext._jvm = SparkContext._gateway.jvm
                 SparkContext._writeToFile = \
                     SparkContext._jvm.PythonRDD.writeToFile
-                SparkContext._takePartition = \
-                    SparkContext._jvm.PythonRDD.takePartition
 
             if instance:
                 if SparkContext._active_spark_context and SparkContext._active_spark_context != instance:
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 61720dcf1af9d4614f2bc7902efb2b2c93af1b40..d81b7c90c1336148835fa41ed220be8f8ac42f8a 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -577,7 +577,7 @@ class RDD(object):
         mapped = self.mapPartitions(takeUpToNum)
         items = []
         for partition in range(mapped._jrdd.splits().size()):
-            iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
+            iterator = mapped._jrdd.collectPartition(partition).iterator()
             items.extend(mapped._collect_iterator_through_file(iterator))
             if len(items) >= num:
                 break