Skip to content
Snippets Groups Projects
Commit 1ef4f0fb authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Allow controlling number of splits in sortByKey.

parent 874a9fd4
No related branches found
No related tags found
No related merge requests found
......@@ -435,8 +435,8 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
extends Logging
with Serializable {
def sortByKey(ascending: Boolean = true): RDD[(K,V)] = {
new ShuffledSortedRDD(self, ascending)
def sortByKey(ascending: Boolean = true, numSplits: Int = self.splits.size): RDD[(K,V)] = {
new ShuffledSortedRDD(self, ascending, numSplits)
}
}
......
......@@ -16,7 +16,7 @@ class ShuffledRDDSplit(val idx: Int) extends Split {
abstract class ShuffledRDD[K, V, C](
@transient parent: RDD[(K, V)],
aggregator: Aggregator[K, V, C],
part : Partitioner)
part: Partitioner)
extends RDD[(K, C)](parent.context) {
override val partitioner = Some(part)
......@@ -38,7 +38,7 @@ abstract class ShuffledRDD[K, V, C](
*/
class RepartitionShuffledRDD[K, V](
@transient parent: RDD[(K, V)],
part : Partitioner)
part: Partitioner)
extends ShuffledRDD[K, V, V](
parent,
Aggregator[K, V, V](null, null, null, false),
......@@ -60,10 +60,11 @@ class RepartitionShuffledRDD[K, V](
*/
class ShuffledSortedRDD[K <% Ordered[K]: ClassManifest, V](
@transient parent: RDD[(K, V)],
ascending: Boolean)
ascending: Boolean,
numSplits: Int)
extends RepartitionShuffledRDD[K, V](
parent,
new RangePartitioner(parent.splits.size, parent, ascending)) {
new RangePartitioner(numSplits, parent, ascending)) {
override def compute(split: Split): Iterator[(K, V)] = {
// By separating this from RepartitionShuffledRDD, we avoided a
......
......@@ -42,7 +42,6 @@ class Client(
val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
try {
master = context.actorFor(akkaUrl)
//master ! RegisterWorker(ip, port, cores, memory)
master ! RegisterJob(jobDescription)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(master) // Doesn't work with remote actors, but useful for testing
......
......@@ -17,7 +17,7 @@ class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with
test("sortByKey") {
sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)))
val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2)
assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
}
......@@ -25,18 +25,56 @@ class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr)
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
val pairs = sc.parallelize(pairArr, 2)
val sorted = pairs.sortByKey()
assert(sorted.splits.size === 2)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
test("large array with one split") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
val sorted = pairs.sortByKey(true, 1)
assert(sorted.splits.size === 1)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
test("large array with many splits") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
val sorted = pairs.sortByKey(true, 20)
assert(sorted.splits.size === 20)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
test("sort descending") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr)
val pairs = sc.parallelize(pairArr, 2)
assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
}
test("sort descending with one split") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 1)
assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
}
test("sort descending with many splits") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
assert(pairs.sortByKey(false, 20).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
}
test("more partitions than elements") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
......@@ -48,7 +86,7 @@ class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with
test("empty RDD") {
sc = new SparkContext("local", "test")
val pairArr = new Array[(Int, Int)](0)
val pairs = sc.parallelize(pairArr)
val pairs = sc.parallelize(pairArr, 2)
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment