diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
index a7a8625c9220474ce3cb1e373153806723ce0dd8..f60deafc6f32386bd1640fefaee233942936cf76 100644
--- a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
@@ -33,6 +33,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
   require(initialCapacity >= 1, "Invalid initial capacity")
 
   private var capacity = nextPowerOf2(initialCapacity)
+  private var mask = capacity - 1
   private var curSize = 0
 
   // Holds keys and values in the same array for memory locality; specifically, the order of
@@ -51,13 +52,14 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
     if (k.eq(null)) {
       return nullValue
     }
-    val mask = capacity - 1
     var pos = rehash(k.hashCode) & mask
     var i = 1
     while (true) {
       val curKey = data(2 * pos)
-      if (curKey.eq(k) || curKey.eq(null) || curKey == k) {
+      if (k.eq(curKey) || k == curKey) {
         return data(2 * pos + 1).asInstanceOf[V]
+      } else if (curKey.eq(null)) {
+        return null.asInstanceOf[V]
       } else {
         val delta = i
         pos = (pos + delta) & mask
@@ -68,7 +70,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
   }
 
   /** Set the value for a key */
-  def update(key: K, value: V) {
+  def update(key: K, value: V): Unit = {
     val k = key.asInstanceOf[AnyRef]
     if (k.eq(null)) {
       if (!haveNullValue) {
@@ -98,21 +100,20 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
       haveNullValue = true
       return nullValue
     }
-    val mask = capacity - 1
     var pos = rehash(k.hashCode) & mask
     var i = 1
     while (true) {
       val curKey = data(2 * pos)
-      if (curKey.eq(null)) {
+      if (k.eq(curKey) || k == curKey) {
+        val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
+        data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
+        return newValue
+      } else if (curKey.eq(null)) {
         val newValue = updateFunc(false, null.asInstanceOf[V])
         data(2 * pos) = k
         data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
         incrementSize()
         return newValue
-      } else if (curKey.eq(k) || curKey == k) {
-        val newValue = updateFunc(true, data(2*pos + 1).asInstanceOf[V])
-        data(2*pos + 1) = newValue.asInstanceOf[AnyRef]
-        return newValue
       } else {
         val delta = i
         pos = (pos + delta) & mask
@@ -219,6 +220,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
     }
     data = newData
     capacity = newCapacity
+    mask = newCapacity - 1
   }
 
   private def nextPowerOf2(n: Int): Int = {