From 940869dfdad5c785404e16f63681a96b885c749a Mon Sep 17 00:00:00 2001 From: Reynold Xin <rxin@cs.berkeley.edu> Date: Wed, 29 Aug 2012 23:00:02 -0700 Subject: [PATCH] Disable running combiners on map tasks when mergeCombiners function is not specified by the user. --- core/src/main/scala/spark/ShuffledRDD.scala | 36 +++++++++++++---- .../spark/scheduler/ShuffleMapTask.scala | 39 +++++++++++++------ 2 files changed, 56 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/spark/ShuffledRDD.scala b/core/src/main/scala/spark/ShuffledRDD.scala index 594dbd235f..8293048caa 100644 --- a/core/src/main/scala/spark/ShuffledRDD.scala +++ b/core/src/main/scala/spark/ShuffledRDD.scala @@ -27,16 +27,36 @@ class ShuffledRDD[K, V, C]( override def compute(split: Split): Iterator[(K, C)] = { val combiners = new JHashMap[K, C] - def mergePair(k: K, c: C) { - val oldC = combiners.get(k) - if (oldC == null) { - combiners.put(k, c) - } else { - combiners.put(k, aggregator.mergeCombiners(oldC, c)) + val fetcher = SparkEnv.get.shuffleFetcher + + if (aggregator.mergeCombiners != null) { + // If mergeCombiners is specified, combiners are applied on the map + // partitions. In this case, post-shuffle we get a list of outputs from + // the combiners and merge them using mergeCombiners. + def mergePairWithMapSideCombiners(k: K, c: C) { + val oldC = combiners.get(k) + if (oldC == null) { + combiners.put(k, c) + } else { + combiners.put(k, aggregator.mergeCombiners(oldC, c)) + } + } + fetcher.fetch[K, C](dep.shuffleId, split.index, mergePairWithMapSideCombiners) + } else { + // If mergeCombiners is not specified, no combiner is applied on the map + // partitions (i.e. map side aggregation is turned off). Post-shuffle we + // get a list of values and we use mergeValue to merge them. + def mergePairWithoutMapSideCombiners(k: K, v: V) { + val oldC = combiners.get(k) + if (oldC == null) { + combiners.put(k, aggregator.createCombiner(v)) + } else { + combiners.put(k, aggregator.mergeValue(oldC, v)) + } } + fetcher.fetch[K, V](dep.shuffleId, split.index, mergePairWithoutMapSideCombiners) } - val fetcher = SparkEnv.get.shuffleFetcher - fetcher.fetch[K, C](dep.shuffleId, split.index, mergePair) + return new Iterator[(K, C)] { var iter = combiners.entrySet().iterator() diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala index e0e050d7c9..4828039bbd 100644 --- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala @@ -104,27 +104,44 @@ class ShuffleMapTask( val numOutputSplits = dep.partitioner.numPartitions val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]] val partitioner = dep.partitioner - val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any]) - for (elem <- rdd.iterator(split)) { - val (k, v) = elem.asInstanceOf[(Any, Any)] - var bucketId = partitioner.getPartition(k) - val bucket = buckets(bucketId) - var existing = bucket.get(k) - if (existing == null) { - bucket.put(k, aggregator.createCombiner(v)) + + val bucketIterators = + if (aggregator.mergeCombiners != null) { + // Apply combiners (map-side aggregation) to the map output. + val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any]) + for (elem <- rdd.iterator(split)) { + val (k, v) = elem.asInstanceOf[(Any, Any)] + val bucketId = partitioner.getPartition(k) + val bucket = buckets(bucketId) + val existing = bucket.get(k) + if (existing == null) { + bucket.put(k, aggregator.createCombiner(v)) + } else { + bucket.put(k, aggregator.mergeValue(existing, v)) + } + } + buckets.map(_.iterator) } else { - bucket.put(k, aggregator.mergeValue(existing, v)) + // No combiners (no map-side aggregation). Simply partition the map output. + val buckets = Array.tabulate(numOutputSplits)(_ => new ArrayBuffer[(Any, Any)]) + for (elem <- rdd.iterator(split)) { + val pair = elem.asInstanceOf[(Any, Any)] + val bucketId = partitioner.getPartition(pair._1) + buckets(bucketId) += pair + } + buckets.map(_.iterator) } - } + val ser = SparkEnv.get.serializer.newInstance() val blockManager = SparkEnv.get.blockManager for (i <- 0 until numOutputSplits) { val blockId = "shuffleid_" + dep.shuffleId + "_" + partition + "_" + i // Get a scala iterator from java map - val iter: Iterator[(Any, Any)] = buckets(i).iterator + val iter: Iterator[(Any, Any)] = bucketIterators(i) // TODO: This should probably be DISK_ONLY blockManager.put(blockId, iter, StorageLevel.MEMORY_ONLY, false) } + return SparkEnv.get.blockManager.blockManagerId } -- GitLab