From 6930e965e26d39fa6c26ae67a08b4c4d0368d556 Mon Sep 17 00:00:00 2001
From: Xiangrui Meng <meng@databricks.com>
Date: Tue, 24 Mar 2015 17:06:22 -0700
Subject: [PATCH] [SPARK-6512] add contains to OpenHashMap

Add `contains` to test whether a key exists in an OpenHashMap. rxin

Author: Xiangrui Meng <meng@databricks.com>

Closes #5171 from mengxr/openhashmap-contains and squashes the following commits:

d6e6f1f [Xiangrui Meng] add contains to primitivekeyopenhashmap
748a69b [Xiangrui Meng] add contains to OpenHashMap
---
 .../org/apache/spark/util/collection/OpenHashMap.scala |  9 +++++++++
 .../util/collection/PrimitiveKeyOpenHashMap.scala      |  5 +++++
 .../spark/util/collection/OpenHashMapSuite.scala       | 10 ++++++++++
 .../util/collection/PrimitiveKeyOpenHashMapSuite.scala |  7 +++++++
 4 files changed, 31 insertions(+)

diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
index c52591b352..efc2482c74 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
@@ -53,6 +53,15 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
 
   override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size
 
+  /** Tests whether this map contains a binding for a key. */
+  def contains(k: K): Boolean = {
+    if (k == null) {
+      haveNullValue
+    } else {
+      _keySet.getPos(k) != OpenHashSet.INVALID_POS
+    }
+  }
+
   /** Get the value for a given key */
   def apply(k: K): V = {
     if (k == null) {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
index 61e2264276..b4ec4ea521 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
@@ -48,6 +48,11 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
 
   override def size: Int = _keySet.size
 
+  /** Tests whether this map contains a binding for a key. */
+  def contains(k: K): Boolean = {
+    _keySet.getPos(k) != OpenHashSet.INVALID_POS
+  }
+
   /** Get the value for a given key */
   def apply(k: K): V = {
     val pos = _keySet.getPos(k)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
index 6a70877356..ef890d2ba6 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
@@ -176,4 +176,14 @@ class OpenHashMapSuite extends FunSuite with Matchers {
       assert(map(i.toString) === i.toString)
     }
   }
+
+  test("contains") {
+    val map = new OpenHashMap[String, Int](2)
+    map("a") = 1
+    assert(map.contains("a"))
+    assert(!map.contains("b"))
+    assert(!map.contains(null))
+    map(null) = 0
+    assert(map.contains(null))
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
index 8c7df7d73d..caf378fec8 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
@@ -118,4 +118,11 @@ class PrimitiveKeyOpenHashMapSuite extends FunSuite with Matchers {
       assert(map(i.toLong) === i.toString)
     }
   }
+
+  test("contains") {
+    val map = new PrimitiveKeyOpenHashMap[Int, Int](1)
+    map(0) = 0
+    assert(map.contains(0))
+    assert(!map.contains(1))
+  }
 }
-- 
GitLab