Skip to content
Snippets Groups Projects
Commit 940869df authored by Reynold Xin's avatar Reynold Xin
Browse files

Disable running combiners on map tasks when mergeCombiners function is

not specified by the user.
parent 3a6a95dc
No related branches found
No related tags found
No related merge requests found
......@@ -27,16 +27,36 @@ class ShuffledRDD[K, V, C](
override def compute(split: Split): Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]
def mergePair(k: K, c: C) {
val oldC = combiners.get(k)
if (oldC == null) {
combiners.put(k, c)
} else {
combiners.put(k, aggregator.mergeCombiners(oldC, c))
val fetcher = SparkEnv.get.shuffleFetcher
if (aggregator.mergeCombiners != null) {
// If mergeCombiners is specified, combiners are applied on the map
// partitions. In this case, post-shuffle we get a list of outputs from
// the combiners and merge them using mergeCombiners.
def mergePairWithMapSideCombiners(k: K, c: C) {
val oldC = combiners.get(k)
if (oldC == null) {
combiners.put(k, c)
} else {
combiners.put(k, aggregator.mergeCombiners(oldC, c))
}
}
fetcher.fetch[K, C](dep.shuffleId, split.index, mergePairWithMapSideCombiners)
} else {
// If mergeCombiners is not specified, no combiner is applied on the map
// partitions (i.e. map side aggregation is turned off). Post-shuffle we
// get a list of values and we use mergeValue to merge them.
def mergePairWithoutMapSideCombiners(k: K, v: V) {
val oldC = combiners.get(k)
if (oldC == null) {
combiners.put(k, aggregator.createCombiner(v))
} else {
combiners.put(k, aggregator.mergeValue(oldC, v))
}
}
fetcher.fetch[K, V](dep.shuffleId, split.index, mergePairWithoutMapSideCombiners)
}
val fetcher = SparkEnv.get.shuffleFetcher
fetcher.fetch[K, C](dep.shuffleId, split.index, mergePair)
return new Iterator[(K, C)] {
var iter = combiners.entrySet().iterator()
......
......@@ -104,27 +104,44 @@ class ShuffleMapTask(
val numOutputSplits = dep.partitioner.numPartitions
val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
val partitioner = dep.partitioner
val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
for (elem <- rdd.iterator(split)) {
val (k, v) = elem.asInstanceOf[(Any, Any)]
var bucketId = partitioner.getPartition(k)
val bucket = buckets(bucketId)
var existing = bucket.get(k)
if (existing == null) {
bucket.put(k, aggregator.createCombiner(v))
val bucketIterators =
if (aggregator.mergeCombiners != null) {
// Apply combiners (map-side aggregation) to the map output.
val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
for (elem <- rdd.iterator(split)) {
val (k, v) = elem.asInstanceOf[(Any, Any)]
val bucketId = partitioner.getPartition(k)
val bucket = buckets(bucketId)
val existing = bucket.get(k)
if (existing == null) {
bucket.put(k, aggregator.createCombiner(v))
} else {
bucket.put(k, aggregator.mergeValue(existing, v))
}
}
buckets.map(_.iterator)
} else {
bucket.put(k, aggregator.mergeValue(existing, v))
// No combiners (no map-side aggregation). Simply partition the map output.
val buckets = Array.tabulate(numOutputSplits)(_ => new ArrayBuffer[(Any, Any)])
for (elem <- rdd.iterator(split)) {
val pair = elem.asInstanceOf[(Any, Any)]
val bucketId = partitioner.getPartition(pair._1)
buckets(bucketId) += pair
}
buckets.map(_.iterator)
}
}
val ser = SparkEnv.get.serializer.newInstance()
val blockManager = SparkEnv.get.blockManager
for (i <- 0 until numOutputSplits) {
val blockId = "shuffleid_" + dep.shuffleId + "_" + partition + "_" + i
// Get a scala iterator from java map
val iter: Iterator[(Any, Any)] = buckets(i).iterator
val iter: Iterator[(Any, Any)] = bucketIterators(i)
// TODO: This should probably be DISK_ONLY
blockManager.put(blockId, iter, StorageLevel.MEMORY_ONLY, false)
}
return SparkEnv.get.blockManager.blockManagerId
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment