diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala index 594dbd235fb4b0de38429f299ce38678f0c7f347..8293048caa5d7c43db343b620870cea0f7e96ecd 100644 --- a/core/src/main/scala/spark/ShuffledRDD.scala +++ b/core/src/main/scala/spark/ShuffledRDD.scala @@ -27,16 +27,36 @@ class ShuffledRDD[K, V, C]( override def compute(split: Split): Iterator[(K, C)] = { val combiners = new JHashMap[K, C] - def mergePair(k: K, c: C) { - val oldC = combiners.get(k) - if (oldC == null) { - combiners.put(k, c) - } else { - combiners.put(k, aggregator.mergeCombiners(oldC, c)) + val fetcher = SparkEnv.get.shuffleFetcher + + if (aggregator.mergeCombiners != null) { + // If mergeCombiners is specified, combiners are applied on the map + // partitions. In this case, post-shuffle we get a list of outputs from + // the combiners and merge them using mergeCombiners. + def mergePairWithMapSideCombiners(k: K, c: C) { + val oldC = combiners.get(k) + if (oldC == null) { + combiners.put(k, c) + } else { + combiners.put(k, aggregator.mergeCombiners(oldC, c)) + } + } + fetcher.fetch[K, C](dep.shuffleId, split.index, mergePairWithMapSideCombiners) + } else { + // If mergeCombiners is not specified, no combiner is applied on the map + // partitions (i.e. map side aggregation is turned off). Post-shuffle we + // get a list of values and we use mergeValue to merge them. + def mergePairWithoutMapSideCombiners(k: K, v: V) { + val oldC = combiners.get(k) + if (oldC == null) { + combiners.put(k, aggregator.createCombiner(v)) + } else { + combiners.put(k, aggregator.mergeValue(oldC, v)) + } } + fetcher.fetch[K, V](dep.shuffleId, split.index, mergePairWithoutMapSideCombiners) } - val fetcher = SparkEnv.get.shuffleFetcher - fetcher.fetch[K, C](dep.shuffleId, split.index, mergePair) + return new Iterator[(K, C)] { var iter = combiners.entrySet().iterator() diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index e0e050d7c96498915e614349aa1f0a584a8941e0..4828039bbd64596b55102fe304c72546bb033c75 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -104,27 +104,44 @@ class ShuffleMapTask( val numOutputSplits = dep.partitioner.numPartitions val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]] val partitioner = dep.partitioner - val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any]) - for (elem <- rdd.iterator(split)) { - val (k, v) = elem.asInstanceOf[(Any, Any)] - var bucketId = partitioner.getPartition(k) - val bucket = buckets(bucketId) - var existing = bucket.get(k) - if (existing == null) { - bucket.put(k, aggregator.createCombiner(v)) + + val bucketIterators = + if (aggregator.mergeCombiners != null) { + // Apply combiners (map-side aggregation) to the map output. + val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any]) + for (elem <- rdd.iterator(split)) { + val (k, v) = elem.asInstanceOf[(Any, Any)] + val bucketId = partitioner.getPartition(k) + val bucket = buckets(bucketId) + val existing = bucket.get(k) + if (existing == null) { + bucket.put(k, aggregator.createCombiner(v)) + } else { + bucket.put(k, aggregator.mergeValue(existing, v)) + } + } + buckets.map(_.iterator) } else { - bucket.put(k, aggregator.mergeValue(existing, v)) + // No combiners (no map-side aggregation). Simply partition the map output. + val buckets = Array.tabulate(numOutputSplits)(_ => new ArrayBuffer[(Any, Any)]) + for (elem <- rdd.iterator(split)) { + val pair = elem.asInstanceOf[(Any, Any)] + val bucketId = partitioner.getPartition(pair._1) + buckets(bucketId) += pair + } + buckets.map(_.iterator) } - } + val ser = SparkEnv.get.serializer.newInstance() val blockManager = SparkEnv.get.blockManager for (i <- 0 until numOutputSplits) { val blockId = "shuffleid_" + dep.shuffleId + "_" + partition + "_" + i // Get a scala iterator from java map - val iter: Iterator[(Any, Any)] = buckets(i).iterator + val iter: Iterator[(Any, Any)] = bucketIterators(i) // TODO: This should probably be DISK_ONLY blockManager.put(blockId, iter, StorageLevel.MEMORY_ONLY, false) } + return SparkEnv.get.blockManager.blockManagerId }