diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 756e8f35fb03d46ac1cd0b8c65354521912ed5a4..3934bdda0a466dada0c44f6499e3a4a4e510ce38 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -30,6 +30,7 @@ from tempfile import NamedTemporaryFile
 from threading import Thread
 import warnings
 import heapq
+import bisect
 from random import Random
 from math import sqrt, log
 
@@ -574,6 +575,8 @@ class RDD(object):
         # noqa
 
         >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
+        >>> sc.parallelize(tmp).sortByKey(True, 1).collect()
+        [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
         >>> sc.parallelize(tmp).sortByKey(True, 2).collect()
         [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
         >>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)]
@@ -584,42 +587,40 @@ class RDD(object):
         if numPartitions is None:
             numPartitions = self._defaultReducePartitions()
 
-        bounds = list()
+        if numPartitions == 1:
+            if self.getNumPartitions() > 1:
+                self = self.coalesce(1)
+
+            def sort(iterator):
+                return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
+
+            return self.mapPartitions(sort)
 
         # first compute the boundary of each part via sampling: we want to partition
         # the key-space into bins such that the bins have roughly the same
         # number of (key, value) pairs falling into them
-        if numPartitions > 1:
-            rddSize = self.count()
-            # constant from Spark's RangePartitioner
-            maxSampleSize = numPartitions * 20.0
-            fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
-
-            samples = self.sample(False, fraction, 1).map(
-                lambda (k, v): k).collect()
-            samples = sorted(samples, reverse=(not ascending), key=keyfunc)
-
-            # we have numPartitions many parts but one of the them has
-            # an implicit boundary
-            for i in range(0, numPartitions - 1):
-                index = (len(samples) - 1) * (i + 1) / numPartitions
-                bounds.append(samples[index])
+        rddSize = self.count()
+        maxSampleSize = numPartitions * 20.0  # constant from Spark's RangePartitioner
+        fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
+        samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect()
+        samples = sorted(samples, reverse=(not ascending), key=keyfunc)
+
+        # we have numPartitions many parts but one of the them has
+        # an implicit boundary
+        bounds = [samples[len(samples) * (i + 1) / numPartitions]
+                  for i in range(0, numPartitions - 1)]
 
         def rangePartitionFunc(k):
-            p = 0
-            while p < len(bounds) and keyfunc(k) > bounds[p]:
-                p += 1
+            p = bisect.bisect_left(bounds, keyfunc(k))
             if ascending:
                 return p
             else:
                 return numPartitions - 1 - p
 
         def mapFunc(iterator):
-            yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
+            return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
 
-        return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc)
-                    .mapPartitions(mapFunc, preservesPartitioning=True)
-                    .flatMap(lambda x: x, preservesPartitioning=True))
+        return self.partitionBy(numPartitions, rangePartitionFunc).mapPartitions(mapFunc, True)
 
     def sortBy(self, keyfunc, ascending=True, numPartitions=None):
         """