diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 84e15fc0c818a0ccfcce646b114cfac1e29fe935..1a2ec55876c35089d3b448a078d2212cafda08f4 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -33,28 +33,26 @@ case class Aggregator[K, V, C] ( def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = { val combiners = new AppendOnlyMap[K, C] - for ((k, v) <- iter) { - combiners.changeValue(k, (hadValue, oldValue) => { - if (hadValue) { - mergeValue(oldValue, v) - } else { - createCombiner(v) - } - }) + var kv: Product2[K, V] = null + val update = (hadValue: Boolean, oldValue: C) => { + if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) + } + while (iter.hasNext) { + kv = iter.next() + combiners.changeValue(kv._1, update) } combiners.iterator } def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = { val combiners = new AppendOnlyMap[K, C] - for ((k, c) <- iter) { - combiners.changeValue(k, (hadValue, oldValue) => { - if (hadValue) { - mergeCombiners(oldValue, c) - } else { - c - } - }) + var kc: (K, C) = null + val update = (hadValue: Boolean, oldValue: C) => { + if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 + } + while (iter.hasNext) { + kc = iter.next() + combiners.changeValue(kc._1, update) } combiners.iterator } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index f41a023bc12f1db63471e6d80417c8437d1aaf86..d237797aa60375ff5645ae5a7b2b1358f1a79b7e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -106,10 +106,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs) val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]] - def getSeq(k: K): Seq[ArrayBuffer[Any]] = { - map.changeValue(k, (hadValue, oldValue) => { - if (hadValue) oldValue else Array.fill(numRdds)(new ArrayBuffer[Any]) - }) + val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => { + if (hadVal) oldVal else Array.fill(numRdds)(new ArrayBuffer[Any]) + } + + val getSeq = (k: K) => { + map.changeValue(k, update) } val ser = SparkEnv.get.serializerManager.get(serializerClass)