From d71a358c4619037053e7f723568282f52c2ad2e0 Mon Sep 17 00:00:00 2001
From: Matei Zaharia <matei@eecs.berkeley.edu>
Date: Wed, 26 Sep 2012 00:25:34 -0700
Subject: [PATCH] Fixed a test that was getting extremely lucky before, and
 increased the number of samples used for sorting

---
 core/src/main/scala/spark/Partitioner.scala  |  4 ++--
 core/src/main/scala/spark/ShuffledRDD.scala  |  6 +++++-
 core/src/test/scala/spark/SortingSuite.scala | 18 +++++++++---------
 3 files changed, 16 insertions(+), 12 deletions(-)

diff --git a/core/src/main/scala/spark/Partitioner.scala b/core/src/main/scala/spark/Partitioner.scala
index 643541429f..20c31714ae 100644
--- a/core/src/main/scala/spark/Partitioner.scala
+++ b/core/src/main/scala/spark/Partitioner.scala
@@ -41,9 +41,9 @@ class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
       Array()
     } else {
       val rddSize = rdd.count()
-      val maxSampleSize = partitions * 10.0
+      val maxSampleSize = partitions * 20.0
       val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
-      val rddSample = rdd.sample(true, frac, 1).map(_._1).collect().sortWith(_ < _)
+      val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sortWith(_ < _)
       if (rddSample.length == 0) {
         Array()
       } else {
diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala
index a7346060b3..be75890a40 100644
--- a/core/src/main/scala/spark/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/ShuffledRDD.scala
@@ -71,7 +71,11 @@ class ShuffledSortedRDD[K <% Ordered[K]: ClassManifest, V](
     val buf = new ArrayBuffer[(K, V)]
     def addTupleToBuffer(k: K, v: V) = { buf += Tuple(k, v) }
     SparkEnv.get.shuffleFetcher.fetch[K, V](dep.shuffleId, split.index, addTupleToBuffer)
-    buf.sortWith((x, y) => if (ascending) x._1 < y._1 else x._1 > y._1).iterator
+    if (ascending) {
+      buf.sortWith((x, y) => x._1 < y._1).iterator
+    } else {
+      buf.sortWith((x, y) => x._1 > y._1).iterator
+    }
   }
 }
 
diff --git a/core/src/test/scala/spark/SortingSuite.scala b/core/src/test/scala/spark/SortingSuite.scala
index 8fa1442a4d..188a9b564e 100644
--- a/core/src/test/scala/spark/SortingSuite.scala
+++ b/core/src/test/scala/spark/SortingSuite.scala
@@ -58,11 +58,11 @@ class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with
     val sorted = sc.parallelize(pairArr, 4).sortByKey()
     assert(sorted.collect() === pairArr.sortBy(_._1))
     val partitions = sorted.collectPartitions()
-    logInfo("partition lengths: " + partitions.map(_.length).mkString(", "))
-    partitions(0).length should be > 200
-    partitions(1).length should be > 200
-    partitions(2).length should be > 200
-    partitions(3).length should be > 200
+    logInfo("Partition lengths: " + partitions.map(_.length).mkString(", "))
+    partitions(0).length should be > 180
+    partitions(1).length should be > 180
+    partitions(2).length should be > 180
+    partitions(3).length should be > 180
     partitions(0).last should be < partitions(1).head
     partitions(1).last should be < partitions(2).head
     partitions(2).last should be < partitions(3).head
@@ -75,10 +75,10 @@ class SortingSuite extends FunSuite with BeforeAndAfter with ShouldMatchers with
     assert(sorted.collect() === pairArr.sortBy(_._1).reverse)
     val partitions = sorted.collectPartitions()
     logInfo("partition lengths: " + partitions.map(_.length).mkString(", "))
-    partitions(0).length should be > 200
-    partitions(1).length should be > 200
-    partitions(2).length should be > 200
-    partitions(3).length should be > 200
+    partitions(0).length should be > 180
+    partitions(1).length should be > 180
+    partitions(2).length should be > 180
+    partitions(3).length should be > 180
     partitions(0).last should be > partitions(1).head
     partitions(1).last should be > partitions(2).head
     partitions(2).last should be > partitions(3).head
-- 
GitLab