diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index 144f679a594606ff8f2c8c611456aba84360a3f5..6fdfdb734d1b8b3fb6e0e28c3d28be8a8d2c2565 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -75,4 +75,27 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, new ShuffledRDD[K, V, V](self, partitioner).setKeyOrdering(ordering) } + /** + * Returns an RDD containing only the elements in the the inclusive range `lower` to `upper`. + * If the RDD has been partitioned using a `RangePartitioner`, then this operation can be + * performed efficiently by only scanning the partitions that might contain matching elements. + * Otherwise, a standard `filter` is applied to all partitions. + */ + def filterByRange(lower: K, upper: K): RDD[P] = { + + def inRange(k: K): Boolean = ordering.gteq(k, lower) && ordering.lteq(k, upper) + + val rddToFilter: RDD[P] = self.partitioner match { + case Some(rp: RangePartitioner[K, V]) => { + val partitionIndicies = (rp.getPartition(lower), rp.getPartition(upper)) match { + case (l, u) => Math.min(l, u) to Math.max(l, u) + } + PartitionPruningRDD.create(self, partitionIndicies.contains) + } + case _ => + self + } + rddToFilter.filter { case (k, v) => inRange(k) } + } + } diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index a40f2ffeffdf97f8bd0606c0e2f28dd9717f4a2e..64b1c24c47168c8147b0560ad1fc31912efa0fd1 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -119,5 +119,33 @@ class SortingSuite extends FunSuite with SharedSparkContext with Matchers with L partitions(1).last should be > partitions(2).head partitions(2).last should be > partitions(3).head } + + test("get a range of elements in a sorted RDD that is on one partition") { + val pairArr = (1 to 1000).map(x => (x, x)).toArray + val sorted = sc.parallelize(pairArr, 10).sortByKey() + val range = sorted.filterByRange(20, 40).collect() + assert((20 to 40).toArray === range.map(_._1)) + } + + test("get a range of elements over multiple partitions in a descendingly sorted RDD") { + val pairArr = (1000 to 1 by -1).map(x => (x, x)).toArray + val sorted = sc.parallelize(pairArr, 10).sortByKey(false) + val range = sorted.filterByRange(200, 800).collect() + assert((800 to 200 by -1).toArray === range.map(_._1)) + } + + test("get a range of elements in an array not partitioned by a range partitioner") { + val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) + val pairs = sc.parallelize(pairArr,10) + val range = pairs.filterByRange(200, 800).collect() + assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted) + } + + test("get a range of elements over multiple partitions but not taking up full partitions") { + val pairArr = (1000 to 1 by -1).map(x => (x, x)).toArray + val sorted = sc.parallelize(pairArr, 10).sortByKey(false) + val range = sorted.filterByRange(250, 850).collect() + assert((850 to 250 by -1).toArray === range.map(_._1)) + } }