From 1ef4f0fbd27e54803f14fed1df541fb341daced8 Mon Sep 17 00:00:00 2001
From: Matei Zaharia <matei@eecs.berkeley.edu>
Date: Wed, 26 Sep 2012 19:18:47 -0700
Subject: [PATCH] Allow controlling number of splits in sortByKey.

---
 .../main/scala/spark/PairRDDFunctions.scala   |  4 +-
 core/src/main/scala/spark/ShuffledRDD.scala   |  9 ++--
 .../scala/spark/deploy/client/Client.scala    |  1 -
 core/src/test/scala/spark/SortingSuite.scala  | 48 +++++++++++++++++--
 4 files changed, 50 insertions(+), 12 deletions(-)

diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala
index aa1d00c63c..4752bf8d9f 100644
--- a/core/src/main/scala/spark/PairRDDFunctions.scala
+++ b/core/src/main/scala/spark/PairRDDFunctions.scala
@@ -435,8 +435,8 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
   extends Logging
   with Serializable {
 
-  def sortByKey(ascending: Boolean = true): RDD[(K,V)] = {
-    new ShuffledSortedRDD(self, ascending)
+  def sortByKey(ascending: Boolean = true, numSplits: Int = self.splits.size): RDD[(K,V)] = {
+    new ShuffledSortedRDD(self, ascending, numSplits)
   }
 }
 
diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala
index be75890a40..7c11925f86 100644
--- a/core/src/main/scala/spark/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/ShuffledRDD.scala
@@ -16,7 +16,7 @@ class ShuffledRDDSplit(val idx: Int) extends Split {
 abstract class ShuffledRDD[K, V, C](
     @transient parent: RDD[(K, V)],
     aggregator: Aggregator[K, V, C],
-    part : Partitioner)
+    part: Partitioner)
   extends RDD[(K, C)](parent.context) {
 
   override val partitioner = Some(part)
@@ -38,7 +38,7 @@ abstract class ShuffledRDD[K, V, C](
  */
 class RepartitionShuffledRDD[K, V](
     @transient parent: RDD[(K, V)],
-    part : Partitioner)
+    part: Partitioner)
   extends ShuffledRDD[K, V, V](
     parent,
     Aggregator[K, V, V](null, null, null, false),
@@ -60,10 +60,11 @@ class RepartitionShuffledRDD[K, V](
  */
 class ShuffledSortedRDD[K <% Ordered[K]: ClassManifest, V](
     @transient parent: RDD[(K, V)],
-    ascending: Boolean)
+    ascending: Boolean,
+    numSplits: Int)
   extends RepartitionShuffledRDD[K, V](
     parent,
-    new RangePartitioner(parent.splits.size, parent, ascending)) {
+    new RangePartitioner(numSplits, parent, ascending)) {
 
   override def compute(split: Split): Iterator[(K, V)] = {
     // By separating this from RepartitionShuffledRDD, we avoided a
diff --git a/core/src/main/scala/spark/deploy/client/Client.scala b/core/src/main/scala/spark/deploy/client/Client.scala
index c7fa8a3874..a2f88fc5e5 100644
--- a/core/src/main/scala/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/spark/deploy/client/Client.scala
@@ -42,7 +42,6 @@ class Client(
       val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
       try {
         master = context.actorFor(akkaUrl)
-        //master ! RegisterWorker(ip, port, cores, memory)
         master ! RegisterJob(jobDescription)
         context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
         context.watch(master)  // Doesn't work with remote actors, but useful for testing
diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala
index 188a9b564e..c87595ecb3 100644
--- a/core/src/test/scala/spark/SortingSuite.scala
+++ b/core/src/test/scala/spark/SortingSuite.scala
@@ -17,7 +17,7 @@ class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with
   
   test("sortByKey") {
     sc = new SparkContext("local", "test")
-    val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)))
+    val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2)
     assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))      
   }
 
@@ -25,18 +25,56 @@ class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with
     sc = new SparkContext("local", "test")
     val rand = new scala.util.Random()
     val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
-    val pairs = sc.parallelize(pairArr)
-    assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
+    val pairs = sc.parallelize(pairArr, 2)
+    val sorted = pairs.sortByKey()
+    assert(sorted.splits.size === 2)
+    assert(sorted.collect() === pairArr.sortBy(_._1))
   }
 
+  test("large array with one split") {
+    sc = new SparkContext("local", "test")
+    val rand = new scala.util.Random()
+    val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
+    val pairs = sc.parallelize(pairArr, 2)
+    val sorted = pairs.sortByKey(true, 1)
+    assert(sorted.splits.size === 1)
+    assert(sorted.collect() === pairArr.sortBy(_._1))
+  }
+  
+  test("large array with many splits") {
+    sc = new SparkContext("local", "test")
+    val rand = new scala.util.Random()
+    val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
+    val pairs = sc.parallelize(pairArr, 2)
+    val sorted = pairs.sortByKey(true, 20)
+    assert(sorted.splits.size === 20)
+    assert(sorted.collect() === pairArr.sortBy(_._1))
+  }
+  
   test("sort descending") {
     sc = new SparkContext("local", "test")
     val rand = new scala.util.Random()
     val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
-    val pairs = sc.parallelize(pairArr)
+    val pairs = sc.parallelize(pairArr, 2)
     assert(pairs.sortByKey(false).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
   }
 
+  test("sort descending with one split") {
+    sc = new SparkContext("local", "test")
+    val rand = new scala.util.Random()
+    val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
+    val pairs = sc.parallelize(pairArr, 1)
+    assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
+  }
+  
+  test("sort descending with many splits") {
+    sc = new SparkContext("local", "test")
+    val rand = new scala.util.Random()
+    val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
+    val pairs = sc.parallelize(pairArr, 2)
+    assert(pairs.sortByKey(false, 20).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
+  }
+
   test("more partitions than elements") {
     sc = new SparkContext("local", "test")
     val rand = new scala.util.Random()
@@ -48,7 +86,7 @@ class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with
   test("empty RDD") {
     sc = new SparkContext("local", "test")
     val pairArr = new Array[(Int, Int)](0)
-    val pairs = sc.parallelize(pairArr)
+    val pairs = sc.parallelize(pairArr, 2)
     assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
   }
 
-- 
GitLab