From c2b7fd68996f9a91b4409007192badf012dd8f86 Mon Sep 17 00:00:00 2001 From: Matei Zaharia <matei@eecs.berkeley.edu> Date: Wed, 2 Nov 2011 15:16:02 -0700 Subject: [PATCH] Make parallelize() work efficiently for ranges of Long, Double, etc (splitting them into sub-ranges). Fixes #87. --- .../main/scala/spark/ParallelCollection.scala | 23 ++++++++++--- .../spark/ParallelCollectionSplitSuite.scala | 34 +++++++++++++++++++ 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index b45f29091b..e96f73b3cf 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -1,6 +1,7 @@ package spark -import java.util.concurrent.atomic.AtomicLong +import scala.collection.immutable.NumericRange +import scala.collection.mutable.ArrayBuffer class ParallelCollectionSplit[T: ClassManifest]( val rddId: Long, val slice: Int, values: Seq[T]) @@ -40,23 +41,35 @@ extends RDD[T](sc) { } private object ParallelCollection { + // Slice a collection into numSlices sub-collections. One extra thing we do here is + // to treat Range collections specially, encoding the slices as other Ranges to + // minimize memory cost. This makes it efficient to run Spark over RDDs representing + // large sets of numbers. def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { if (numSlices < 1) throw new IllegalArgumentException("Positive number of slices required") seq match { case r: Range.Inclusive => { val sign = if (r.step < 0) -1 else 1 - slice(new Range(r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], - numSlices) + slice(new Range(r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices) } case r: Range => { (0 until numSlices).map(i => { val start = ((i * r.length.toLong) / numSlices).toInt val end = (((i+1) * r.length.toLong) / numSlices).toInt - new Range( - r.start + start * r.step, r.start + end * r.step, r.step) + new Range(r.start + start * r.step, r.start + end * r.step, r.step) }).asInstanceOf[Seq[Seq[T]]] } + case nr: NumericRange[_] => { // For ranges of Long, Double, BigInteger, etc + val slices = new ArrayBuffer[Seq[T]](numSlices) + val sliceSize = (nr.size + numSlices - 1) / numSlices // Round up to catch everything + var r = nr + for (i <- 0 until numSlices) { + slices += r.take(sliceSize).asInstanceOf[Seq[T]] + r = r.drop(sliceSize) + } + slices + } case _ => { val array = seq.toArray // To prevent O(n^2) operations for List etc (0 until numSlices).map(i => { diff --git a/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala b/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala index af6ec8bae5..450c69bd58 100644 --- a/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/spark/ParallelCollectionSplitSuite.scala @@ -1,5 +1,7 @@ package spark +import scala.collection.immutable.NumericRange + import org.scalatest.FunSuite import org.scalatest.prop.Checkers import org.scalacheck.Arbitrary._ @@ -158,4 +160,36 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { } check(prop) } + + test("exclusive ranges of longs") { + val data = 1L until 100L + val slices = ParallelCollection.slice(data, 3) + assert(slices.size === 3) + assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.forall(_.isInstanceOf[NumericRange[_]])) + } + + test("inclusive ranges of longs") { + val data = 1L to 100L + val slices = ParallelCollection.slice(data, 3) + assert(slices.size === 3) + assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.forall(_.isInstanceOf[NumericRange[_]])) + } + + test("exclusive ranges of doubles") { + val data = 1.0 until 100.0 by 1.0 + val slices = ParallelCollection.slice(data, 3) + assert(slices.size === 3) + assert(slices.map(_.size).reduceLeft(_+_) === 99) + assert(slices.forall(_.isInstanceOf[NumericRange[_]])) + } + + test("inclusive ranges of doubles") { + val data = 1.0 to 100.0 by 1.0 + val slices = ParallelCollection.slice(data, 3) + assert(slices.size === 3) + assert(slices.map(_.size).reduceLeft(_+_) === 100) + assert(slices.forall(_.isInstanceOf[NumericRange[_]])) + } } -- GitLab