From 41d6586e8e0df94fee66a386c967b56c535e3c28 Mon Sep 17 00:00:00 2001
From: Ankur Dave <ankurdave@gmail.com>
Date: Fri, 10 Jan 2014 18:00:54 -0800
Subject: [PATCH] Revert changes to Spark's (PrimitiveKey)OpenHashMap; copy
 PKOHM to graphx

---
 .../spark/util/collection/OpenHashMap.scala   |  51 ++----
 .../collection/PrimitiveKeyOpenHashMap.scala  |  57 ++-----
 .../spark/graphx/impl/EdgePartition.scala     |   2 +-
 .../graphx/impl/EdgePartitionBuilder.scala    |   3 +-
 .../graphx/impl/EdgeTripletIterator.scala     |   2 +-
 .../apache/spark/graphx/impl/GraphImpl.scala  |   3 -
 .../spark/graphx/impl/VertexPartition.scala   |   5 +-
 .../collection/PrimitiveKeyOpenHashMap.scala  | 153 ++++++++++++++++++
 8 files changed, 192 insertions(+), 84 deletions(-)
 create mode 100644 graphx/src/main/scala/org/apache/spark/graphx/util/collection/PrimitiveKeyOpenHashMap.scala

diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
index a7a6635dec..c26f23d500 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
@@ -28,20 +28,18 @@ import scala.reflect.ClassTag
  */
 private[spark]
 class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
-  val keySet: OpenHashSet[K], var _values: Array[V])
+    initialCapacity: Int)
   extends Iterable[(K, V)]
   with Serializable {
 
-  /**
-   * Allocate an OpenHashMap with a fixed initial capacity
-   */
-  def this(initialCapacity: Int = 64) =
-    this(new OpenHashSet[K](initialCapacity), new Array[V](initialCapacity))
+  def this() = this(64)
 
-  /**
-   * Allocate an OpenHashMap with a fixed initial capacity
-   */
-  def this(keySet: OpenHashSet[K]) = this(keySet, new Array[V](keySet.capacity))
+  protected var _keySet = new OpenHashSet[K](initialCapacity)
+
+  // Init in constructor (instead of in declaration) to work around a Scala compiler specialization
+  // bug that would generate two arrays (one for Object and one for specialized T).
+  private var _values: Array[V] = _
+  _values = new Array[V](_keySet.capacity)
 
   @transient private var _oldValues: Array[V] = null
 
@@ -49,14 +47,14 @@ class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: Class
   private var haveNullValue = false
   private var nullValue: V = null.asInstanceOf[V]
 
-  override def size: Int = if (haveNullValue) keySet.size + 1 else keySet.size
+  override def size: Int = if (haveNullValue) _keySet.size + 1 else _keySet.size
 
   /** Get the value for a given key */
   def apply(k: K): V = {
     if (k == null) {
       nullValue
     } else {
-      val pos = keySet.getPos(k)
+      val pos = _keySet.getPos(k)
       if (pos < 0) {
         null.asInstanceOf[V]
       } else {
@@ -71,26 +69,9 @@ class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: Class
       haveNullValue = true
       nullValue = v
     } else {
-      val pos = keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
+      val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
       _values(pos) = v
-      keySet.rehashIfNeeded(k, grow, move)
-      _oldValues = null
-    }
-  }
-
-  /** Set the value for a key */
-  def setMerge(k: K, v: V, mergeF: (V,V) => V) {
-    if (k == null) {
-      if(haveNullValue) {
-        nullValue = mergeF(nullValue, v)
-      } else {
-        haveNullValue = true
-        nullValue = v
-      }
-    } else {
-      val pos = keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
-      _values(pos) = mergeF(_values(pos), v)
-      keySet.rehashIfNeeded(k, grow, move)
+      _keySet.rehashIfNeeded(k, grow, move)
       _oldValues = null
     }
   }
@@ -111,11 +92,11 @@ class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: Class
       }
       nullValue
     } else {
-      val pos = keySet.addWithoutResize(k)
+      val pos = _keySet.addWithoutResize(k)
       if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
         val newValue = defaultValue
         _values(pos & OpenHashSet.POSITION_MASK) = newValue
-        keySet.rehashIfNeeded(k, grow, move)
+        _keySet.rehashIfNeeded(k, grow, move)
         newValue
       } else {
         _values(pos) = mergeValue(_values(pos))
@@ -137,9 +118,9 @@ class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: Class
         }
         pos += 1
       }
-      pos = keySet.nextPos(pos)
+      pos = _keySet.nextPos(pos)
       if (pos >= 0) {
-        val ret = (keySet.getValue(pos), _values(pos))
+        val ret = (_keySet.getValue(pos), _values(pos))
         pos += 1
         ret
       } else {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
index 1dc9f744e1..2e1ef06cbc 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
@@ -29,68 +29,45 @@ import scala.reflect._
 private[spark]
 class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
                               @specialized(Long, Int, Double) V: ClassTag](
-    val keySet: OpenHashSet[K], var _values: Array[V])
+    initialCapacity: Int)
   extends Iterable[(K, V)]
   with Serializable {
 
-  /**
-   * Allocate an OpenHashMap with a fixed initial capacity
-   */
-  def this(initialCapacity: Int) =
-    this(new OpenHashSet[K](initialCapacity), new Array[V](initialCapacity))
-
-  /**
-   * Allocate an OpenHashMap with a default initial capacity, providing a true
-   * no-argument constructor.
-   */
   def this() = this(64)
 
-  /**
-   * Allocate an OpenHashMap with a fixed initial capacity
-   */
-  def this(keySet: OpenHashSet[K]) = this(keySet, new Array[V](keySet.capacity))
-
   require(classTag[K] == classTag[Long] || classTag[K] == classTag[Int])
 
+  // Init in constructor (instead of in declaration) to work around a Scala compiler specialization
+  // bug that would generate two arrays (one for Object and one for specialized T).
+  protected var _keySet: OpenHashSet[K] = _
+  private var _values: Array[V] = _
+  _keySet = new OpenHashSet[K](initialCapacity)
+  _values = new Array[V](_keySet.capacity)
+
   private var _oldValues: Array[V] = null
 
-  override def size = keySet.size
+  override def size = _keySet.size
 
   /** Get the value for a given key */
   def apply(k: K): V = {
-    val pos = keySet.getPos(k)
+    val pos = _keySet.getPos(k)
     _values(pos)
   }
 
   /** Get the value for a given key, or returns elseValue if it doesn't exist. */
   def getOrElse(k: K, elseValue: V): V = {
-    val pos = keySet.getPos(k)
+    val pos = _keySet.getPos(k)
     if (pos >= 0) _values(pos) else elseValue
   }
 
   /** Set the value for a key */
   def update(k: K, v: V) {
-    val pos = keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
+    val pos = _keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
     _values(pos) = v
-    keySet.rehashIfNeeded(k, grow, move)
+    _keySet.rehashIfNeeded(k, grow, move)
     _oldValues = null
   }
 
-
-  /** Set the value for a key */
-  def setMerge(k: K, v: V, mergeF: (V, V) => V) {
-    val pos = keySet.addWithoutResize(k)
-    val ind = pos & OpenHashSet.POSITION_MASK
-    if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { // if first add
-      _values(ind) = v
-    } else {
-      _values(ind) = mergeF(_values(ind), v)
-    }
-    keySet.rehashIfNeeded(k, grow, move)
-    _oldValues = null
-  }
-
-
   /**
    * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise,
    * set its value to mergeValue(oldValue).
@@ -98,11 +75,11 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
    * @return the newly updated value.
    */
   def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = {
-    val pos = keySet.addWithoutResize(k)
+    val pos = _keySet.addWithoutResize(k)
     if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
       val newValue = defaultValue
       _values(pos & OpenHashSet.POSITION_MASK) = newValue
-      keySet.rehashIfNeeded(k, grow, move)
+      _keySet.rehashIfNeeded(k, grow, move)
       newValue
     } else {
       _values(pos) = mergeValue(_values(pos))
@@ -116,9 +93,9 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
 
     /** Get the next value we should return from next(), or null if we're finished iterating */
     def computeNextPair(): (K, V) = {
-      pos = keySet.nextPos(pos)
+      pos = _keySet.nextPos(pos)
       if (pos >= 0) {
-        val ret = (keySet.getValue(pos), _values(pos))
+        val ret = (_keySet.getValue(pos), _values(pos))
         pos += 1
         ret
       } else {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
index 4176563d22..a03e73ee79 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
@@ -3,7 +3,7 @@ package org.apache.spark.graphx.impl
 import scala.reflect.ClassTag
 
 import org.apache.spark.graphx._
-import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap
+import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
 
 /**
  * A collection of edges stored in 3 large columnar arrays (src, dst, attribute). The arrays are
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
index ca64e9af66..fbc29409b5 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
@@ -4,7 +4,8 @@ import scala.reflect.ClassTag
 import scala.util.Sorting
 
 import org.apache.spark.graphx._
-import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
+import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
+import org.apache.spark.util.collection.PrimitiveVector
 
 class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag](size: Int = 64) {
   var edges = new PrimitiveVector[Edge[ED]](size)
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala
index c5258360da..bad840f1cd 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala
@@ -3,7 +3,7 @@ package org.apache.spark.graphx.impl
 import scala.reflect.ClassTag
 
 import org.apache.spark.graphx._
-import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap
+import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
 
 /**
  * The Iterator type returned when constructing edge triplets. This class technically could be
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
index 987a646c0c..c66b8c804f 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
@@ -173,9 +173,6 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
 
   override def mapTriplets[ED2: ClassTag](
       f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]): Graph[VD, ED2] = {
-    // Use an explicit manifest in PrimitiveKeyOpenHashMap init so we don't pull in the implicit
-    // manifest from GraphImpl (which would require serializing GraphImpl).
-    val vdTag = classTag[VD]
     val newEdgePartitions =
       edges.partitionsRDD.zipPartitions(replicatedVertexView.get(true, true), true) {
         (ePartIter, vTableReplicatedIter) =>
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala
index 7c83497ca9..f97ff75fb2 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala
@@ -2,11 +2,10 @@ package org.apache.spark.graphx.impl
 
 import scala.reflect.ClassTag
 
-import org.apache.spark.util.collection.{BitSet, PrimitiveKeyOpenHashMap}
-
 import org.apache.spark.Logging
 import org.apache.spark.graphx._
-
+import org.apache.spark.graphx.util.collection.PrimitiveKeyOpenHashMap
+import org.apache.spark.util.collection.BitSet
 
 private[graphx] object VertexPartition {
 
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/PrimitiveKeyOpenHashMap.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/PrimitiveKeyOpenHashMap.scala
new file mode 100644
index 0000000000..1088944cd3
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/PrimitiveKeyOpenHashMap.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.graphx.util.collection
+
+import org.apache.spark.util.collection.OpenHashSet
+
+import scala.reflect._
+
+/**
+ * A fast hash map implementation for primitive, non-null keys. This hash map supports
+ * insertions and updates, but not deletions. This map is about an order of magnitude
+ * faster than java.util.HashMap, while using much less space overhead.
+ *
+ * Under the hood, it uses our OpenHashSet implementation.
+ */
+private[spark]
+class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
+                              @specialized(Long, Int, Double) V: ClassTag](
+    val keySet: OpenHashSet[K], var _values: Array[V])
+  extends Iterable[(K, V)]
+  with Serializable {
+
+  /**
+   * Allocate an OpenHashMap with a fixed initial capacity
+   */
+  def this(initialCapacity: Int) =
+    this(new OpenHashSet[K](initialCapacity), new Array[V](initialCapacity))
+
+  /**
+   * Allocate an OpenHashMap with a default initial capacity, providing a true
+   * no-argument constructor.
+   */
+  def this() = this(64)
+
+  /**
+   * Allocate an OpenHashMap with a fixed initial capacity
+   */
+  def this(keySet: OpenHashSet[K]) = this(keySet, new Array[V](keySet.capacity))
+
+  require(classTag[K] == classTag[Long] || classTag[K] == classTag[Int])
+
+  private var _oldValues: Array[V] = null
+
+  override def size = keySet.size
+
+  /** Get the value for a given key */
+  def apply(k: K): V = {
+    val pos = keySet.getPos(k)
+    _values(pos)
+  }
+
+  /** Get the value for a given key, or returns elseValue if it doesn't exist. */
+  def getOrElse(k: K, elseValue: V): V = {
+    val pos = keySet.getPos(k)
+    if (pos >= 0) _values(pos) else elseValue
+  }
+
+  /** Set the value for a key */
+  def update(k: K, v: V) {
+    val pos = keySet.addWithoutResize(k) & OpenHashSet.POSITION_MASK
+    _values(pos) = v
+    keySet.rehashIfNeeded(k, grow, move)
+    _oldValues = null
+  }
+
+
+  /** Set the value for a key */
+  def setMerge(k: K, v: V, mergeF: (V, V) => V) {
+    val pos = keySet.addWithoutResize(k)
+    val ind = pos & OpenHashSet.POSITION_MASK
+    if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { // if first add
+      _values(ind) = v
+    } else {
+      _values(ind) = mergeF(_values(ind), v)
+    }
+    keySet.rehashIfNeeded(k, grow, move)
+    _oldValues = null
+  }
+
+
+  /**
+   * If the key doesn't exist yet in the hash map, set its value to defaultValue; otherwise,
+   * set its value to mergeValue(oldValue).
+   *
+   * @return the newly updated value.
+   */
+  def changeValue(k: K, defaultValue: => V, mergeValue: (V) => V): V = {
+    val pos = keySet.addWithoutResize(k)
+    if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) {
+      val newValue = defaultValue
+      _values(pos & OpenHashSet.POSITION_MASK) = newValue
+      keySet.rehashIfNeeded(k, grow, move)
+      newValue
+    } else {
+      _values(pos) = mergeValue(_values(pos))
+      _values(pos)
+    }
+  }
+
+  override def iterator = new Iterator[(K, V)] {
+    var pos = 0
+    var nextPair: (K, V) = computeNextPair()
+
+    /** Get the next value we should return from next(), or null if we're finished iterating */
+    def computeNextPair(): (K, V) = {
+      pos = keySet.nextPos(pos)
+      if (pos >= 0) {
+        val ret = (keySet.getValue(pos), _values(pos))
+        pos += 1
+        ret
+      } else {
+        null
+      }
+    }
+
+    def hasNext = nextPair != null
+
+    def next() = {
+      val pair = nextPair
+      nextPair = computeNextPair()
+      pair
+    }
+  }
+
+  // The following member variables are declared as protected instead of private for the
+  // specialization to work (specialized class extends the unspecialized one and needs access
+  // to the "private" variables).
+  // They also should have been val's. We use var's because there is a Scala compiler bug that
+  // would throw illegal access error at runtime if they are declared as val's.
+  protected var grow = (newCapacity: Int) => {
+    _oldValues = _values
+    _values = new Array[V](newCapacity)
+  }
+
+  protected var move = (oldPos: Int, newPos: Int) => {
+    _values(newPos) = _oldValues(oldPos)
+  }
+}
-- 
GitLab