diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 7dfa17f68a943749d7a6098ca6a4b70fc91d0b91..3325b65f8b60039bc6c60701a228665bec672219 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -608,7 +608,7 @@ class RDD(object):
         sort records by their keys.
 
         >>> rdd = sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)])
-        >>> rdd2 = rdd.repartitionAndSortWithinPartitions(2, lambda x: x % 2, 2)
+        >>> rdd2 = rdd.repartitionAndSortWithinPartitions(2, lambda x: x % 2, True)
         >>> rdd2.glom().collect()
         [[(0, 5), (0, 8), (2, 6)], [(1, 3), (3, 8), (3, 8)]]
         """
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index bb13de563cdd4d167a367d62716b2bc11afe5d9d..73ab442dfd791f2ae7f3de78adba48b2a3cb7393 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1019,14 +1019,22 @@ class RDDTests(ReusedPySparkTestCase):
         self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1))
         self.assertRaises(TypeError, lambda: rdd.histogram(2))
 
-    def test_repartitionAndSortWithinPartitions(self):
+    def test_repartitionAndSortWithinPartitions_asc(self):
         rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)
 
-        repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2)
+        repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, True)
         partitions = repartitioned.glom().collect()
         self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)])
         self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)])
 
+    def test_repartitionAndSortWithinPartitions_desc(self):
+        rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)
+
+        repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, False)
+        partitions = repartitioned.glom().collect()
+        self.assertEqual(partitions[0], [(2, 6), (0, 5), (0, 8)])
+        self.assertEqual(partitions[1], [(3, 8), (3, 8), (1, 3)])
+
     def test_repartition_no_skewed(self):
         num_partitions = 20
         a = self.sc.parallelize(range(int(1000)), 2)