diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala index d1e65edcda430dd308c0d5307e708eb54873d18e..ac61fe3b54526da22a0d812a485da167651a686e 100644 --- a/core/src/main/scala/spark/Partitioner.scala +++ b/core/src/main/scala/spark/Partitioner.scala @@ -25,34 +25,49 @@ class HashPartitioner(partitions: Int) extends Partitioner { } } -class RangePartitioner[K <% Ordered[K],V](partitions: Int, rdd: RDD[(K,V)], ascending: Boolean = true) +class RangePartitioner[K <% Ordered[K]: ClassManifest, V]( + partitions: Int, rdd: RDD[(K,V)], + ascending: Boolean = true) extends Partitioner { - def numPartitions = partitions + private val rangeBounds: Array[K] = { + val rddSize = rdd.count() + val maxSampleSize = partitions * 10.0 + val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0) + val rddSample = rdd.sample(true, frac, 1).map(_._1).collect() + .sortWith((x, y) => if (ascending) x < y else x > y) + if (rddSample.length == 0) { + Array() + } else { + val bounds = new Array[K](partitions) + for (i <- 0 until partitions) { + bounds(i) = rddSample(i * rddSample.length / partitions) + } + bounds + } + } + + def numPartitions = rangeBounds.length - val rddSize = rdd.count() - val maxSampleSize = partitions*10.0 - val frac = 1.0.min(maxSampleSize / rddSize) - val rddSample = rdd.sample(true, frac, 1).collect.toList - .sortWith((x, y) => if (ascending) x._1 < y._1 else x._1 > y._1) - .map(_._1) - val bucketSize = rddSample.size / partitions - val rangeBounds = rddSample.zipWithIndex.filter(_._2 % bucketSize == 0) - .map(_._1).slice(1, partitions) - - def getPartition(key: Any): Int = { + def getPartition(key: Any): Int = { + // TODO: Use a binary search here if number of partitions is large val k = key.asInstanceOf[K] - val p = rangeBounds.zipWithIndex.foldLeft(0) { - case (part, (bound, index)) => - if (k > bound) index + 1 else part - } - if (ascending) p else numPartitions-1-p + var partition = 0 + while (partition < rangeBounds.length - 1 && k > rangeBounds(partition)) { + partition += 1 + } + if (ascending) { + partition + } else { + rangeBounds.length - 1 - partition + } } override def equals(other: Any): Boolean = other match { case r: RangePartitioner[_,_] => - r.numPartitions == numPartitions & r.rangeBounds == rangeBounds - case _ => false + r.rangeBounds.sameElements(rangeBounds) + case _ => + false } } diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala index 0d1f229eda32bc38a60aee7ec40cd7881e966091..caff8849661aac1b94b94a7215ed14b25614e17e 100644 --- a/core/src/test/scala/spark/SortingSuite.scala +++ b/core/src/test/scala/spark/SortingSuite.scala @@ -29,11 +29,20 @@ class SortingSuite extends FunSuite { sc.stop() } - test("sortHighParallelism") { + test("morePartitionsThanElements") { val sc = new SparkContext("local", "test") val rand = new scala.util.Random() - val pairArr = Array.fill(3000) { (rand.nextInt(), rand.nextInt()) } - val pairs = sc.parallelize(pairArr, 300) + val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) } + val pairs = sc.parallelize(pairArr, 30) + assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) + sc.stop() + } + + test("emptyRDD") { + val sc = new SparkContext("local", "test") + val rand = new scala.util.Random() + val pairArr = new Array[(Int, Int)](0) + val pairs = sc.parallelize(pairArr) assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1)) sc.stop() }