Skip to content
Snippets Groups Projects
Commit 22b8fcf6 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Added fold() and aggregate() operations that reuse an object to

merge results into rather than requiring a new object allocation
for each element merged. Fixes #95.
parent 09dd58b3
No related branches found
No related tags found
No related merge requests found
......@@ -133,7 +133,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
val cleanF = sc.clean(f)
val reducePartition: Iterator[T] => Option[T] = iter => {
if (iter.hasNext)
Some(iter.reduceLeft(f))
Some(iter.reduceLeft(cleanF))
else
None
}
......@@ -144,7 +144,36 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
if (results.size == 0)
throw new UnsupportedOperationException("empty collection")
else
return results.reduceLeft(f)
return results.reduceLeft(cleanF)
}
/**
* Aggregate the elements of each partition, and then the results for all the
* partitions, using a given associative function and a neutral "zero value".
* The function op(t1, t2) is allowed to modify t1 and return it as its result
* value to avoid object allocation; however, it should not modify t2.
*/
def fold(zeroValue: T)(op: (T, T) => T): T = {
val cleanOp = sc.clean(op)
val results = sc.runJob(this, (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp))
return results.fold(zeroValue)(cleanOp)
}
/**
* Aggregate the elements of each partition, and then the results for all the
* partitions, using given combine functions and a neutral "zero value". This
* function can return a different result type, U, than the type of this RDD, T.
* Thus, we need one operation for merging a T into an U and one operation for
* merging two U's, as in scala.TraversableOnce. Both of these functions are
* allowed to modify and return their first argument instead of creating a new U
* to avoid memory allocation.
*/
def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = {
val cleanSeqOp = sc.clean(seqOp)
val cleanCombOp = sc.clean(combOp)
val results = sc.runJob(this,
(iter: Iterator[T]) => iter.aggregate(zeroValue)(cleanSeqOp, cleanCombOp))
return results.fold(zeroValue)(cleanCombOp)
}
def count(): Long = {
......
package spark
import scala.collection.mutable.HashMap
import org.scalatest.FunSuite
import SparkContext._
......@@ -9,6 +10,7 @@ class RDDSuite extends FunSuite {
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
assert(nums.reduce(_ + _) === 10)
assert(nums.fold(0)(_ + _) === 10)
assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4"))
assert(nums.filter(_ > 2).collect().toList === List(3, 4))
assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4))
......@@ -18,4 +20,26 @@ class RDDSuite extends FunSuite {
assert(partitionSums.collect().toList === List(3, 7))
sc.stop()
}
test("aggregate") {
val sc = new SparkContext("local", "test")
val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3)))
type StringMap = HashMap[String, Int]
val emptyMap = new StringMap {
override def default(key: String): Int = 0
}
val mergeElement: (StringMap, (String, Int)) => StringMap = (map, pair) => {
map(pair._1) += pair._2
map
}
val mergeMaps: (StringMap, StringMap) => StringMap = (map1, map2) => {
for ((key, value) <- map2) {
map1(key) += value
}
map1
}
val result = pairs.aggregate(emptyMap)(mergeElement, mergeMaps)
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
sc.stop()
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment