diff --git a/src/scala/spark/DfsShuffle.scala b/src/scala/spark/DfsShuffle.scala index 256bf4ea9c7d74eae5e4fc3e8bc6c073e912bb40..2ef0321a632c5d13ddef52964a303c808f2670e9 100644 --- a/src/scala/spark/DfsShuffle.scala +++ b/src/scala/spark/DfsShuffle.scala @@ -38,26 +38,25 @@ extends Logging numberedSplitRdd.foreach((pair: (Int, Iterator[(K, V)])) => { val myIndex = pair._1 val myIterator = pair._2 - val combiners = new HashMap[K, C] + val buckets = Array.tabulate(numOutputSplits)(_ => new HashMap[K, C]) for ((k, v) <- myIterator) { - combiners(k) = combiners.get(k) match { + var bucketId = k.hashCode % numOutputSplits + if (bucketId < 0) { // Fix bucket ID if hash code was negative + bucketId += numOutputSplits + } + val bucket = buckets(bucketId) + bucket(k) = bucket.get(k) match { case Some(c) => mergeValue(c, v) case None => createCombiner(v) } } val fs = DfsShuffle.getFileSystem() - val outputStreams = (0 until numOutputSplits).map(i => { + for (i <- 0 until numOutputSplits) { val path = new Path(dir, "%d-to-%d".format(myIndex, i)) - new ObjectOutputStream(fs.create(path, true)) - }).toArray - for ((k, c) <- combiners) { - var bucket = k.hashCode % numOutputSplits - if (bucket < 0) { - bucket += numOutputSplits - } - outputStreams(bucket).writeObject((k, c)) + val out = new ObjectOutputStream(fs.create(path, true)) + buckets(i).foreach(pair => out.writeObject(pair)) + out.close() } - outputStreams.foreach(_.close()) }) // Return an RDD that does each of the merges for a given partition