From c2b7fd68996f9a91b4409007192badf012dd8f86 Mon Sep 17 00:00:00 2001
From: Matei Zaharia <matei@eecs.berkeley.edu>
Date: Wed, 2 Nov 2011 15:16:02 -0700
Subject: [PATCH] Make parallelize() work efficiently for ranges of Long,
 Double, etc (splitting them into sub-ranges). Fixes #87.

---
 .../main/scala/spark/ParallelCollection.scala | 23 ++++++++++---
 .../spark/ParallelCollectionSplitSuite.scala  | 34 +++++++++++++++++++
 2 files changed, 52 insertions(+), 5 deletions(-)

diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala
index b45f29091b..e96f73b3cf 100644
--- a/core/src/main/scala/spark/ParallelCollection.scala
+++ b/core/src/main/scala/spark/ParallelCollection.scala
@@ -1,6 +1,7 @@
 package spark
 
-import java.util.concurrent.atomic.AtomicLong
+import scala.collection.immutable.NumericRange
+import scala.collection.mutable.ArrayBuffer
 
 class ParallelCollectionSplit[T: ClassManifest](
     val rddId: Long, val slice: Int, values: Seq[T])
@@ -40,23 +41,35 @@ extends RDD[T](sc) {
 }
 
 private object ParallelCollection {
+  // Slice a collection into numSlices sub-collections. One extra thing we do here is
+  // to treat Range collections specially, encoding the slices as other Ranges to
+  // minimize memory cost. This makes it efficient to run Spark over RDDs representing
+  // large sets of numbers.
   def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
     if (numSlices < 1)
       throw new IllegalArgumentException("Positive number of slices required")
     seq match {
       case r: Range.Inclusive => {
         val sign = if (r.step < 0) -1 else 1
-        slice(new Range(r.start, r.end + sign, r.step).asInstanceOf[Seq[T]],
-              numSlices)
+        slice(new Range(r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices)
       }
       case r: Range => {
         (0 until numSlices).map(i => {
           val start = ((i * r.length.toLong) / numSlices).toInt
           val end = (((i+1) * r.length.toLong) / numSlices).toInt
-          new Range(
-            r.start + start * r.step, r.start + end * r.step, r.step)
+          new Range(r.start + start * r.step, r.start + end * r.step, r.step)
         }).asInstanceOf[Seq[Seq[T]]]
       }
+      case nr: NumericRange[_] => { // For ranges of Long, Double, BigInteger, etc
+        val slices = new ArrayBuffer[Seq[T]](numSlices)
+        val sliceSize = (nr.size + numSlices - 1) / numSlices // Round up to catch everything
+        var r = nr
+        for (i <- 0 until numSlices) {
+          slices += r.take(sliceSize).asInstanceOf[Seq[T]]
+          r = r.drop(sliceSize)
+        }
+        slices
+      }
       case _ => {
         val array = seq.toArray  // To prevent O(n^2) operations for List etc
         (0 until numSlices).map(i => {
diff --git a/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala b/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala
index af6ec8bae5..450c69bd58 100644
--- a/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala
+++ b/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala
@@ -1,5 +1,7 @@
 package spark
 
+import scala.collection.immutable.NumericRange
+
 import org.scalatest.FunSuite
 import org.scalatest.prop.Checkers
 import org.scalacheck.Arbitrary._
@@ -158,4 +160,36 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers {
     }
     check(prop)
   }
+  
+  test("exclusive ranges of longs") {
+    val data = 1L until 100L
+    val slices = ParallelCollection.slice(data, 3)
+    assert(slices.size === 3)
+    assert(slices.map(_.size).reduceLeft(_+_) === 99)
+    assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
+  }
+  
+  test("inclusive ranges of longs") {
+    val data = 1L to 100L
+    val slices = ParallelCollection.slice(data, 3)
+    assert(slices.size === 3)
+    assert(slices.map(_.size).reduceLeft(_+_) === 100)
+    assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
+  }
+  
+  test("exclusive ranges of doubles") {
+    val data = 1.0 until 100.0 by 1.0
+    val slices = ParallelCollection.slice(data, 3)
+    assert(slices.size === 3)
+    assert(slices.map(_.size).reduceLeft(_+_) === 99)
+    assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
+  }
+  
+  test("inclusive ranges of doubles") {
+    val data = 1.0 to 100.0 by 1.0
+    val slices = ParallelCollection.slice(data, 3)
+    assert(slices.size === 3)
+    assert(slices.map(_.size).reduceLeft(_+_) === 100)
+    assert(slices.forall(_.isInstanceOf[NumericRange[_]]))
+  }
 }
-- 
GitLab