Skip to content
Snippets Groups Projects
Commit c40e7663 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Use java.util.HashMap in shuffles

parent d6ec664b
No related branches found
No related tags found
No related merge requests found
......@@ -3,7 +3,7 @@ package spark
import java.io.BufferedOutputStream
import java.io.FileOutputStream
import java.io.ObjectOutputStream
import scala.collection.mutable.HashMap
import java.util.{HashMap => JHashMap}
class ShuffleMapTask(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_,_], val partition: Int, locs: Seq[String])
......@@ -14,21 +14,27 @@ extends DAGTask[String](stageId) with Logging {
val numOutputSplits = dep.partitioner.numPartitions
val aggregator = dep.aggregator.asInstanceOf[Aggregator[Any, Any, Any]]
val partitioner = dep.partitioner.asInstanceOf[Partitioner]
val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[Any, Any])
val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any])
for (elem <- rdd.iterator(split)) {
val (k, v) = elem.asInstanceOf[(Any, Any)]
var bucketId = partitioner.getPartition(k)
val bucket = buckets(bucketId)
bucket(k) = bucket.get(k) match {
case Some(c) => aggregator.mergeValue(c, v)
case None => aggregator.createCombiner(v)
var existing = bucket.get(k)
if (existing == null) {
bucket.put(k, aggregator.createCombiner(v))
} else {
bucket.put(k, aggregator.mergeValue(existing, v))
}
}
val ser = SparkEnv.get.serializer.newInstance()
for (i <- 0 until numOutputSplits) {
val file = LocalFileShuffle.getOutputFile(dep.shuffleId, partition, i)
val out = ser.outputStream(new BufferedOutputStream(new FileOutputStream(file)))
buckets(i).foreach(pair => out.writeObject(pair))
val iter = buckets(i).entrySet().iterator()
while (iter.hasNext()) {
val entry = iter.next()
out.writeObject((entry.getKey, entry.getValue))
}
// TODO: have some kind of EOF marker
out.close()
}
......
package spark
import scala.collection.mutable.HashMap
import java.util.{HashMap => JHashMap}
class ShuffledRDDSplit(val idx: Int) extends Split {
......@@ -27,15 +27,26 @@ extends RDD[(K, C)](parent.context) {
override val dependencies = List(dep)
override def compute(split: Split): Iterator[(K, C)] = {
val combiners = new HashMap[K, C]
val combiners = new JHashMap[K, C]
def mergePair(k: K, c: C) {
combiners(k) = combiners.get(k) match {
case Some(oldC) => aggregator.mergeCombiners(oldC, c)
case None => c
val oldC = combiners.get(k)
if (oldC == null) {
combiners.put(k, c)
} else {
combiners.put(k, aggregator.mergeCombiners(oldC, c))
}
}
val fetcher = SparkEnv.get.shuffleFetcher
fetcher.fetch[K, C](dep.shuffleId, split.index, mergePair)
combiners.iterator
return new Iterator[(K, C)] {
var iter = combiners.entrySet().iterator()
def hasNext(): Boolean = iter.hasNext()
def next(): (K, C) = {
val entry = iter.next()
(entry.getKey, entry.getValue)
}
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment