diff --git a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
index 6b74a29aceda999af0d6aa4844fcd15c663400d9..bcb95b416dd25e0de50151d010dfe1168ba457f3 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala
@@ -140,16 +140,16 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64)
     var i = 1
     while (true) {
       val curKey = data(2 * pos)
-      if (k.eq(curKey) || k.equals(curKey)) {
-        val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
-        data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
-        return newValue
-      } else if (curKey.eq(null)) {
+      if (curKey.eq(null)) {
         val newValue = updateFunc(false, null.asInstanceOf[V])
         data(2 * pos) = k
         data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
         incrementSize()
         return newValue
+      } else if (k.eq(curKey) || k.equals(curKey)) {
+        val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
+        data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
+        return newValue
       } else {
         val delta = i
         pos = (pos + delta) & mask