diff --git a/core/src/main/scala/org/apache/spark/util/hash/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/hash/OpenHashMap.scala
index a376d1015a314815ac3cd75e76b1629d544ed62c..af282d5651292469051c678d7b49731ab6a30af4 100644
--- a/core/src/main/scala/org/apache/spark/util/hash/OpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/hash/OpenHashMap.scala
@@ -27,14 +27,21 @@ package org.apache.spark.util.hash
  */
 private[spark]
 class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V: ClassManifest](
-    initialCapacity: Int)
+  var keySet: OpenHashSet[K], var _values: Array[V])
   extends Iterable[(K, V)]
   with Serializable {
 
-  def this() = this(64)
+  /**
+   * Allocate an OpenHashMap with a fixed initial capacity
+   */
+  def this(initialCapacity: Int = 64) = 
+    this(new OpenHashSet[K](initialCapacity), new Array[V](initialCapacity))
+
+  /**
+   * Allocate an OpenHashMap with a fixed initial capacity
+   */
+  def this(keySet: OpenHashSet[K]) = this(keySet, new Array[V](keySet.capacity))
 
-  protected var _keySet = new OpenHashSet[K](initialCapacity)
-  private var _values = new Array[V](_keySet.capacity)
 
   @transient private var _oldValues: Array[V] = null
 
@@ -42,14 +49,14 @@ class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V:
   private var haveNullValue = false
   private var nullValue: V = null.asInstanceOf[V]
 
-  override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size
+  override def size: Int = if (haveNullValue) keySet.size + 1 else keySet.size
 
   /** Get the value for a given key */
   def apply(k: K): V = {
     if (k == null) {
       nullValue
     } else {
-      val pos = _keySet.getPos(k)
+      val pos = keySet.getPos(k)
       if (pos < 0) {
         null.asInstanceOf[V]
       } else {
@@ -64,9 +71,26 @@ class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V:
       haveNullValue = true
       nullValue = v
     } else {
-      val pos = _keySet.fastAdd(k) & OpenHashSet.POSITION_MASK
+      val pos = keySet.fastAdd(k) & OpenHashSet.POSITION_MASK
       _values(pos) = v
-      _keySet.rehashIfNeeded(k, grow, move)
+      keySet.rehashIfNeeded(k, grow, move)
+      _oldValues = null
+    }
+  }
+
+  /** Set the value for a key */
+  def update(k: K, v: V, mergeF: (V,V) => V) {
+    if (k == null) {
+      if(haveNullValue) {
+        nullValue = mergeF(nullValue, v)
+      } else {
+        haveNullValue = true
+        nullValue = v
+      }
+    } else {
+      val pos = keySet.fastAdd(k) & OpenHashSet.POSITION_MASK
+      _values(pos) = mergeF(_values(pos), v)
+      keySet.rehashIfNeeded(k, grow, move)
       _oldValues = null
     }
   }
@@ -87,11 +111,11 @@ class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V:
       }
       nullValue
     } else {
-      val pos = _keySet.fastAdd(k)
+      val pos = keySet.fastAdd(k)
       if ((pos & OpenHashSet.EXISTENCE_MASK) != 0) {
         val newValue = defaultValue
         _values(pos & OpenHashSet.POSITION_MASK) = newValue
-        _keySet.rehashIfNeeded(k, grow, move)
+        keySet.rehashIfNeeded(k, grow, move)
         newValue
       } else {
         _values(pos) = mergeValue(_values(pos))
@@ -113,9 +137,9 @@ class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V:
         }
         pos += 1
       }
-      pos = _keySet.nextPos(pos)
+      pos = keySet.nextPos(pos)
       if (pos >= 0) {
-        val ret = (_keySet.getValue(pos), _values(pos))
+        val ret = (keySet.getValue(pos), _values(pos))
         pos += 1
         ret
       } else {
@@ -146,3 +170,4 @@ class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V:
     _values(newPos) = _oldValues(oldPos)
   }
 }
+
diff --git a/core/src/main/scala/org/apache/spark/util/hash/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/hash/PrimitiveKeyOpenHashMap.scala
index 14c136720788aee3c64507f82cedae87de2a2c11..cbfb2361b419e04034393e67e8220c5f7b54620d 100644
--- a/core/src/main/scala/org/apache/spark/util/hash/PrimitiveKeyOpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/hash/PrimitiveKeyOpenHashMap.scala
@@ -28,35 +28,56 @@ package org.apache.spark.util.hash
 private[spark]
 class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest,
                               @specialized(Long, Int, Double) V: ClassManifest](
-    initialCapacity: Int)
+    var keySet: OpenHashSet[K], var _values: Array[V])
   extends Iterable[(K, V)]
   with Serializable {
 
-  def this() = this(64)
+  /**
+   * Allocate an OpenHashMap with a fixed initial capacity
+   */
+  def this(initialCapacity: Int = 64) = 
+    this(new OpenHashSet[K](initialCapacity), new Array[V](initialCapacity))
 
-  require(classManifest[K] == classManifest[Long] || classManifest[K] == classManifest[Int])
+  /**
+   * Allocate an OpenHashMap with a fixed initial capacity
+   */
+  def this(keySet: OpenHashSet[K]) = this(keySet, new Array[V](keySet.capacity))
 
-  protected var _keySet = new OpenHashSet[K](initialCapacity)
-  private var _values = new Array[V](_keySet.capacity)
+  require(classManifest[K] == classManifest[Long] || classManifest[K] == classManifest[Int])
 
   private var _oldValues: Array[V] = null
 
-  override def size = _keySet.size
+  override def size = keySet.size
 
   /** Get the value for a given key */
   def apply(k: K): V = {
-    val pos = _keySet.getPos(k)
+    val pos = keySet.getPos(k)
     _values(pos)
   }
 
   /** Set the value for a key */
   def update(k: K, v: V) {
-    val pos = _keySet.fastAdd(k) & OpenHashSet.POSITION_MASK
+    val pos = keySet.fastAdd(k) & OpenHashSet.POSITION_MASK
     _values(pos) = v
-    _keySet.rehashIfNeeded(k, grow, move)
+    keySet.rehashIfNeeded(k, grow, move)
+    _oldValues = null
+  }
+
+
+  /** Set the value for a key */
+  def update(k: K, v: V, mergeF: (V,V) => V) {
+    val pos = keySet.fastAdd(k)
+    val ind = pos & OpenHashSet.POSITION_MASK
+    if ((pos & OpenHashSet.EXISTENCE_MASK) != 0) { // if first add
+      _values(ind) = v
+    } else {
+      _values(ind) = mergeF(_values(ind), v)
+    }
+    keySet.rehashIfNeeded(k, grow, move)
     _oldValues = null
   }
 
+
   /**
    * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise,
    * set its value to mergeValue(oldValue).
@@ -64,11 +85,11 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest,
    * @return the newly updated value.
    */
   def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = {
-    val pos = _keySet.fastAdd(k)
+    val pos = keySet.fastAdd(k)
     if ((pos & OpenHashSet.EXISTENCE_MASK) != 0) {
       val newValue = defaultValue
       _values(pos & OpenHashSet.POSITION_MASK) = newValue
-      _keySet.rehashIfNeeded(k, grow, move)
+      keySet.rehashIfNeeded(k, grow, move)
       newValue
     } else {
       _values(pos) = mergeValue(_values(pos))
@@ -82,9 +103,9 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest,
 
     /** Get the next value we should return from next(), or null if we're finished iterating */
     def computeNextPair(): (K, V) = {
-      pos = _keySet.nextPos(pos)
+      pos = keySet.nextPos(pos)
       if (pos >= 0) {
-        val ret = (_keySet.getValue(pos), _values(pos))
+        val ret = (keySet.getValue(pos), _values(pos))
         pos += 1
         ret
       } else {