Skip to content
Snippets Groups Projects
Commit fefd22f4 authored by Andrew Or's avatar Andrew Or Committed by Patrick Wendell
Browse files

[SPARK-1113] External spilling - fix Int.MaxValue hash code collision bug

The original poster of this bug is @guojc, who opened a PR that preceded this one at https://github.com/apache/incubator-spark/pull/612.

ExternalAppendOnlyMap uses key hash code to order the buffer streams from which spilled files are read back into memory. When a buffer stream is empty, the default hash code for that stream is equal to Int.MaxValue. This is, however, a perfectly legitimate candidate for a key hash code. When reading from a spilled map containing such a key, a hash collision may occur, in which case we attempt to read from an empty stream and throw NoSuchElementException.

The fix is to maintain the invariant that empty buffer streams are never added back to the merge queue to be considered. This guarantees that we never read from an empty buffer stream, ever again.

This PR also includes two new tests for hash collisions.

Author: Andrew Or <andrewor14@gmail.com>

Closes #624 from andrewor14/spilling-bug and squashes the following commits:

9e7263d [Andrew Or] Slightly optimize next()
2037ae2 [Andrew Or] Move a few comments around...
cf95942 [Andrew Or] Remove default value of Int.MaxValue for minKeyHash
c11f03b [Andrew Or] Fix Int.MaxValue hash collision bug in ExternalAppendOnlyMap
21c1a39 [Andrew Or] Add hash collision tests to ExternalAppendOnlyMapSuite
parent c8a4c9b1
No related branches found
No related tags found
No related merge requests found
...@@ -148,7 +148,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C]( ...@@ -148,7 +148,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
} }
/** /**
* Sort the existing contents of the in-memory map and spill them to a temporary file on disk * Sort the existing contents of the in-memory map and spill them to a temporary file on disk.
*/ */
private def spill(mapSize: Long) { private def spill(mapSize: Long) {
spillCount += 1 spillCount += 1
...@@ -223,7 +223,8 @@ private[spark] class ExternalAppendOnlyMap[K, V, C]( ...@@ -223,7 +223,8 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
*/ */
private class ExternalIterator extends Iterator[(K, C)] { private class ExternalIterator extends Iterator[(K, C)] {
// A fixed-size queue that maintains a buffer for each stream we are currently merging // A queue that maintains a buffer for each stream we are currently merging
// This queue maintains the invariant that it only contains non-empty buffers
private val mergeHeap = new mutable.PriorityQueue[StreamBuffer] private val mergeHeap = new mutable.PriorityQueue[StreamBuffer]
// Input streams are derived both from the in-memory map and spilled maps on disk // Input streams are derived both from the in-memory map and spilled maps on disk
...@@ -233,7 +234,9 @@ private[spark] class ExternalAppendOnlyMap[K, V, C]( ...@@ -233,7 +234,9 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
inputStreams.foreach { it => inputStreams.foreach { it =>
val kcPairs = getMorePairs(it) val kcPairs = getMorePairs(it)
mergeHeap.enqueue(StreamBuffer(it, kcPairs)) if (kcPairs.length > 0) {
mergeHeap.enqueue(new StreamBuffer(it, kcPairs))
}
} }
/** /**
...@@ -258,11 +261,11 @@ private[spark] class ExternalAppendOnlyMap[K, V, C]( ...@@ -258,11 +261,11 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
/** /**
* If the given buffer contains a value for the given key, merge that value into * If the given buffer contains a value for the given key, merge that value into
* baseCombiner and remove the corresponding (K, C) pair from the buffer * baseCombiner and remove the corresponding (K, C) pair from the buffer.
*/ */
private def mergeIfKeyExists(key: K, baseCombiner: C, buffer: StreamBuffer): C = { private def mergeIfKeyExists(key: K, baseCombiner: C, buffer: StreamBuffer): C = {
var i = 0 var i = 0
while (i < buffer.pairs.size) { while (i < buffer.pairs.length) {
val (k, c) = buffer.pairs(i) val (k, c) = buffer.pairs(i)
if (k == key) { if (k == key) {
buffer.pairs.remove(i) buffer.pairs.remove(i)
...@@ -274,40 +277,41 @@ private[spark] class ExternalAppendOnlyMap[K, V, C]( ...@@ -274,40 +277,41 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
} }
/** /**
* Return true if there exists an input stream that still has unvisited pairs * Return true if there exists an input stream that still has unvisited pairs.
*/ */
override def hasNext: Boolean = mergeHeap.exists(!_.pairs.isEmpty) override def hasNext: Boolean = mergeHeap.length > 0
/** /**
* Select a key with the minimum hash, then combine all values with the same key from all * Select a key with the minimum hash, then combine all values with the same key from all
* input streams. * input streams.
*/ */
override def next(): (K, C) = { override def next(): (K, C) = {
if (mergeHeap.length == 0) {
throw new NoSuchElementException
}
// Select a key from the StreamBuffer that holds the lowest key hash // Select a key from the StreamBuffer that holds the lowest key hash
val minBuffer = mergeHeap.dequeue() val minBuffer = mergeHeap.dequeue()
val (minPairs, minHash) = (minBuffer.pairs, minBuffer.minKeyHash) val (minPairs, minHash) = (minBuffer.pairs, minBuffer.minKeyHash)
if (minPairs.length == 0) {
// Should only happen when no other stream buffers have any pairs left
throw new NoSuchElementException
}
var (minKey, minCombiner) = minPairs.remove(0) var (minKey, minCombiner) = minPairs.remove(0)
assert(minKey.hashCode() == minHash) assert(minKey.hashCode() == minHash)
// For all other streams that may have this key (i.e. have the same minimum key hash), // For all other streams that may have this key (i.e. have the same minimum key hash),
// merge in the corresponding value (if any) from that stream // merge in the corresponding value (if any) from that stream
val mergedBuffers = ArrayBuffer[StreamBuffer](minBuffer) val mergedBuffers = ArrayBuffer[StreamBuffer](minBuffer)
while (!mergeHeap.isEmpty && mergeHeap.head.minKeyHash == minHash) { while (mergeHeap.length > 0 && mergeHeap.head.minKeyHash == minHash) {
val newBuffer = mergeHeap.dequeue() val newBuffer = mergeHeap.dequeue()
minCombiner = mergeIfKeyExists(minKey, minCombiner, newBuffer) minCombiner = mergeIfKeyExists(minKey, minCombiner, newBuffer)
mergedBuffers += newBuffer mergedBuffers += newBuffer
} }
// Repopulate each visited stream buffer and add it back to the merge heap // Repopulate each visited stream buffer and add it back to the queue if it is non-empty
mergedBuffers.foreach { buffer => mergedBuffers.foreach { buffer =>
if (buffer.pairs.length == 0) { if (buffer.isEmpty) {
buffer.pairs ++= getMorePairs(buffer.iterator) buffer.pairs ++= getMorePairs(buffer.iterator)
} }
mergeHeap.enqueue(buffer) if (!buffer.isEmpty) {
mergeHeap.enqueue(buffer)
}
} }
(minKey, minCombiner) (minKey, minCombiner)
...@@ -323,13 +327,12 @@ private[spark] class ExternalAppendOnlyMap[K, V, C]( ...@@ -323,13 +327,12 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
private case class StreamBuffer(iterator: Iterator[(K, C)], pairs: ArrayBuffer[(K, C)]) private case class StreamBuffer(iterator: Iterator[(K, C)], pairs: ArrayBuffer[(K, C)])
extends Comparable[StreamBuffer] { extends Comparable[StreamBuffer] {
def minKeyHash: Int = { def isEmpty = pairs.length == 0
if (pairs.length > 0){
// pairs are already sorted by key hash // Invalid if there are no more pairs in this stream
pairs(0)._1.hashCode() def minKeyHash = {
} else { assert(pairs.length > 0)
Int.MaxValue pairs.head._1.hashCode()
}
} }
override def compareTo(other: StreamBuffer): Int = { override def compareTo(other: StreamBuffer): Int = {
...@@ -356,7 +359,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C]( ...@@ -356,7 +359,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C](
private var objectsRead = 0 private var objectsRead = 0
/** /**
* Construct a stream that reads only from the next batch * Construct a stream that reads only from the next batch.
*/ */
private def nextBatchStream(): InputStream = { private def nextBatchStream(): InputStream = {
if (batchSizes.length > 0) { if (batchSizes.length > 0) {
......
...@@ -19,21 +19,16 @@ package org.apache.spark.util.collection ...@@ -19,21 +19,16 @@ package org.apache.spark.util.collection
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.FunSuite
import org.apache.spark._ import org.apache.spark._
import org.apache.spark.SparkContext._ import org.apache.spark.SparkContext._
class ExternalAppendOnlyMapSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
private val createCombiner: (Int => ArrayBuffer[Int]) = i => ArrayBuffer[Int](i) private def createCombiner(i: Int) = ArrayBuffer[Int](i)
private val mergeValue: (ArrayBuffer[Int], Int) => ArrayBuffer[Int] = (buffer, i) => { private def mergeValue(buffer: ArrayBuffer[Int], i: Int) = buffer += i
buffer += i private def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]) = buf1 ++= buf2
}
private val mergeCombiners: (ArrayBuffer[Int], ArrayBuffer[Int]) => ArrayBuffer[Int] =
(buf1, buf2) => {
buf1 ++= buf2
}
test("simple insert") { test("simple insert") {
val conf = new SparkConf(false) val conf = new SparkConf(false)
...@@ -203,13 +198,13 @@ class ExternalAppendOnlyMapSuite extends FunSuite with BeforeAndAfter with Local ...@@ -203,13 +198,13 @@ class ExternalAppendOnlyMapSuite extends FunSuite with BeforeAndAfter with Local
} }
test("spilling") { test("spilling") {
// TODO: Use SparkConf (which currently throws connection reset exception) val conf = new SparkConf(true) // Load defaults, otherwise SPARK_HOME is not found
System.setProperty("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.memoryFraction", "0.001")
sc = new SparkContext("local-cluster[1,1,512]", "test") sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
// reduceByKey - should spill ~8 times // reduceByKey - should spill ~8 times
val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i))
val resultA = rddA.reduceByKey(math.max(_, _)).collect() val resultA = rddA.reduceByKey(math.max).collect()
assert(resultA.length == 50000) assert(resultA.length == 50000)
resultA.foreach { case(k, v) => resultA.foreach { case(k, v) =>
k match { k match {
...@@ -252,7 +247,73 @@ class ExternalAppendOnlyMapSuite extends FunSuite with BeforeAndAfter with Local ...@@ -252,7 +247,73 @@ class ExternalAppendOnlyMapSuite extends FunSuite with BeforeAndAfter with Local
case _ => case _ =>
} }
} }
}
test("spilling with hash collisions") {
val conf = new SparkConf(true)
conf.set("spark.shuffle.memoryFraction", "0.001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
def createCombiner(i: String) = ArrayBuffer[String](i)
def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i
def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) =
buffer1 ++= buffer2
val map = new ExternalAppendOnlyMap[String, String, ArrayBuffer[String]](
createCombiner, mergeValue, mergeCombiners)
val collisionPairs = Seq(
("Aa", "BB"), // 2112
("to", "v1"), // 3707
("variants", "gelato"), // -1249574770
("Teheran", "Siblings"), // 231609873
("misused", "horsemints"), // 1069518484
("isohel", "epistolaries"), // -1179291542
("righto", "buzzards"), // -931102253
("hierarch", "crinolines"), // -1732884796
("inwork", "hypercatalexes"), // -1183663690
("wainages", "presentencing"), // 240183619
("trichothecenes", "locular"), // 339006536
("pomatoes", "eructation") // 568647356
)
(1 to 100000).map(_.toString).foreach { i => map.insert(i, i) }
collisionPairs.foreach { case (w1, w2) =>
map.insert(w1, w2)
map.insert(w2, w1)
}
// A map of collision pairs in both directions
val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap
// Avoid map.size or map.iterator.length because this destructively sorts the underlying map
var count = 0
val it = map.iterator
while (it.hasNext) {
val kv = it.next()
val expectedValue = ArrayBuffer[String](collisionPairsMap.getOrElse(kv._1, kv._1))
assert(kv._2.equals(expectedValue))
count += 1
}
assert(count == 100000 + collisionPairs.size * 2)
}
test("spilling with hash collisions using the Int.MaxValue key") {
val conf = new SparkConf(true)
conf.set("spark.shuffle.memoryFraction", "0.001")
sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
System.clearProperty("spark.shuffle.memoryFraction") val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner,
mergeValue, mergeCombiners)
(1 to 100000).foreach { i => map.insert(i, i) }
map.insert(Int.MaxValue, Int.MaxValue)
val it = map.iterator
while (it.hasNext) {
// Should not throw NoSuchElementException
it.next()
}
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment