Skip to content
Snippets Groups Projects
Commit 78ffe164 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Clone the zero value for each key in foldByKey

The old version reused the object within each task, leading to
overwriting of the object when a mutable type is used, which is expected
to be common in fold.

Conflicts:

	core/src/test/scala/spark/ShuffleSuite.scala
parent 0e0f9d30
No related branches found
No related tags found
No related merge requests found
package spark
import java.nio.ByteBuffer
import java.util.{Date, HashMap => JHashMap}
import java.text.SimpleDateFormat
......@@ -64,8 +65,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
throw new SparkException("Default partitioner cannot partition array keys.")
}
}
val aggregator =
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (self.partitioner == Some(partitioner)) {
self.mapPartitions(aggregator.combineValuesByKey(_), true)
} else if (mapSideCombine) {
......@@ -97,7 +97,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* list concatenation, 0 for addition, or 1 for multiplication.).
*/
def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = {
combineByKey[V]({v: V => func(zeroValue, v)}, func, func, partitioner)
// Serialize the zero value to a byte array so that we can get a new clone of it on each key
val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue)
val zeroArray = new Array[Byte](zeroBuffer.limit)
zeroBuffer.get(zeroArray)
// When deserializing, use a lazy val to create just one instance of the serializer per task
lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance()
def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray))
combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner)
}
/**
......
......@@ -392,6 +392,28 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
assert(nonEmptyBlocks.size <= 4)
}
test("foldByKey") {
sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.foldByKey(0)(_+_).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
}
test("foldByKey with mutable result type") {
sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache()
// Fold the values using in-place mutation
val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect()
assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1))))
// Check that the mutable objects in the original RDD were not changed
assert(bufs.collect().toSet === Set(
(1, ArrayBuffer(1)),
(1, ArrayBuffer(2)),
(1, ArrayBuffer(3)),
(1, ArrayBuffer(1)),
(2, ArrayBuffer(1))))
}
}
object ShuffleSuite {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment