diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala
index d1e65edcda430dd308c0d5307e708eb54873d18e..ac61fe3b54526da22a0d812a485da167651a686e 100644
--- a/core/src/main/scala/spark/Partitioner.scala
+++ b/core/src/main/scala/spark/Partitioner.scala
@@ -25,34 +25,49 @@ class HashPartitioner(partitions: Int) extends Partitioner {
   }
 }
 
-class RangePartitioner[K <% Ordered[K],V](partitions: Int, rdd: RDD[(K,V)], ascending: Boolean = true) 
+class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
+    partitions: Int, rdd: RDD[(K,V)],
+    ascending: Boolean = true) 
   extends Partitioner {
 
-  def numPartitions = partitions
+  private val rangeBounds: Array[K] = {
+    val rddSize = rdd.count()
+    val maxSampleSize = partitions * 10.0
+    val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
+    val rddSample = rdd.sample(true, frac, 1).map(_._1).collect()
+      .sortWith((x, y) => if (ascending) x < y else x > y)
+    if (rddSample.length == 0) {
+      Array()
+    } else {
+      val bounds = new Array[K](partitions)
+      for (i <- 0 until partitions) {
+        bounds(i) = rddSample(i * rddSample.length / partitions)
+      }
+      bounds
+    }
+  }
+
+  def numPartitions = rangeBounds.length
 
-  val rddSize = rdd.count()
-  val maxSampleSize = partitions*10.0
-  val frac = 1.0.min(maxSampleSize / rddSize)
-  val rddSample = rdd.sample(true, frac, 1).collect.toList
-    .sortWith((x, y) => if (ascending) x._1 < y._1 else x._1 > y._1)
-    .map(_._1)
-  val bucketSize = rddSample.size / partitions
-  val rangeBounds = rddSample.zipWithIndex.filter(_._2 % bucketSize == 0)
-    .map(_._1).slice(1, partitions)
-
-  def getPartition(key: Any): Int = { 
+  def getPartition(key: Any): Int = {
+    // TODO: Use a binary search here if number of partitions is large
     val k = key.asInstanceOf[K]
-    val p = rangeBounds.zipWithIndex.foldLeft(0) {
-        case (part, (bound, index)) =>  
-          if (k > bound) index + 1 else part
-      }   
-    if (ascending) p else numPartitions-1-p
+    var partition = 0
+    while (partition < rangeBounds.length - 1 && k > rangeBounds(partition)) {
+      partition += 1
+    }
+    if (ascending) {
+      partition
+    } else {
+      rangeBounds.length - 1 - partition
+    }
   }
 
   override def equals(other: Any): Boolean = other match {
     case r: RangePartitioner[_,_] =>
-      r.numPartitions == numPartitions & r.rangeBounds == rangeBounds
-    case _ => false
+      r.rangeBounds.sameElements(rangeBounds)
+    case _ =>
+      false
   }
 }
 
diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala
index 0d1f229eda32bc38a60aee7ec40cd7881e966091..caff8849661aac1b94b94a7215ed14b25614e17e 100644
--- a/core/src/test/scala/spark/SortingSuite.scala
+++ b/core/src/test/scala/spark/SortingSuite.scala
@@ -29,11 +29,20 @@ class SortingSuite extends FunSuite {
       sc.stop()
   }
 
-  test("sortHighParallelism") {
+  test("morePartitionsThanElements") {
       val sc = new SparkContext("local", "test")
       val rand = new scala.util.Random()
-      val pairArr = Array.fill(3000) { (rand.nextInt(), rand.nextInt()) }
-      val pairs = sc.parallelize(pairArr, 300)
+      val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) }
+      val pairs = sc.parallelize(pairArr, 30)
+      assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
+      sc.stop()
+  }
+
+  test("emptyRDD") {
+      val sc = new SparkContext("local", "test")
+      val rand = new scala.util.Random()
+      val pairArr = new Array[(Int, Int)](0)
+      val pairs = sc.parallelize(pairArr)
       assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
       sc.stop()
   }