diff --git a/core/src/main/scala/org/apache/spark/util/hash/BitSet.scala b/core/src/main/scala/org/apache/spark/util/hash/BitSet.scala new file mode 100644 index 0000000000000000000000000000000000000000..69b10566f33992150b80b5e2aac5db6992b2ac5c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/hash/BitSet.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.hash + + +/** + * A simple, fixed-size bit set implementation. This implementation is fast because it avoids + * safety/bound checking. + */ +class BitSet(numBits: Int) { + + private val words = new Array[Long](bit2words(numBits)) + private val numWords = words.length + + /** + * Compute the capacity (number of bits) that can be represented + * by this bitset. + */ + def capacity: Int = numWords * 64 + + + /** + * Set all the bits up to a given index + */ + def setUntil(bitIndex: Int) { + val wordIndex = bitIndex >> 6 // divide by 64 + var i = 0 + while(i < wordIndex) { words(i) = -1; i += 1 } + // Set the remaining bits + val mask = ~(-1L << (bitIndex & 0x3f)) + words(wordIndex) |= mask + } + + + /** + * Compute the bit-wise AND of the two sets returning the + * result. + */ + def &(other: BitSet): BitSet = { + val newBS = new BitSet(math.max(capacity, other.capacity)) + val smaller = math.min(numWords, other.numWords) + assert(newBS.numWords >= numWords) + assert(newBS.numWords >= other.numWords) + var ind = 0 + while( ind < smaller ) { + newBS.words(ind) = words(ind) & other.words(ind) + ind += 1 + } + newBS + } + + + /** + * Compute the bit-wise OR of the two sets returning the + * result. + */ + def |(other: BitSet): BitSet = { + val newBS = new BitSet(math.max(capacity, other.capacity)) + assert(newBS.numWords >= numWords) + assert(newBS.numWords >= other.numWords) + val smaller = math.min(numWords, other.numWords) + var ind = 0 + while( ind < smaller ) { + newBS.words(ind) = words(ind) | other.words(ind) + ind += 1 + } + while( ind < numWords ) { + newBS.words(ind) = words(ind) + ind += 1 + } + while( ind < other.numWords ) { + newBS.words(ind) = other.words(ind) + ind += 1 + } + newBS + } + + + /** + * Sets the bit at the specified index to true. + * @param index the bit index + */ + def set(index: Int) { + val bitmask = 1L << (index & 0x3f) // mod 64 and shift + words(index >> 6) |= bitmask // div by 64 and mask + } + + + /** + * Return the value of the bit with the specified index. The value is true if the bit with + * the index is currently set in this BitSet; otherwise, the result is false. + * + * @param index the bit index + * @return the value of the bit with the specified index + */ + def get(index: Int): Boolean = { + val bitmask = 1L << (index & 0x3f) // mod 64 and shift + (words(index >>> 6) & bitmask) != 0 // div by 64 and mask + } + + + /** + * Get an iterator over the set bits. + */ + def iterator = new Iterator[Int] { + var ind = nextSetBit(0) + override def hasNext: Boolean = ind >= 0 + override def next() = { + val tmp = ind + ind = nextSetBit(ind+1) + tmp + } + } + + + /** Return the number of bits set to true in this BitSet. */ + def cardinality(): Int = { + var sum = 0 + var i = 0 + while (i < numWords) { + sum += java.lang.Long.bitCount(words(i)) + i += 1 + } + sum + } + + + /** + * Returns the index of the first bit that is set to true that occurs on or after the + * specified starting index. If no such bit exists then -1 is returned. + * + * To iterate over the true bits in a BitSet, use the following loop: + * + * for (int i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i+1)) { + * // operate on index i here + * } + * + * @param fromIndex the index to start checking from (inclusive) + * @return the index of the next set bit, or -1 if there is no such bit + */ + def nextSetBit(fromIndex: Int): Int = { + var wordIndex = fromIndex >> 6 + if (wordIndex >= numWords) { + return -1 + } + + // Try to find the next set bit in the current word + val subIndex = fromIndex & 0x3f + var word = words(wordIndex) >> subIndex + if (word != 0) { + return (wordIndex << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word) + } + + // Find the next set bit in the rest of the words + wordIndex += 1 + while (wordIndex < numWords) { + word = words(wordIndex) + if (word != 0) { + return (wordIndex << 6) + java.lang.Long.numberOfTrailingZeros(word) + } + wordIndex += 1 + } + + -1 + } + + + /** Return the number of longs it would take to hold numBits. */ + private def bit2words(numBits: Int) = ((numBits - 1) >>> 6) + 1 +} diff --git a/core/src/main/scala/org/apache/spark/util/hash/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/hash/OpenHashMap.scala new file mode 100644 index 0000000000000000000000000000000000000000..a376d1015a314815ac3cd75e76b1629d544ed62c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/hash/OpenHashMap.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.hash + + +/** + * A fast hash map implementation for nullable keys. This hash map supports insertions and updates, + * but not deletions. This map is about 5X faster than java.util.HashMap, while using much less + * space overhead. + * + * Under the hood, it uses our OpenHashSet implementation. + */ +private[spark] +class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V: ClassManifest]( + initialCapacity: Int) + extends Iterable[(K, V)] + with Serializable { + + def this() = this(64) + + protected var _keySet = new OpenHashSet[K](initialCapacity) + private var _values = new Array[V](_keySet.capacity) + + @transient private var _oldValues: Array[V] = null + + // Treat the null key differently so we can use nulls in "data" to represent empty items. + private var haveNullValue = false + private var nullValue: V = null.asInstanceOf[V] + + override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size + + /** Get the value for a given key */ + def apply(k: K): V = { + if (k == null) { + nullValue + } else { + val pos = _keySet.getPos(k) + if (pos < 0) { + null.asInstanceOf[V] + } else { + _values(pos) + } + } + } + + /** Set the value for a key */ + def update(k: K, v: V) { + if (k == null) { + haveNullValue = true + nullValue = v + } else { + val pos = _keySet.fastAdd(k) & OpenHashSet.POSITION_MASK + _values(pos) = v + _keySet.rehashIfNeeded(k, grow, move) + _oldValues = null + } + } + + /** + * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise, + * set its value to mergeValue(oldValue). + * + * @return the newly updated value. + */ + def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = { + if (k == null) { + if (haveNullValue) { + nullValue = mergeValue(nullValue) + } else { + haveNullValue = true + nullValue = defaultValue + } + nullValue + } else { + val pos = _keySet.fastAdd(k) + if ((pos & OpenHashSet.EXISTENCE_MASK) != 0) { + val newValue = defaultValue + _values(pos & OpenHashSet.POSITION_MASK) = newValue + _keySet.rehashIfNeeded(k, grow, move) + newValue + } else { + _values(pos) = mergeValue(_values(pos)) + _values(pos) + } + } + } + + override def iterator = new Iterator[(K, V)] { + var pos = -1 + var nextPair: (K, V) = computeNextPair() + + /** Get the next value we should return from next(), or null if we're finished iterating */ + def computeNextPair(): (K, V) = { + if (pos == -1) { // Treat position -1 as looking at the null value + if (haveNullValue) { + pos += 1 + return (null.asInstanceOf[K], nullValue) + } + pos += 1 + } + pos = _keySet.nextPos(pos) + if (pos >= 0) { + val ret = (_keySet.getValue(pos), _values(pos)) + pos += 1 + ret + } else { + null + } + } + + def hasNext = nextPair != null + + def next() = { + val pair = nextPair + nextPair = computeNextPair() + pair + } + } + + // The following member variables are declared as protected instead of private for the + // specialization to work (specialized class extends the non-specialized one and needs access + // to the "private" variables). + // They also should have been val's. We use var's because there is a Scala compiler bug that + // would throw illegal access error at runtime if they are declared as val's. + protected var grow = (newCapacity: Int) => { + _oldValues = _values + _values = new Array[V](newCapacity) + } + + protected var move = (oldPos: Int, newPos: Int) => { + _values(newPos) = _oldValues(oldPos) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/hash/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/hash/OpenHashSet.scala new file mode 100644 index 0000000000000000000000000000000000000000..7aa3f6220cee38642eb688875b564fb7132c1529 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/hash/OpenHashSet.scala @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.hash + + +/** + * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never + * removed. + * + * The underlying implementation uses Scala compiler's specialization to generate optimized + * storage for two primitive types (Long and Int). It is much faster than Java's standard HashSet + * while incurring much less memory overhead. This can serve as building blocks for higher level + * data structures such as an optimized HashMap. + * + * This OpenHashSet is designed to serve as building blocks for higher level data structures + * such as an optimized hash map. Compared with standard hash set implementations, this class + * provides its various callbacks interfaces (e.g. allocateFunc, moveFunc) and interfaces to + * retrieve the position of a key in the underlying array. + * + * It uses quadratic probing with a power-of-2 hash table size, which is guaranteed + * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing). + */ +private[spark] +class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( + initialCapacity: Int, + loadFactor: Double) + extends Serializable { + + require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + require(initialCapacity >= 1, "Invalid initial capacity") + + import OpenHashSet._ + + def this(initialCapacity: Int) = this(initialCapacity, 0.7) + + def this() = this(64) + + // The following member variables are declared as protected instead of private for the + // specialization to work (specialized class extends the non-specialized one and needs access + // to the "private" variables). + + protected val hasher: Hasher[T] = { + // It would've been more natural to write the following using pattern matching. But Scala 2.9.x + // compiler has a bug when specialization is used together with this pattern matching, and + // throws: + // scala.tools.nsc.symtab.Types$TypeError: type mismatch; + // found : scala.reflect.AnyValManifest[Long] + // required: scala.reflect.ClassManifest[Int] + // at scala.tools.nsc.typechecker.Contexts$Context.error(Contexts.scala:298) + // at scala.tools.nsc.typechecker.Infer$Inferencer.error(Infer.scala:207) + // ... + val mt = classManifest[T] + if (mt == ClassManifest.Long) { + (new LongHasher).asInstanceOf[Hasher[T]] + } else if (mt == ClassManifest.Int) { + (new IntHasher).asInstanceOf[Hasher[T]] + } else { + new Hasher[T] + } + } + + protected var _capacity = nextPowerOf2(initialCapacity) + protected var _mask = _capacity - 1 + protected var _size = 0 + + protected var _data = classManifest[T].newArray(_capacity) + protected var _bitset = new BitSet(_capacity) + + /** Number of elements in the set. */ + def size: Int = _size + + /** The capacity of the set (i.e. size of the underlying array). */ + def capacity: Int = _capacity + + /** Return true if this set contains the specified element. */ + def contains(k: T): Boolean = getPos(k) != INVALID_POS + + /** + * Add an element to the set. If the set is over capacity after the insertion, grow the set + * and rehash all elements. + */ + def add(k: T) { + fastAdd(k) + rehashIfNeeded(k, grow, move) + } + + /** + * Add an element to the set. This one differs from add in that it doesn't trigger rehashing. + * The caller is responsible for calling rehashIfNeeded. + * + * Use (retval & POSITION_MASK) to get the actual position, and + * (retval & EXISTENCE_MASK) != 0 for prior existence. + * + * @return The position where the key is placed, plus the highest order bit is set if the key + * exists previously. + */ + def fastAdd(k: T): Int = putInto(_bitset, _data, k) + + /** + * Rehash the set if it is overloaded. + * @param k A parameter unused in the function, but to force the Scala compiler to specialize + * this method. + * @param allocateFunc Closure invoked when we are allocating a new, larger array. + * @param moveFunc Closure invoked when we move the key from one position (in the old data array) + * to a new position (in the new data array). + */ + def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { + if (_size > loadFactor * _capacity) { + rehash(k, allocateFunc, moveFunc) + } + } + + /** Return the position of the element in the underlying array. */ + def getPos(k: T): Int = { + var pos = hashcode(hasher.hash(k)) & _mask + var i = 1 + while (true) { + if (!_bitset.get(pos)) { + return INVALID_POS + } else if (k == _data(pos)) { + return pos + } else { + val delta = i + pos = (pos + delta) & _mask + i += 1 + } + } + // Never reached here + INVALID_POS + } + + /** Return the value at the specified position. */ + def getValue(pos: Int): T = _data(pos) + + /** + * Return the next position with an element stored, starting from the given position inclusively. + */ + def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos) + + /** + * Put an entry into the set. Return the position where the key is placed. In addition, the + * highest bid in the returned position is set if the key exists prior to this put. + * + * This function assumes the data array has at least one empty slot. + */ + private def putInto(bitset: BitSet, data: Array[T], k: T): Int = { + val mask = data.length - 1 + var pos = hashcode(hasher.hash(k)) & mask + var i = 1 + while (true) { + if (!bitset.get(pos)) { + // This is a new key. + data(pos) = k + bitset.set(pos) + _size += 1 + return pos | EXISTENCE_MASK + } else if (data(pos) == k) { + // Found an existing key. + return pos + } else { + val delta = i + pos = (pos + delta) & mask + i += 1 + } + } + // Never reached here + assert(INVALID_POS != INVALID_POS) + INVALID_POS + } + + /** + * Double the table's size and re-hash everything. We are not really using k, but it is declared + * so Scala compiler can specialize this method (which leads to calling the specialized version + * of putInto). + * + * @param k A parameter unused in the function, but to force the Scala compiler to specialize + * this method. + * @param allocateFunc Closure invoked when we are allocating a new, larger array. + * @param moveFunc Closure invoked when we move the key from one position (in the old data array) + * to a new position (in the new data array). + */ + private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { + val newCapacity = _capacity * 2 + require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + + allocateFunc(newCapacity) + val newData = classManifest[T].newArray(newCapacity) + val newBitset = new BitSet(newCapacity) + var pos = 0 + _size = 0 + while (pos < _capacity) { + if (_bitset.get(pos)) { + val newPos = putInto(newBitset, newData, _data(pos)) + moveFunc(pos, newPos & POSITION_MASK) + } + pos += 1 + } + _bitset = newBitset + _data = newData + _capacity = newCapacity + _mask = newCapacity - 1 + } + + /** + * Re-hash a value to deal better with hash functions that don't differ + * in the lower bits, similar to java.util.HashMap + */ + private def hashcode(h: Int): Int = { + val r = h ^ (h >>> 20) ^ (h >>> 12) + r ^ (r >>> 7) ^ (r >>> 4) + } + + private def nextPowerOf2(n: Int): Int = { + val highBit = Integer.highestOneBit(n) + if (highBit == n) n else highBit << 1 + } +} + + +private[spark] +object OpenHashSet { + + val INVALID_POS = -1 + + val EXISTENCE_MASK = 0x80000000 + + val POSITION_MASK = 0xEFFFFFF + + /** + * A set of specialized hash function implementation to avoid boxing hash code computation + * in the specialized implementation of OpenHashSet. + */ + sealed class Hasher[@specialized(Long, Int) T] { + def hash(o: T): Int = o.hashCode() + } + + class LongHasher extends Hasher[Long] { + override def hash(o: Long): Int = (o ^ (o >>> 32)).toInt + } + + class IntHasher extends Hasher[Int] { + override def hash(o: Int): Int = o + } + + private def grow1(newSize: Int) {} + private def move1(oldPos: Int, newPos: Int) { } + + private val grow = grow1 _ + private val move = move1 _ +} diff --git a/core/src/main/scala/org/apache/spark/util/hash/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/hash/PrimitiveKeyOpenHashMap.scala new file mode 100644 index 0000000000000000000000000000000000000000..14c136720788aee3c64507f82cedae87de2a2c11 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/hash/PrimitiveKeyOpenHashMap.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.hash + + +/** + * A fast hash map implementation for primitive, non-null keys. This hash map supports + * insertions and updates, but not deletions. This map is about an order of magnitude + * faster than java.util.HashMap, while using much less space overhead. + * + * Under the hood, it uses our OpenHashSet implementation. + */ +private[spark] +class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest, + @specialized(Long, Int, Double) V: ClassManifest]( + initialCapacity: Int) + extends Iterable[(K, V)] + with Serializable { + + def this() = this(64) + + require(classManifest[K] == classManifest[Long] || classManifest[K] == classManifest[Int]) + + protected var _keySet = new OpenHashSet[K](initialCapacity) + private var _values = new Array[V](_keySet.capacity) + + private var _oldValues: Array[V] = null + + override def size = _keySet.size + + /** Get the value for a given key */ + def apply(k: K): V = { + val pos = _keySet.getPos(k) + _values(pos) + } + + /** Set the value for a key */ + def update(k: K, v: V) { + val pos = _keySet.fastAdd(k) & OpenHashSet.POSITION_MASK + _values(pos) = v + _keySet.rehashIfNeeded(k, grow, move) + _oldValues = null + } + + /** + * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise, + * set its value to mergeValue(oldValue). + * + * @return the newly updated value. + */ + def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = { + val pos = _keySet.fastAdd(k) + if ((pos & OpenHashSet.EXISTENCE_MASK) != 0) { + val newValue = defaultValue + _values(pos & OpenHashSet.POSITION_MASK) = newValue + _keySet.rehashIfNeeded(k, grow, move) + newValue + } else { + _values(pos) = mergeValue(_values(pos)) + _values(pos) + } + } + + override def iterator = new Iterator[(K, V)] { + var pos = 0 + var nextPair: (K, V) = computeNextPair() + + /** Get the next value we should return from next(), or null if we're finished iterating */ + def computeNextPair(): (K, V) = { + pos = _keySet.nextPos(pos) + if (pos >= 0) { + val ret = (_keySet.getValue(pos), _values(pos)) + pos += 1 + ret + } else { + null + } + } + + def hasNext = nextPair != null + + def next() = { + val pair = nextPair + nextPair = computeNextPair() + pair + } + } + + // The following member variables are declared as protected instead of private for the + // specialization to work (specialized class extends the unspecialized one and needs access + // to the "private" variables). + // They also should have been val's. We use var's because there is a Scala compiler bug that + // would throw illegal access error at runtime if they are declared as val's. + protected var grow = (newCapacity: Int) => { + _oldValues = _values + _values = new Array[V](newCapacity) + } + + protected var move = (oldPos: Int, newPos: Int) => { + _values(newPos) = _oldValues(oldPos) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/hash/BitSetSuite.scala b/core/src/test/scala/org/apache/spark/util/hash/BitSetSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..41ede860d2c79b5f083604c14e10a5a9e8de5372 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/hash/BitSetSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.hash + +import org.scalatest.FunSuite + + +class BitSetSuite extends FunSuite { + + test("basic set and get") { + val setBits = Seq(0, 9, 1, 10, 90, 96) + val bitset = new BitSet(100) + + for (i <- 0 until 100) { + assert(!bitset.get(i)) + } + + setBits.foreach(i => bitset.set(i)) + + for (i <- 0 until 100) { + if (setBits.contains(i)) { + assert(bitset.get(i)) + } else { + assert(!bitset.get(i)) + } + } + assert(bitset.cardinality() === setBits.size) + } + + test("100% full bit set") { + val bitset = new BitSet(10000) + for (i <- 0 until 10000) { + assert(!bitset.get(i)) + bitset.set(i) + } + for (i <- 0 until 10000) { + assert(bitset.get(i)) + } + assert(bitset.cardinality() === 10000) + } + + test("nextSetBit") { + val setBits = Seq(0, 9, 1, 10, 90, 96) + val bitset = new BitSet(100) + setBits.foreach(i => bitset.set(i)) + + assert(bitset.nextSetBit(0) === 0) + assert(bitset.nextSetBit(1) === 1) + assert(bitset.nextSetBit(2) === 9) + assert(bitset.nextSetBit(9) === 9) + assert(bitset.nextSetBit(10) === 10) + assert(bitset.nextSetBit(11) === 90) + assert(bitset.nextSetBit(80) === 90) + assert(bitset.nextSetBit(91) === 96) + assert(bitset.nextSetBit(96) === 96) + assert(bitset.nextSetBit(97) === -1) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/hash/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/hash/OpenHashMapSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..355784da32ecc0fd83a3e63a87c3a32fc4f3c17c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/hash/OpenHashMapSuite.scala @@ -0,0 +1,148 @@ +package org.apache.spark.util.hash + +import scala.collection.mutable.HashSet +import org.scalatest.FunSuite + +class OpenHashMapSuite extends FunSuite { + + test("initialization") { + val goodMap1 = new OpenHashMap[String, Int](1) + assert(goodMap1.size === 0) + val goodMap2 = new OpenHashMap[String, Int](255) + assert(goodMap2.size === 0) + val goodMap3 = new OpenHashMap[String, String](256) + assert(goodMap3.size === 0) + intercept[IllegalArgumentException] { + new OpenHashMap[String, Int](1 << 30) // Invalid map size: bigger than 2^29 + } + intercept[IllegalArgumentException] { + new OpenHashMap[String, Int](-1) + } + intercept[IllegalArgumentException] { + new OpenHashMap[String, String](0) + } + } + + test("primitive value") { + val map = new OpenHashMap[String, Int] + + for (i <- 1 to 1000) { + map(i.toString) = i + assert(map(i.toString) === i) + } + + assert(map.size === 1000) + assert(map(null) === 0) + + map(null) = -1 + assert(map.size === 1001) + assert(map(null) === -1) + + for (i <- 1 to 1000) { + assert(map(i.toString) === i) + } + + // Test iterator + val set = new HashSet[(String, Int)] + for ((k, v) <- map) { + set.add((k, v)) + } + val expected = (1 to 1000).map(x => (x.toString, x)) :+ (null.asInstanceOf[String], -1) + assert(set === expected.toSet) + } + + test("non-primitive value") { + val map = new OpenHashMap[String, String] + + for (i <- 1 to 1000) { + map(i.toString) = i.toString + assert(map(i.toString) === i.toString) + } + + assert(map.size === 1000) + assert(map(null) === null) + + map(null) = "-1" + assert(map.size === 1001) + assert(map(null) === "-1") + + for (i <- 1 to 1000) { + assert(map(i.toString) === i.toString) + } + + // Test iterator + val set = new HashSet[(String, String)] + for ((k, v) <- map) { + set.add((k, v)) + } + val expected = (1 to 1000).map(_.toString).map(x => (x, x)) :+ (null.asInstanceOf[String], "-1") + assert(set === expected.toSet) + } + + test("null keys") { + val map = new OpenHashMap[String, String]() + for (i <- 1 to 100) { + map("" + i) = "" + i + } + assert(map.size === 100) + assert(map(null) === null) + map(null) = "hello" + assert(map.size === 101) + assert(map(null) === "hello") + } + + test("null values") { + val map = new OpenHashMap[String, String]() + for (i <- 1 to 100) { + map("" + i) = null + } + assert(map.size === 100) + assert(map("1") === null) + assert(map(null) === null) + assert(map.size === 100) + map(null) = null + assert(map.size === 101) + assert(map(null) === null) + } + + test("changeValue") { + val map = new OpenHashMap[String, String]() + for (i <- 1 to 100) { + map("" + i) = "" + i + } + assert(map.size === 100) + for (i <- 1 to 100) { + val res = map.changeValue("" + i, { assert(false); "" }, v => { + assert(v === "" + i) + v + "!" + }) + assert(res === i + "!") + } + // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a + // bug where changeValue would return the wrong result when the map grew on that insert + for (i <- 101 to 400) { + val res = map.changeValue("" + i, { i + "!" }, v => { assert(false); v }) + assert(res === i + "!") + } + assert(map.size === 400) + assert(map(null) === null) + map.changeValue(null, { "null!" }, v => { assert(false); v }) + assert(map.size === 401) + map.changeValue(null, { assert(false); "" }, v => { + assert(v === "null!") + "null!!" + }) + assert(map.size === 401) + } + + test("inserting in capacity-1 map") { + val map = new OpenHashMap[String, String](1) + for (i <- 1 to 100) { + map("" + i) = "" + i + } + assert(map.size === 100) + for (i <- 1 to 100) { + assert(map("" + i) === "" + i) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/hash/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/hash/OpenHashSetSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..b5b3a4abe1cdd4efaf99be4600aa5bb4625273c8 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/hash/OpenHashSetSuite.scala @@ -0,0 +1,74 @@ +package org.apache.spark.util.hash + +import org.scalatest.FunSuite + + +class OpenHashSetSuite extends FunSuite { + + test("primitive int") { + val set = new OpenHashSet[Int] + assert(set.size === 0) + set.add(10) + assert(set.size === 1) + set.add(50) + assert(set.size === 2) + set.add(999) + assert(set.size === 3) + set.add(50) + assert(set.size === 3) + } + + test("primitive long") { + val set = new OpenHashSet[Long] + assert(set.size === 0) + set.add(10L) + assert(set.size === 1) + set.add(50L) + assert(set.size === 2) + set.add(999L) + assert(set.size === 3) + set.add(50L) + assert(set.size === 3) + } + + test("non-primitive") { + val set = new OpenHashSet[String] + assert(set.size === 0) + set.add(10.toString) + assert(set.size === 1) + set.add(50.toString) + assert(set.size === 2) + set.add(999.toString) + assert(set.size === 3) + set.add(50.toString) + assert(set.size === 3) + } + + test("non-primitive set growth") { + val set = new OpenHashSet[String] + for (i <- 1 to 1000) { + set.add(i.toString) + } + assert(set.size === 1000) + assert(set.capacity > 1000) + for (i <- 1 to 100) { + set.add(i.toString) + } + assert(set.size === 1000) + assert(set.capacity > 1000) + } + + test("primitive set growth") { + val set = new OpenHashSet[Long] + for (i <- 1 to 1000) { + set.add(i.toLong) + } + assert(set.size === 1000) + assert(set.capacity > 1000) + for (i <- 1 to 100) { + set.add(i.toLong) + } + assert(set.size === 1000) + assert(set.capacity > 1000) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/hash/PrimitiveKeyOpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/hash/PrimitiveKeyOpenHashSetSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..b9a4b545447adc72c2b307961e59b77db24c79ba --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/hash/PrimitiveKeyOpenHashSetSuite.scala @@ -0,0 +1,90 @@ +package org.apache.spark.util.hash + +import scala.collection.mutable.HashSet +import org.scalatest.FunSuite + +class PrimitiveKeyOpenHashSetSuite extends FunSuite { + + test("initialization") { + val goodMap1 = new PrimitiveKeyOpenHashMap[Int, Int](1) + assert(goodMap1.size === 0) + val goodMap2 = new PrimitiveKeyOpenHashMap[Int, Int](255) + assert(goodMap2.size === 0) + val goodMap3 = new PrimitiveKeyOpenHashMap[Int, Int](256) + assert(goodMap3.size === 0) + intercept[IllegalArgumentException] { + new PrimitiveKeyOpenHashMap[Int, Int](1 << 30) // Invalid map size: bigger than 2^29 + } + intercept[IllegalArgumentException] { + new PrimitiveKeyOpenHashMap[Int, Int](-1) + } + intercept[IllegalArgumentException] { + new PrimitiveKeyOpenHashMap[Int, Int](0) + } + } + + test("basic operations") { + val longBase = 1000000L + val map = new PrimitiveKeyOpenHashMap[Long, Int] + + for (i <- 1 to 1000) { + map(i + longBase) = i + assert(map(i + longBase) === i) + } + + assert(map.size === 1000) + + for (i <- 1 to 1000) { + assert(map(i + longBase) === i) + } + + // Test iterator + val set = new HashSet[(Long, Int)] + for ((k, v) <- map) { + set.add((k, v)) + } + assert(set === (1 to 1000).map(x => (x + longBase, x)).toSet) + } + + test("null values") { + val map = new PrimitiveKeyOpenHashMap[Long, String]() + for (i <- 1 to 100) { + map(i.toLong) = null + } + assert(map.size === 100) + assert(map(1.toLong) === null) + } + + test("changeValue") { + val map = new PrimitiveKeyOpenHashMap[Long, String]() + for (i <- 1 to 100) { + map(i.toLong) = "" + i + } + assert(map.size === 100) + for (i <- 1 to 100) { + val res = map.changeValue(i.toLong, { assert(false); "" }, v => { + assert(v === "" + i) + v + "!" + }) + assert(res === i + "!") + } + // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a + // bug where changeValue would return the wrong result when the map grew on that insert + for (i <- 101 to 400) { + val res = map.changeValue(i.toLong, { i + "!" }, v => { assert(false); v }) + assert(res === i + "!") + } + assert(map.size === 400) + } + + test("inserting in capacity-1 map") { + val map = new PrimitiveKeyOpenHashMap[Long, String](1) + for (i <- 1 to 100) { + map(i.toLong) = "" + i + } + assert(map.size === 100) + for (i <- 1 to 100) { + assert(map(i.toLong) === "" + i) + } + } +} diff --git a/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala b/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala index 13edd8fd1af9f18d5a9b0a8c0a636a44ee22a73e..821063e1f811bd4f86a0a636381cbf6048f39e28 100644 --- a/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala +++ b/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala @@ -1,5 +1,8 @@ package org.apache.spark.graph +import org.apache.spark.util.hash.BitSet + + import com.esotericsoftware.kryo.Kryo import org.apache.spark.graph.impl.MessageToPartition @@ -16,6 +19,7 @@ class GraphKryoRegistrator extends KryoRegistrator { kryo.register(classOf[(Vid, Object)]) kryo.register(classOf[EdgePartition[Object]]) kryo.register(classOf[BitSet]) + kryo.register(classOf[VertexIdToIndexMap]) // This avoids a large number of hash table lookups. kryo.setReferences(false) diff --git a/graph/src/main/scala/org/apache/spark/graph/VertexSetRDD.scala b/graph/src/main/scala/org/apache/spark/graph/VertexSetRDD.scala index 9b7e0432726d55725fc3e3a608563cf4bba5b88b..8acc89a95b8f1f0daca39042a03ef07638f50a5e 100644 --- a/graph/src/main/scala/org/apache/spark/graph/VertexSetRDD.scala +++ b/graph/src/main/scala/org/apache/spark/graph/VertexSetRDD.scala @@ -20,12 +20,9 @@ package org.apache.spark.graph import java.nio.ByteBuffer -import java.util.{HashMap => JHashMap, BitSet => JBitSet, HashSet => JHashSet} import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.BitSet - import org.apache.spark._ import org.apache.spark.rdd._ @@ -33,7 +30,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.Partitioner._ import org.apache.spark.storage.StorageLevel - +import org.apache.spark.util.hash.BitSet @@ -167,7 +164,7 @@ class VertexSetRDD[@specialized V: ClassManifest]( // Walk the index to construct the key, value pairs indexMap.iterator // Extract rows with key value pairs and indicators - .map{ case (k, ind) => (bs(ind), k, ind) } + .map{ case (k, ind) => (bs.get(ind), k, ind) } // Remove tuples that aren't actually present in the array .filter( _._1 ) // Extract the pair (removing the indicator from the tuple) @@ -188,7 +185,7 @@ class VertexSetRDD[@specialized V: ClassManifest]( * modifies the bitmap index and so no new values are allocated. */ override def filter(pred: Tuple2[Vid,V] => Boolean): VertexSetRDD[V] = { - val cleanF = index.rdd.context.clean(pred) + val cleanPred = index.rdd.context.clean(pred) val newValues = index.rdd.zipPartitions(valuesRDD){ (keysIter: Iterator[VertexIdToIndexMap], valuesIter: Iterator[(IndexedSeq[V], BitSet)]) => @@ -200,7 +197,9 @@ class VertexSetRDD[@specialized V: ClassManifest]( val newBS = new BitSet(oldValues.size) // Populate the new Values for( (k,i) <- index ) { - newBS(i) = bs(i) && cleanF( (k, oldValues(i)) ) + if( bs.get(i) && cleanPred( (k, oldValues(i)) ) ) { + newBS.set(i) + } } Array((oldValues, newBS)).iterator } @@ -224,6 +223,7 @@ class VertexSetRDD[@specialized V: ClassManifest]( val newValuesRDD: RDD[ (IndexedSeq[U], BitSet) ] = valuesRDD.mapPartitions(iter => iter.map{ case (values, bs: BitSet) => + /** * @todo Consider using a view rather than creating a new * array. This is already being done for join operations. @@ -231,8 +231,17 @@ class VertexSetRDD[@specialized V: ClassManifest]( * recomputation. */ val newValues = new Array[U](values.size) - for ( ind <- bs ) { - newValues(ind) = f(values(ind)) + var ind = bs.nextSetBit(0) + while(ind >= 0) { + // if(ind >= newValues.size) { + // println(ind) + // println(newValues.size) + // bs.iterator.foreach(print(_)) + // } + // assert(ind < newValues.size) + // assert(ind < values.size) + newValues(ind) = cleanF(values(ind)) + ind = bs.nextSetBit(ind+1) } (newValues.toIndexedSeq, bs) }, preservesPartitioning = true) @@ -271,7 +280,7 @@ class VertexSetRDD[@specialized V: ClassManifest]( val newValues: Array[U] = new Array[U](oldValues.size) // Populate the new Values for( (k,i) <- index ) { - if (bs(i)) { newValues(i) = f(k, oldValues(i)) } + if (bs.get(i)) { newValues(i) = cleanF(k, oldValues(i)) } } Array((newValues.toIndexedSeq, bs)).iterator } @@ -304,7 +313,7 @@ class VertexSetRDD[@specialized V: ClassManifest]( assert(!thisIter.hasNext) val (otherValues, otherBS: BitSet) = otherIter.next() assert(!otherIter.hasNext) - val newBS = thisBS & otherBS + val newBS: BitSet = thisBS & otherBS val newValues = thisValues.view.zip(otherValues) Iterator((newValues.toIndexedSeq, newBS)) } @@ -340,7 +349,7 @@ class VertexSetRDD[@specialized V: ClassManifest]( val (otherValues, otherBS: BitSet) = otherIter.next() assert(!otherIter.hasNext) val otherOption = otherValues.view.zipWithIndex - .map{ (x: (W, Int)) => if(otherBS(x._2)) Option(x._1) else None } + .map{ (x: (W, Int)) => if(otherBS.get(x._2)) Option(x._1) else None } val newValues = thisValues.view.zip(otherOption) Iterator((newValues.toIndexedSeq, thisBS)) } @@ -406,19 +415,19 @@ class VertexSetRDD[@specialized V: ClassManifest]( val ind = index.get(k) // Not all the vertex ids in the index are in this VertexSet. // If there is a vertex in this set then record the other value - if(thisBS(ind)) { - if(wBS(ind)) { + if(thisBS.get(ind)) { + if(wBS.get(ind)) { newW(ind) = cleanMerge(newW(ind), w) } else { newW(ind) = w - wBS(ind) = true + wBS.set(ind) } } } // end of for loop over tuples // Some vertices in this vertex set may not have a corresponding // tuple in the join and so a None value should be returned. val otherOption = newW.view.zipWithIndex - .map{ (x: (W, Int)) => if(wBS(x._2)) Option(x._1) else None } + .map{ (x: (W, Int)) => if(wBS.get(x._2)) Option(x._1) else None } // the final values is the zip of the values in this RDD along with // the values in the other val newValues = thisValues.view.zip(otherOption) @@ -456,10 +465,13 @@ class VertexSetRDD[@specialized V: ClassManifest]( */ val newValues = new Array[(Seq[V], Seq[W])](thisValues.size) val newBS = thisBS | otherBS - for( ind <- newBS ) { - val a = if (thisBS(ind)) Seq(thisValues(ind)) else Seq.empty[V] - val b = if (otherBS(ind)) Seq(otherValues(ind)) else Seq.empty[W] + + var ind = newBS.nextSetBit(0) + while(ind >= 0) { + val a = if (thisBS.get(ind)) Seq(thisValues(ind)) else Seq.empty[V] + val b = if (otherBS.get(ind)) Seq(otherValues(ind)) else Seq.empty[W] newValues(ind) = (a, b) + ind = newBS.nextSetBit(ind+1) } Iterator((newValues.toIndexedSeq, newBS)) } @@ -511,17 +523,17 @@ class VertexSetRDD[@specialized V: ClassManifest]( // Get the left key val a = if (thisIndex.contains(k)) { val ind = thisIndex.get(k) - if(thisBS(ind)) Seq(thisValues(ind)) else Seq.empty[V] + if(thisBS.get(ind)) Seq(thisValues(ind)) else Seq.empty[V] } else Seq.empty[V] // Get the right key val b = if (otherIndex.contains(k)) { val ind = otherIndex.get(k) - if (otherBS(ind)) Seq(otherValues(ind)) else Seq.empty[W] + if (otherBS.get(ind)) Seq(otherValues(ind)) else Seq.empty[W] } else Seq.empty[W] // If at least one key was present then we generate a tuple. if (!a.isEmpty || !b.isEmpty) { newValues(ind) = (a, b) - newBS(ind) = true + newBS.set(ind) } } Iterator((newValues.toIndexedSeq, newBS)) @@ -554,28 +566,28 @@ class VertexSetRDD[@specialized V: ClassManifest]( val newBS = new BitSet(thisValues.size) // populate the newValues with the values in this VertexSetRDD for ((k,i) <- thisIndex) { - if (thisBS(i)) { + if (thisBS.get(i)) { newValues(i) = (Seq(thisValues(i)), ArrayBuffer.empty[W]) - newBS(i) = true + newBS.set(i) } } // Now iterate through the other tuples updating the map for ((k,w) <- otherTuplesIter){ if (newIndex.contains(k)) { val ind = newIndex.get(k) - if(newBS(ind)) { + if(newBS.get(ind)) { newValues(ind)._2.asInstanceOf[ArrayBuffer[W]].append(w) } else { // If the other key was in the index but not in the values // of this indexed RDD then create a new values entry for it - newBS(ind) = true + newBS.set(ind) newValues(ind) = (Seq.empty[V], ArrayBuffer(w)) } } else { // update the index val ind = newIndex.size newIndex.put(k, ind) - newBS(ind) = true + newBS.set(ind) // Update the values newValues.append( (Seq.empty[V], ArrayBuffer(w) ) ) } @@ -592,6 +604,8 @@ class VertexSetRDD[@specialized V: ClassManifest]( } } } // end of cogroup + + } // End of VertexSetRDD @@ -637,18 +651,18 @@ object VertexSetRDD { val groups = preAgg.mapPartitions( iter => { val indexMap = new VertexIdToIndexMap() val values = new ArrayBuffer[V] - val bs = new BitSet for ((k,v) <- iter) { if(!indexMap.contains(k)) { val ind = indexMap.size indexMap.put(k, ind) values.append(v) - bs(ind) = true } else { val ind = indexMap.get(k) values(ind) = reduceFunc(values(ind), v) } } + val bs = new BitSet(indexMap.size) + bs.setUntil(indexMap.size) Iterator( (indexMap, (values.toIndexedSeq, bs)) ) }, true).cache // extract the index and the values @@ -736,16 +750,17 @@ object VertexSetRDD { val values = new Array[C](index.size) val bs = new BitSet(index.size) for ((k,c) <- tblIter) { + // @todo this extra check may be costing us a lot! if (!index.contains(k)) { throw new SparkException("Error: Trying to bind an external index " + "to an RDD which contains keys that are not in the index.") } val ind = index(k) - if (bs(ind)) { + if (bs.get(ind)) { values(ind) = mergeCombiners(values(ind), c) } else { values(ind) = c - bs(ind) = true + bs.set(ind) } } Iterator((values, bs))