Skip to content
Snippets Groups Projects
Commit 41d6586e authored by Ankur Dave's avatar Ankur Dave
Browse files

Revert changes to Spark's (PrimitiveKey)OpenHashMap; copy PKOHM to graphx

parent 85a6645d
No related branches found
No related tags found
No related merge requests found
...@@ -28,20 +28,18 @@ import scala.reflect.ClassTag ...@@ -28,20 +28,18 @@ import scala.reflect.ClassTag
*/ */
private[spark] private[spark]
class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
val keySet: OpenHashSet[K], var _values: Array[V]) initialCapacity: Int)
extends Iterable[(K, V)] extends Iterable[(K, V)]
with Serializable { with Serializable {
/** def this() = this(64)
* Allocate an OpenHashMap with a fixed initial capacity
*/
def this(initialCapacity: Int = 64) =
this(new OpenHashSet[K](initialCapacity), new Array[V](initialCapacity))
/** protected var _keySet = new OpenHashSet[K](initialCapacity)
* Allocate an OpenHashMap with a fixed initial capacity
*/ // Init in constructor (instead of in declaration) to work around a Scala compiler specialization
def this(keySet: OpenHashSet[K]) = this(keySet, new Array[V](keySet.capacity)) // bug that would generate two arrays (one for Object and one for specialized T).
private var _values: Array[V] = _
_values = new Array[V](_keySet.capacity)
@transient private var _oldValues: Array[V] = null @transient private var _oldValues: Array[V] = null
...@@ -49,14 +47,14 @@ class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: Class ...@@ -49,14 +47,14 @@ class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: Class
private var haveNullValue = false private var haveNullValue = false
private var nullValue: V = null.asInstanceOf[V] private var nullValue: V = null.asInstanceOf[V]
override def size: Int = if (haveNullValue) keySet.size + 1 else keySet.size override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size
/** Get the value for a given key */ /** Get the value for a given key */
def apply(k: K): V = { def apply(k: K): V = {
if (k == null) { if (k == null) {
nullValue nullValue
} else { } else {
val pos = keySet.getPos(k) val pos = _keySet.getPos(k)
if (pos < 0) { if (pos < 0) {
null.asInstanceOf[V] null.asInstanceOf[V]
} else { } else {
...@@ -71,26 +69,9 @@ class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: Class ...@@ -71,26 +69,9 @@ class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: Class
haveNullValue = true haveNullValue = true
nullValue = v nullValue = v
} else { } else {
val pos = keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
_values(pos) = v _values(pos) = v
keySet.rehashIfNeeded(k, grow, move) _keySet.rehashIfNeeded(k, grow, move)
_oldValues = null
}
}
/** Set the value for a key */
def setMerge(k: K, v: V, mergeF: (V,V) => V) {
if (k == null) {
if(haveNullValue) {
nullValue = mergeF(nullValue, v)
} else {
haveNullValue = true
nullValue = v
}
} else {
val pos = keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
_values(pos) = mergeF(_values(pos), v)
keySet.rehashIfNeeded(k, grow, move)
_oldValues = null _oldValues = null
} }
} }
...@@ -111,11 +92,11 @@ class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: Class ...@@ -111,11 +92,11 @@ class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: Class
} }
nullValue nullValue
} else { } else {
val pos = keySet.addWithoutResize(k) val pos = _keySet.addWithoutResize(k)
if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
val newValue = defaultValue val newValue = defaultValue
_values(pos & OpenHashSet.POSITION_MASK) = newValue _values(pos & OpenHashSet.POSITION_MASK) = newValue
keySet.rehashIfNeeded(k, grow, move) _keySet.rehashIfNeeded(k, grow, move)
newValue newValue
} else { } else {
_values(pos) = mergeValue(_values(pos)) _values(pos) = mergeValue(_values(pos))
...@@ -137,9 +118,9 @@ class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: Class ...@@ -137,9 +118,9 @@ class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: Class
} }
pos += 1 pos += 1
} }
pos = keySet.nextPos(pos) pos = _keySet.nextPos(pos)
if (pos >= 0) { if (pos >= 0) {
val ret = (keySet.getValue(pos), _values(pos)) val ret = (_keySet.getValue(pos), _values(pos))
pos += 1 pos += 1
ret ret
} else { } else {
......
...@@ -29,68 +29,45 @@ import scala.reflect._ ...@@ -29,68 +29,45 @@ import scala.reflect._
private[spark] private[spark]
class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
@specialized(Long, Int, Double) V: ClassTag]( @specialized(Long, Int, Double) V: ClassTag](
val keySet: OpenHashSet[K], var _values: Array[V]) initialCapacity: Int)
extends Iterable[(K, V)] extends Iterable[(K, V)]
with Serializable { with Serializable {
/**
* Allocate an OpenHashMap with a fixed initial capacity
*/
def this(initialCapacity: Int) =
this(new OpenHashSet[K](initialCapacity), new Array[V](initialCapacity))
/**
* Allocate an OpenHashMap with a default initial capacity, providing a true
* no-argument constructor.
*/
def this() = this(64) def this() = this(64)
/**
* Allocate an OpenHashMap with a fixed initial capacity
*/
def this(keySet: OpenHashSet[K]) = this(keySet, new Array[V](keySet.capacity))
require(classTag[K] == classTag[Long] || classTag[K] == classTag[Int]) require(classTag[K] == classTag[Long] || classTag[K] == classTag[Int])
// Init in constructor (instead of in declaration) to work around a Scala compiler specialization
// bug that would generate two arrays (one for Object and one for specialized T).
protected var _keySet: OpenHashSet[K] = _
private var _values: Array[V] = _
_keySet = new OpenHashSet[K](initialCapacity)
_values = new Array[V](_keySet.capacity)
private var _oldValues: Array[V] = null private var _oldValues: Array[V] = null
override def size = keySet.size override def size = _keySet.size
/** Get the value for a given key */ /** Get the value for a given key */
def apply(k: K): V = { def apply(k: K): V = {
val pos = keySet.getPos(k) val pos = _keySet.getPos(k)
_values(pos) _values(pos)
} }
/** Get the value for a given key, or returns elseValue if it doesn't exist. */ /** Get the value for a given key, or returns elseValue if it doesn't exist. */
def getOrElse(k: K, elseValue: V): V = { def getOrElse(k: K, elseValue: V): V = {
val pos = keySet.getPos(k) val pos = _keySet.getPos(k)
if (pos >= 0) _values(pos) else elseValue if (pos >= 0) _values(pos) else elseValue
} }
/** Set the value for a key */ /** Set the value for a key */
def update(k: K, v: V) { def update(k: K, v: V) {
val pos = keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
_values(pos) = v _values(pos) = v
keySet.rehashIfNeeded(k, grow, move) _keySet.rehashIfNeeded(k, grow, move)
_oldValues = null _oldValues = null
} }
/** Set the value for a key */
def setMerge(k: K, v: V, mergeF: (V, V) => V) {
val pos = keySet.addWithoutResize(k)
val ind = pos & OpenHashSet.POSITION_MASK
if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { // if first add
_values(ind) = v
} else {
_values(ind) = mergeF(_values(ind), v)
}
keySet.rehashIfNeeded(k, grow, move)
_oldValues = null
}
/** /**
* If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise, * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise,
* set its value to mergeValue(oldValue). * set its value to mergeValue(oldValue).
...@@ -98,11 +75,11 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, ...@@ -98,11 +75,11 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
* @return the newly updated value. * @return the newly updated value.
*/ */
def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = { def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = {
val pos = keySet.addWithoutResize(k) val pos = _keySet.addWithoutResize(k)
if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
val newValue = defaultValue val newValue = defaultValue
_values(pos & OpenHashSet.POSITION_MASK) = newValue _values(pos & OpenHashSet.POSITION_MASK) = newValue
keySet.rehashIfNeeded(k, grow, move) _keySet.rehashIfNeeded(k, grow, move)
newValue newValue
} else { } else {
_values(pos) = mergeValue(_values(pos)) _values(pos) = mergeValue(_values(pos))
...@@ -116,9 +93,9 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, ...@@ -116,9 +93,9 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
/** Get the next value we should return from next(), or null if we're finished iterating */ /** Get the next value we should return from next(), or null if we're finished iterating */
def computeNextPair(): (K, V) = { def computeNextPair(): (K, V) = {
pos = keySet.nextPos(pos) pos = _keySet.nextPos(pos)
if (pos >= 0) { if (pos >= 0) {
val ret = (keySet.getValue(pos), _values(pos)) val ret = (_keySet.getValue(pos), _values(pos))
pos += 1 pos += 1
ret ret
} else { } else {
......
...@@ -3,7 +3,7 @@ package org.apache.spark.graphx.impl ...@@ -3,7 +3,7 @@ package org.apache.spark.graphx.impl
import scala.reflect.ClassTag import scala.reflect.ClassTag
import org.apache.spark.graphx._ import org.apache.spark.graphx._
import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
/** /**
* A collection of edges stored in 3 large columnar arrays (src, dst, attribute). The arrays are * A collection of edges stored in 3 large columnar arrays (src, dst, attribute). The arrays are
......
...@@ -4,7 +4,8 @@ import scala.reflect.ClassTag ...@@ -4,7 +4,8 @@ import scala.reflect.ClassTag
import scala.util.Sorting import scala.util.Sorting
import org.apache.spark.graphx._ import org.apache.spark.graphx._
import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector} import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
import org.apache.spark.util.collection.PrimitiveVector
class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag](size: Int = 64) { class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag](size: Int = 64) {
var edges = new PrimitiveVector[Edge[ED]](size) var edges = new PrimitiveVector[Edge[ED]](size)
......
...@@ -3,7 +3,7 @@ package org.apache.spark.graphx.impl ...@@ -3,7 +3,7 @@ package org.apache.spark.graphx.impl
import scala.reflect.ClassTag import scala.reflect.ClassTag
import org.apache.spark.graphx._ import org.apache.spark.graphx._
import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
/** /**
* The Iterator type returned when constructing edge triplets. This class technically could be * The Iterator type returned when constructing edge triplets. This class technically could be
......
...@@ -173,9 +173,6 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( ...@@ -173,9 +173,6 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
override def mapTriplets[ED2: ClassTag]( override def mapTriplets[ED2: ClassTag](
f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]): Graph[VD, ED2] = { f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]): Graph[VD, ED2] = {
// Use an explicit manifest in PrimitiveKeyOpenHashMap init so we don't pull in the implicit
// manifest from GraphImpl (which would require serializing GraphImpl).
val vdTag = classTag[VD]
val newEdgePartitions = val newEdgePartitions =
edges.partitionsRDD.zipPartitions(replicatedVertexView.get(true, true), true) { edges.partitionsRDD.zipPartitions(replicatedVertexView.get(true, true), true) {
(ePartIter, vTableReplicatedIter) => (ePartIter, vTableReplicatedIter) =>
......
...@@ -2,11 +2,10 @@ package org.apache.spark.graphx.impl ...@@ -2,11 +2,10 @@ package org.apache.spark.graphx.impl
import scala.reflect.ClassTag import scala.reflect.ClassTag
import org.apache.spark.util.collection.{BitSet, PrimitiveKeyOpenHashMap}
import org.apache.spark.Logging import org.apache.spark.Logging
import org.apache.spark.graphx._ import org.apache.spark.graphx._
import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
import org.apache.spark.util.collection.BitSet
private[graphx] object VertexPartition { private[graphx] object VertexPartition {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.graphx.util.collection
import org.apache.spark.util.collection.OpenHashSet
import scala.reflect._
/**
* A fast hash map implementation for primitive, non-null keys. This hash map supports
* insertions and updates, but not deletions. This map is about an order of magnitude
* faster than java.util.HashMap, while using much less space overhead.
*
* Under the hood, it uses our OpenHashSet implementation.
*/
private[spark]
class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
@specialized(Long, Int, Double) V: ClassTag](
val keySet: OpenHashSet[K], var _values: Array[V])
extends Iterable[(K, V)]
with Serializable {
/**
* Allocate an OpenHashMap with a fixed initial capacity
*/
def this(initialCapacity: Int) =
this(new OpenHashSet[K](initialCapacity), new Array[V](initialCapacity))
/**
* Allocate an OpenHashMap with a default initial capacity, providing a true
* no-argument constructor.
*/
def this() = this(64)
/**
* Allocate an OpenHashMap with a fixed initial capacity
*/
def this(keySet: OpenHashSet[K]) = this(keySet, new Array[V](keySet.capacity))
require(classTag[K] == classTag[Long] || classTag[K] == classTag[Int])
private var _oldValues: Array[V] = null
override def size = keySet.size
/** Get the value for a given key */
def apply(k: K): V = {
val pos = keySet.getPos(k)
_values(pos)
}
/** Get the value for a given key, or returns elseValue if it doesn't exist. */
def getOrElse(k: K, elseValue: V): V = {
val pos = keySet.getPos(k)
if (pos >= 0) _values(pos) else elseValue
}
/** Set the value for a key */
def update(k: K, v: V) {
val pos = keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
_values(pos) = v
keySet.rehashIfNeeded(k, grow, move)
_oldValues = null
}
/** Set the value for a key */
def setMerge(k: K, v: V, mergeF: (V, V) => V) {
val pos = keySet.addWithoutResize(k)
val ind = pos & OpenHashSet.POSITION_MASK
if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { // if first add
_values(ind) = v
} else {
_values(ind) = mergeF(_values(ind), v)
}
keySet.rehashIfNeeded(k, grow, move)
_oldValues = null
}
/**
* If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise,
* set its value to mergeValue(oldValue).
*
* @return the newly updated value.
*/
def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = {
val pos = keySet.addWithoutResize(k)
if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
val newValue = defaultValue
_values(pos & OpenHashSet.POSITION_MASK) = newValue
keySet.rehashIfNeeded(k, grow, move)
newValue
} else {
_values(pos) = mergeValue(_values(pos))
_values(pos)
}
}
override def iterator = new Iterator[(K, V)] {
var pos = 0
var nextPair: (K, V) = computeNextPair()
/** Get the next value we should return from next(), or null if we're finished iterating */
def computeNextPair(): (K, V) = {
pos = keySet.nextPos(pos)
if (pos >= 0) {
val ret = (keySet.getValue(pos), _values(pos))
pos += 1
ret
} else {
null
}
}
def hasNext = nextPair != null
def next() = {
val pair = nextPair
nextPair = computeNextPair()
pair
}
}
// The following member variables are declared as protected instead of private for the
// specialization to work (specialized class extends the unspecialized one and needs access
// to the "private" variables).
// They also should have been val's. We use var's because there is a Scala compiler bug that
// would throw illegal access error at runtime if they are declared as val's.
protected var grow = (newCapacity: Int) => {
_oldValues = _values
_values = new Array[V](newCapacity)
}
protected var move = (oldPos: Int, newPos: Int) => {
_values(newPos) = _oldValues(oldPos)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment