diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index feeb6c02caa7836114ecd7045ebe2b07733b71f9..880f61c49726ebc2ae7b27670facffce45a99b99 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -758,6 +758,32 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
     rdd.saveAsHadoopDataset(conf)
   }
 
+  /**
+   * Repartition the RDD according to the given partitioner and, within each resulting partition,
+   * sort records by their keys.
+   *
+   * This is more efficient than calling `repartition` and then sorting within each partition
+   * because it can push the sorting down into the shuffle machinery.
+   */
+  def repartitionAndSortWithinPartitions(partitioner: Partitioner): JavaPairRDD[K, V] = {
+    val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]]
+    repartitionAndSortWithinPartitions(partitioner, comp)
+  }
+
+  /**
+   * Repartition the RDD according to the given partitioner and, within each resulting partition,
+   * sort records by their keys.
+   *
+   * This is more efficient than calling `repartition` and then sorting within each partition
+   * because it can push the sorting down into the shuffle machinery.
+   */
+  def repartitionAndSortWithinPartitions(partitioner: Partitioner, comp: Comparator[K])
+    : JavaPairRDD[K, V] = {
+    implicit val ordering = comp // Allow implicit conversion of Comparator to Ordering.
+    fromRDD(
+      new OrderedRDDFunctions[K, V, (K, V)](rdd).repartitionAndSortWithinPartitions(partitioner))
+  }
+
   /**
    * Sort the RDD by key, so that each partition contains a sorted range of the elements in
    * ascending order. Calling `collect` or `save` on the resulting RDD will return or output an
diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
index e98bad2026e32637d90d0cee35d16bbcb62accc8..d0dbfef35d03c9d3c63455a9535a5db9220ba442 100644
--- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
@@ -19,7 +19,7 @@ package org.apache.spark.rdd
 
 import scala.reflect.ClassTag
 
-import org.apache.spark.{Logging, RangePartitioner}
+import org.apache.spark.{Logging, Partitioner, RangePartitioner}
 import org.apache.spark.annotation.DeveloperApi
 
 /**
@@ -64,4 +64,16 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
     new ShuffledRDD[K, V, V](self, part)
       .setKeyOrdering(if (ascending) ordering else ordering.reverse)
   }
+
+  /**
+   * Repartition the RDD according to the given partitioner and, within each resulting partition,
+   * sort records by their keys.
+   *
+   * This is more efficient than calling `repartition` and then sorting within each partition
+   * because it can push the sorting down into the shuffle machinery.
+   */
+  def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)] = {
+    new ShuffledRDD[K, V, V](self, partitioner).setKeyOrdering(ordering)
+  }
+
 }
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index e1c13de04a0be12fb753d2ff9ab5c9e2f449103f..be99dc501c4b28189b4654a79d2fbf9dc35799be 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -189,6 +189,36 @@ public class JavaAPISuite implements Serializable {
     Assert.assertEquals(new Tuple2<Integer, Integer>(3, 2), sortedPairs.get(2));
   }
 
+  @Test
+  public void repartitionAndSortWithinPartitions() {
+    List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>();
+    pairs.add(new Tuple2<Integer, Integer>(0, 5));
+    pairs.add(new Tuple2<Integer, Integer>(3, 8));
+    pairs.add(new Tuple2<Integer, Integer>(2, 6));
+    pairs.add(new Tuple2<Integer, Integer>(0, 8));
+    pairs.add(new Tuple2<Integer, Integer>(3, 8));
+    pairs.add(new Tuple2<Integer, Integer>(1, 3));
+
+    JavaPairRDD<Integer, Integer> rdd = sc.parallelizePairs(pairs);
+
+    Partitioner partitioner = new Partitioner() {
+      public int numPartitions() {
+        return 2;
+      }
+      public int getPartition(Object key) {
+        return ((Integer)key).intValue() % 2;
+      }
+    };
+
+    JavaPairRDD<Integer, Integer> repartitioned =
+        rdd.repartitionAndSortWithinPartitions(partitioner);
+    List<List<Tuple2<Integer, Integer>>> partitions = repartitioned.glom().collect();
+    Assert.assertEquals(partitions.get(0), Arrays.asList(new Tuple2<Integer, Integer>(0, 5),
+        new Tuple2<Integer, Integer>(0, 8), new Tuple2<Integer, Integer>(2, 6)));
+    Assert.assertEquals(partitions.get(1), Arrays.asList(new Tuple2<Integer, Integer>(1, 3),
+        new Tuple2<Integer, Integer>(3, 8), new Tuple2<Integer, Integer>(3, 8)));
+  }
+
   @Test
   public void emptyRDD() {
     JavaRDD<String> rdd = sc.emptyRDD();
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 499dcda3dae8fd3b8236f19e3ddc3af91a4186b4..c1b501a75c8b8e9f6c338712e3baec6d32063acc 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -682,6 +682,20 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     assert(data.sortBy(parse, true, 2)(NameOrdering, classTag[Person]).collect() === nameOrdered)
   }
 
+  test("repartitionAndSortWithinPartitions") {
+    val data = sc.parallelize(Seq((0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)), 2)
+
+    val partitioner = new Partitioner {
+      def numPartitions: Int = 2
+      def getPartition(key: Any): Int = key.asInstanceOf[Int] % 2
+    }
+
+    val repartitioned = data.repartitionAndSortWithinPartitions(partitioner)
+    val partitions = repartitioned.glom().collect()
+    assert(partitions(0) === Seq((0, 5), (0, 8), (2, 6)))
+    assert(partitions(1) === Seq((1, 3), (3, 8), (3, 8)))
+  }
+
   test("intersection") {
     val all = sc.parallelize(1 to 10)
     val evens = sc.parallelize(2 to 10 by 2)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 266090e3ae8f3ce56dc4d5c3f43f133c311b7b53..5667154cb84a8465536d0d12f878c27c1826170c 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -520,6 +520,30 @@ class RDD(object):
             raise TypeError
         return self.union(other)
 
+    def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=portable_hash,
+                                           ascending=True, keyfunc=lambda x: x):
+        """
+        Repartition the RDD according to the given partitioner and, within each resulting partition,
+        sort records by their keys.
+
+        >>> rdd = sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)])
+        >>> rdd2 = rdd.repartitionAndSortWithinPartitions(2, lambda x: x % 2, 2)
+        >>> rdd2.glom().collect()
+        [[(0, 5), (0, 8), (2, 6)], [(1, 3), (3, 8), (3, 8)]]
+        """
+        if numPartitions is None:
+            numPartitions = self._defaultReducePartitions()
+
+        spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == "true")
+        memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
+        serializer = self._jrdd_deserializer
+
+        def sortPartition(iterator):
+            sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted
+            return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending)))
+
+        return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True)
+
     def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
         """
         Sorts this RDD, which is assumed to consist of (key, value) pairs.
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 9fbeb36f4f1dd41854159111db8a0c7b00ba6bbc..0bd2a9e6c507d136999c7a591bd2a3d972cedca7 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -545,6 +545,14 @@ class TestRDDFunctions(PySparkTestCase):
         self.assertEquals(([1, "b"], [5]), rdd.histogram(1))
         self.assertRaises(TypeError, lambda: rdd.histogram(2))
 
+    def test_repartitionAndSortWithinPartitions(self):
+        rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)
+
+        repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2)
+        partitions = repartitioned.glom().collect()
+        self.assertEquals(partitions[0], [(0, 5), (0, 8), (2, 6)])
+        self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)])
+
 
 class TestSQL(PySparkTestCase):