From 8381aeffb34fecba53b943763ef65f35ef52289a Mon Sep 17 00:00:00 2001
From: "Joseph E. Gonzalez" <joseph.e.gonzalez@gmail.com>
Date: Thu, 31 Oct 2013 18:13:02 -0700
Subject: [PATCH] This commit introduces the OpenHashSet and OpenHashMap as
 indexing primitives.

Large parts of the VertexSetRDD were restructured to take advantage of:

  1) the OpenHashSet as an index map
  2) view based lazy mapValues and mapValuesWithVertices
  3) the cogroup code is currently disabled (since it is not used in any of the tests)

The GraphImpl was updated to also use the OpenHashSet and PrimitiveOpenHashMap
wherever possible:

  1) the LocalVidMaps (used to track replicated vertices) are now implemented
     using the OpenHashSet
  2) an OpenHashMap is temporarily constructed to combine the local OpenHashSet
     with the local (replicated) vertex attribute arrays
  3) because the OpenHashSet constructor grabs a class manifest all operations
     that construct OpenHashSets have been moved to the GraphImpl Singleton to prevent
     implicit variable capture within closures.
---
 .../spark/graph/GraphKryoRegistrator.scala    |   6 +-
 .../org/apache/spark/graph/VertexSetRDD.scala | 214 ++++++++---------
 .../apache/spark/graph/impl/GraphImpl.scala   | 217 +++++++++---------
 .../org/apache/spark/graph/package.scala      |  10 +-
 .../org/apache/spark/graph/GraphSuite.scala   |  11 +-
 5 files changed, 225 insertions(+), 233 deletions(-)

diff --git a/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala b/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala
index 821063e1f8..62f445127c 100644
--- a/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/GraphKryoRegistrator.scala
@@ -1,14 +1,11 @@
 package org.apache.spark.graph
 
-import org.apache.spark.util.hash.BitSet
-
-
 import com.esotericsoftware.kryo.Kryo
 
 import org.apache.spark.graph.impl.MessageToPartition
 import org.apache.spark.serializer.KryoRegistrator
 import org.apache.spark.graph.impl._
-import scala.collection.mutable.BitSet
+import org.apache.spark.util.hash.BitSet
 
 class GraphKryoRegistrator extends KryoRegistrator {
 
@@ -20,7 +17,6 @@ class GraphKryoRegistrator extends KryoRegistrator {
     kryo.register(classOf[EdgePartition[Object]])
     kryo.register(classOf[BitSet])
     kryo.register(classOf[VertexIdToIndexMap])
-
     // This avoids a large number of hash table lookups.
     kryo.setReferences(false)
   }
diff --git a/graph/src/main/scala/org/apache/spark/graph/VertexSetRDD.scala b/graph/src/main/scala/org/apache/spark/graph/VertexSetRDD.scala
index 8acc89a95b..b3647c083e 100644
--- a/graph/src/main/scala/org/apache/spark/graph/VertexSetRDD.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/VertexSetRDD.scala
@@ -31,6 +31,8 @@ import org.apache.spark.Partitioner._
 
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.hash.BitSet
+import org.apache.spark.util.hash.OpenHashSet
+import org.apache.spark.util.hash.PrimitiveKeyOpenHashMap
 
 
 
@@ -160,15 +162,8 @@ class VertexSetRDD[@specialized V: ClassManifest](
    * Provide the RDD[(K,V)] equivalent output. 
    */
   override def compute(part: Partition, context: TaskContext): Iterator[(Vid, V)] = {
-    tuples.compute(part, context).flatMap { case (indexMap, (values, bs) ) => 
-      // Walk the index to construct the key, value pairs
-      indexMap.iterator 
-        // Extract rows with key value pairs and indicators
-        .map{ case (k, ind) => (bs.get(ind), k, ind)  }
-        // Remove tuples that aren't actually present in the array
-        .filter( _._1 )
-        // Extract the pair (removing the indicator from the tuple)
-        .map( x => (x._2, values(x._3) ) )
+    tuples.compute(part, context).flatMap { case (indexMap, (values, bs) ) =>
+      bs.iterator.map(ind => (indexMap.getValueSafe(ind), values(ind)))
     }
   } // end of compute
 
@@ -195,11 +190,15 @@ class VertexSetRDD[@specialized V: ClassManifest](
       assert(valuesIter.hasNext() == false)
       // Allocate the array to store the results into
       val newBS = new BitSet(oldValues.size)
-      // Populate the new Values
-      for( (k,i) <- index ) {
-        if( bs.get(i) && cleanPred( (k, oldValues(i)) ) ) {
-          newBS.set(i)
+      // Iterate over the active bits in the old bitset and 
+      // evaluate the predicate
+      var ind = bs.nextSetBit(0)
+      while(ind >= 0) {
+        val k = index.getValueSafe(ind)
+        if( cleanPred( (k, oldValues(ind)) ) ) {
+          newBS.set(ind)
         }
+        ind = bs.nextSetBit(ind+1)
       }
       Array((oldValues, newBS)).iterator
     }
@@ -223,27 +222,10 @@ class VertexSetRDD[@specialized V: ClassManifest](
     val newValuesRDD: RDD[ (IndexedSeq[U], BitSet) ] = 
       valuesRDD.mapPartitions(iter => iter.map{ 
         case (values, bs: BitSet) => 
-
-          /** 
-           * @todo Consider using a view rather than creating a new
-           * array.  This is already being done for join operations.
-           * It could reduce memory overhead but require additional
-           * recomputation.
-           */
-          val newValues = new Array[U](values.size)
-          var ind = bs.nextSetBit(0)
-          while(ind >= 0) {
-            // if(ind >= newValues.size) {
-            //   println(ind)
-            //   println(newValues.size)
-            //   bs.iterator.foreach(print(_))
-            // }
-            // assert(ind < newValues.size)
-            // assert(ind < values.size)
-            newValues(ind) = cleanF(values(ind))
-            ind = bs.nextSetBit(ind+1)
-          }
-          (newValues.toIndexedSeq, bs)
+          val newValues: IndexedSeq[U] = values.view.zipWithIndex.map{ 
+            (x: (V, Int)) => if(bs.get(x._2)) cleanF(x._1) else null.asInstanceOf[U]
+          }.toIndexedSeq // @todo check the toIndexedSeq is free
+          (newValues, bs)
           }, preservesPartitioning = true)   
     new VertexSetRDD[U](index, newValuesRDD)
   } // end of mapValues
@@ -271,18 +253,14 @@ class VertexSetRDD[@specialized V: ClassManifest](
         assert(keysIter.hasNext() == false)
         val (oldValues, bs: BitSet) = valuesIter.next()
         assert(valuesIter.hasNext() == false)
-        /** 
-         * @todo Consider using a view rather than creating a new array. 
-         * This is already being done for join operations.  It could reduce
-         * memory overhead but require additional recomputation.  
-         */
-        // Allocate the array to store the results into
-        val newValues: Array[U] = new Array[U](oldValues.size)
-        // Populate the new Values
-        for( (k,i) <- index ) {
-          if (bs.get(i)) { newValues(i) = cleanF(k, oldValues(i)) }      
-        }
-        Array((newValues.toIndexedSeq, bs)).iterator
+        // Cosntruct a view of the map transformation
+        val newValues: IndexedSeq[U] = oldValues.view.zipWithIndex.map{ 
+          (x: (V, Int)) => 
+          if(bs.get(x._2)) {
+            cleanF(index.getValueSafe(x._2), x._1)
+          } else null.asInstanceOf[U]
+        }.toIndexedSeq // @todo check the toIndexedSeq is free
+        Iterator((newValues, bs))
       }
     new VertexSetRDD[U](index, newValues)
   } // end of mapValuesWithKeys
@@ -314,8 +292,10 @@ class VertexSetRDD[@specialized V: ClassManifest](
         val (otherValues, otherBS: BitSet) = otherIter.next()
         assert(!otherIter.hasNext)
         val newBS: BitSet = thisBS & otherBS
-        val newValues = thisValues.view.zip(otherValues)
-        Iterator((newValues.toIndexedSeq, newBS))
+        val newValues: IndexedSeq[(V,W)] = 
+          thisValues.view.zip(otherValues).toIndexedSeq // @todo check the toIndexedSeq is free
+        // Iterator((newValues.toIndexedSeq, newBS))
+        Iterator((newValues, newBS))
       }
     new VertexSetRDD(index, newValuesRDD)
   }
@@ -348,10 +328,15 @@ class VertexSetRDD[@specialized V: ClassManifest](
       assert(!thisIter.hasNext)
       val (otherValues, otherBS: BitSet) = otherIter.next()
       assert(!otherIter.hasNext)
-      val otherOption = otherValues.view.zipWithIndex
-        .map{ (x: (W, Int)) => if(otherBS.get(x._2)) Option(x._1) else None }
-      val newValues = thisValues.view.zip(otherOption)
-      Iterator((newValues.toIndexedSeq, thisBS))
+      val newValues: IndexedSeq[(V, Option[W])] = thisValues.view.zip(otherValues)
+        .zipWithIndex.map {
+          // @todo not sure about the efficiency of this case statement
+          // though it is assumed that the return value is short lived
+          case ((v, w), ind) => (v, if (otherBS.get(ind)) Option(w) else None) 
+        }
+        .toIndexedSeq // @todo check the toIndexedSeq is free
+      Iterator((newValues, thisBS))
+      //      Iterator((newValues.toIndexedSeq, thisBS))
     }
     new VertexSetRDD(index, newValuesRDD)
   } // end of leftZipJoin
@@ -378,7 +363,6 @@ class VertexSetRDD[@specialized V: ClassManifest](
   def leftJoin[W: ClassManifest](
     other: RDD[(Vid,W)], merge: (W,W) => W = (a:W, b:W) => a):
     VertexSetRDD[(V, Option[W]) ] = {
-    val cleanMerge = index.rdd.context.clean(merge)
     // Test if the other vertex is a VertexSetRDD to choose the optimal
     // join strategy
     other match {
@@ -396,7 +380,7 @@ class VertexSetRDD[@specialized V: ClassManifest](
           if (other.partitioner == partitioner) other 
           else other.partitionBy(partitioner.get)
         // Compute the new values RDD
-        val newValues: RDD[ (IndexedSeq[(V,Option[W])], BitSet) ] = 
+        val newValuesRDD: RDD[ (IndexedSeq[(V,Option[W])], BitSet) ] = 
           index.rdd.zipPartitions(valuesRDD, otherShuffled) {
           (thisIndexIter: Iterator[VertexIdToIndexMap], 
             thisIter: Iterator[(IndexedSeq[V], BitSet)], 
@@ -407,33 +391,37 @@ class VertexSetRDD[@specialized V: ClassManifest](
           val (thisValues, thisBS) = thisIter.next()
           assert(!thisIter.hasNext)
           // Create a new array to store the values in the resulting VertexSet
-          val newW = new Array[W](thisValues.size)
+          val otherValues = new Array[W](thisValues.size)
           // track which values are matched with values in other
-          val wBS = new BitSet(thisValues.size)
-          // Loop over all the tuples that have vertices in this VertexSet.  
-          for( (k, w) <- tuplesIter if index.contains(k) ) {
-            val ind = index.get(k)
-            // Not all the vertex ids in the index are in this VertexSet. 
-            // If there is a vertex in this set then record the other value
-            if(thisBS.get(ind)) {
-              if(wBS.get(ind)) {
-                newW(ind) = cleanMerge(newW(ind), w) 
-              } else {
-                newW(ind) = w
-                wBS.set(ind) 
+          val otherBS = new BitSet(thisValues.size)
+          for ((k,w) <- tuplesIter) {
+            // Get the location of the key in the index
+            val pos = index.getPos(k)
+            // Only if the key is already in the index
+            if ((pos & OpenHashSet.EXISTENCE_MASK) == 0) {
+              // Get the actual index
+              val ind = pos & OpenHashSet.POSITION_MASK
+              // If this value has already been seen then merge
+              if (otherBS.get(ind)) {
+                otherValues(ind) = merge(otherValues(ind), w)
+              } else { // otherwise just store the new value
+                otherBS.set(ind)
+                otherValues(ind) = w
               }
             }
-          } // end of for loop over tuples
+          }
           // Some vertices in this vertex set may not have a corresponding
           // tuple in the join and so a None value should be returned. 
-          val otherOption = newW.view.zipWithIndex
-            .map{ (x: (W, Int)) => if(wBS.get(x._2)) Option(x._1) else None }
-          // the final values is the zip of the values in this RDD along with
-          // the values in the other
-          val newValues = thisValues.view.zip(otherOption)
-          Iterator((newValues.toIndexedSeq, thisBS))
+          val newValues: IndexedSeq[(V, Option[W])] = thisValues.view.zip(otherValues)
+            .zipWithIndex.map {
+            // @todo not sure about the efficiency of this case statement
+            // though it is assumed that the return value is short lived
+            case ((v, w), ind) => (v, if (otherBS.get(ind)) Option(w) else None) 
+            }
+            .toIndexedSeq // @todo check the toIndexedSeq is free
+          Iterator((newValues, thisBS))
         } // end of newValues
-        new VertexSetRDD(index, newValues) 
+        new VertexSetRDD(index, newValuesRDD) 
       }
     }
   } // end of leftJoin
@@ -443,6 +431,7 @@ class VertexSetRDD[@specialized V: ClassManifest](
    * For each key k in `this` or `other`, return a resulting RDD that contains a 
    * tuple with the list of values for that key in `this` as well as `other`.
    */
+   /*
   def cogroup[W: ClassManifest](other: RDD[(Vid, W)], partitioner: Partitioner): 
   VertexSetRDD[(Seq[V], Seq[W])] = {
     //RDD[(K, (Seq[V], Seq[W]))] = {
@@ -489,16 +478,17 @@ class VertexSetRDD[@specialized V: ClassManifest](
             assert(!thisIter.hasNext)
             val otherIndex = otherIter.next()
             assert(!otherIter.hasNext)
-            val newIndex = new VertexIdToIndexMap()
-            // @todo Merge only the keys that correspond to non-null values
             // Merge the keys
-            newIndex.putAll(thisIndex)
-            newIndex.putAll(otherIndex)
-            // We need to rekey the index
-            var ctr = 0
-            for (e <- newIndex.entrySet) {
-              e.setValue(ctr)
-              ctr += 1
+            val newIndex = new VertexIdToIndexMap(thisIndex.capacity + otherIndex.capacity)
+            var ind = thisIndex.nextPos(0)
+            while(ind >= 0) {
+              newIndex.fastAdd(thisIndex.getValue(ind))
+              ind = thisIndex.nextPos(ind+1)
+            }
+            var ind = otherIndex.nextPos(0)
+            while(ind >= 0) {
+              newIndex.fastAdd(otherIndex.getValue(ind))
+              ind = otherIndex.nextPos(ind+1)
             }
             List(newIndex).iterator
           }).cache()
@@ -604,7 +594,7 @@ class VertexSetRDD[@specialized V: ClassManifest](
       }
     }
   } // end of cogroup
-
+ */
 
 } // End of VertexSetRDD
 
@@ -649,21 +639,14 @@ object VertexSetRDD {
     } 
 
     val groups = preAgg.mapPartitions( iter => {
-      val indexMap = new VertexIdToIndexMap()
-      val values = new ArrayBuffer[V]
+      val hashMap = new PrimitiveKeyOpenHashMap[Vid, V]
       for ((k,v) <- iter) {
-        if(!indexMap.contains(k)) {
-          val ind = indexMap.size
-          indexMap.put(k, ind)
-          values.append(v)
-        } else {
-          val ind = indexMap.get(k)
-          values(ind) = reduceFunc(values(ind), v)
-        }
+        hashMap.update(k, v, reduceFunc)
       }
-      val bs = new BitSet(indexMap.size)
-      bs.setUntil(indexMap.size)
-      Iterator( (indexMap, (values.toIndexedSeq, bs)) )
+      val index = hashMap.keySet
+      val values: IndexedSeq[V] = hashMap._values
+      val bs = index.getBitSet
+      Iterator( (index, (values, bs)) )
       }, true).cache
     // extract the index and the values
     val index = groups.mapPartitions(_.map{ case (kMap, vAr) => kMap }, true)
@@ -747,20 +730,24 @@ object VertexSetRDD {
       // There is only one map
       val index = indexIter.next()
       assert(!indexIter.hasNext())
-      val values = new Array[C](index.size)
-      val bs = new BitSet(index.size)
+      val values = new Array[C](index.capacity)
+      val bs = new BitSet(index.capacity)
       for ((k,c) <- tblIter) {
-        // @todo this extra check may be costing us a lot!
-        if (!index.contains(k)) {
+        // Get the location of the key in the index
+        val pos = index.getPos(k)
+        if ((pos & OpenHashSet.EXISTENCE_MASK) != 0) {
           throw new SparkException("Error: Trying to bind an external index " +
             "to an RDD which contains keys that are not in the index.")
-        }
-        val ind = index(k)
-        if (bs.get(ind)) { 
-          values(ind) = mergeCombiners(values(ind), c) 
         } else {
-          values(ind) = c
-          bs.set(ind)
+          // Get the actual index
+          val ind = pos & OpenHashSet.POSITION_MASK
+          // If this value has already been seen then merge
+          if (bs.get(ind)) {
+            values(ind) = mergeCombiners(values(ind), c)
+          } else { // otherwise just store the new value
+            bs.set(ind)
+            values(ind) = c
+          }
         }
       }
       Iterator((values, bs))
@@ -792,14 +779,9 @@ object VertexSetRDD {
     }
 
     val index = shuffledTbl.mapPartitions( iter => {
-      val indexMap = new VertexIdToIndexMap()
-      for ( (k,_) <- iter ){
-        if(!indexMap.contains(k)){
-          val ind = indexMap.size
-          indexMap.put(k, ind)   
-        }
-      }
-      Iterator(indexMap)
+      val index = new VertexIdToIndexMap
+      for ( (k,_) <- iter ){ index.add(k) }
+      Iterator(index)
       }, true).cache
     new VertexSetIndex(index)
   }
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
index 016811db36..b80713dbf4 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
@@ -5,7 +5,6 @@ import scala.collection.JavaConversions._
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable.ArrayBuilder
-import scala.collection.mutable.BitSet
 
 
 import org.apache.spark.SparkContext._
@@ -21,6 +20,12 @@ import org.apache.spark.graph._
 import org.apache.spark.graph.impl.GraphImpl._
 import org.apache.spark.graph.impl.MessageToPartitionRDDFunctions._
 
+import org.apache.spark.util.hash.BitSet
+import org.apache.spark.util.hash.OpenHashSet
+import org.apache.spark.util.hash.PrimitiveKeyOpenHashMap
+
+
+
 /**
  * The Iterator type returned when constructing edge triplets
  */
@@ -31,15 +36,16 @@ class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest](
 
   private var pos = 0
   private val et = new EdgeTriplet[VD, ED]
+  private val vmap = new PrimitiveKeyOpenHashMap[Vid, VD](vidToIndex, vertexArray)
   
   override def hasNext: Boolean = pos < edgePartition.size
   override def next() = {
     et.srcId = edgePartition.srcIds(pos)
     // assert(vmap.containsKey(e.src.id))
-    et.srcAttr = vertexArray(vidToIndex(et.srcId))
+    et.srcAttr = vmap(et.srcId)
     et.dstId = edgePartition.dstIds(pos)
     // assert(vmap.containsKey(e.dst.id))
-    et.dstAttr = vertexArray(vidToIndex(et.dstId))
+    et.dstAttr = vmap(et.dstId)
     et.attr = edgePartition.data(pos)
     pos += 1
     et
@@ -51,10 +57,10 @@ class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest](
     for (i <- (0 until edgePartition.size)) {
       currentEdge.srcId = edgePartition.srcIds(i)
       // assert(vmap.containsKey(e.src.id))
-      currentEdge.srcAttr = vertexArray(vidToIndex(currentEdge.srcId))
+      currentEdge.srcAttr = vmap(currentEdge.srcId)
       currentEdge.dstId = edgePartition.dstIds(i)
       // assert(vmap.containsKey(e.dst.id))
-      currentEdge.dstAttr = vertexArray(vidToIndex(currentEdge.dstId))
+      currentEdge.dstAttr = vmap(currentEdge.dstId)
       currentEdge.attr = edgePartition.data(i)
       lb += currentEdge
     }
@@ -63,23 +69,6 @@ class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest](
 } // end of Edge Triplet Iterator
 
 
-
-object EdgeTripletBuilder {
-  def makeTriplets[VD: ClassManifest, ED: ClassManifest]( 
-    localVidMap: RDD[(Pid, VertexIdToIndexMap)],
-    vTableReplicatedValues: RDD[(Pid, Array[VD]) ],
-    eTable: RDD[(Pid, EdgePartition[ED])]): RDD[EdgeTriplet[VD, ED]] = {
-    localVidMap.zipPartitions(vTableReplicatedValues, eTable) {
-      (vidMapIter, replicatedValuesIter, eTableIter) =>
-      val (_, vidToIndex) = vidMapIter.next()
-      val (_, vertexArray) = replicatedValuesIter.next()
-      val (_, edgePartition) = eTableIter.next()
-      new EdgeTripletIterator(vidToIndex, vertexArray, edgePartition)
-    }
-  }
-}
-
-
 /**
  * A Graph RDD that supports computation on graphs.
  */
@@ -90,6 +79,10 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     @transient val eTable: RDD[(Pid, EdgePartition[ED])] )
   extends Graph[VD, ED] {
 
+  def this() = this(null, null, null, null)
+
+
+
   /**
    * (localVidMap: VertexSetRDD[Pid, VertexIdToIndexMap]) is a version of the
    * vertex data after it is replicated. Within each partition, it holds a map
@@ -115,7 +108,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
 
   /** Return a RDD that brings edges with its source and destination vertices together. */
   @transient override val triplets: RDD[EdgeTriplet[VD, ED]] =
-    EdgeTripletBuilder.makeTriplets(localVidMap, vTableReplicatedValues, eTable)
+    makeTriplets(localVidMap, vTableReplicatedValues, eTable)
 
 
   override def cache(): Graph[VD, ED] = {
@@ -219,24 +212,8 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
   }
 
 
-  override def mapTriplets[ED2: ClassManifest](f: EdgeTriplet[VD, ED] => ED2):
-    Graph[VD, ED2] = {
-    val newETable = eTable.zipPartitions(localVidMap, vTableReplicatedValues){ 
-      (edgePartitionIter, vidToIndexIter, vertexArrayIter) =>
-      val (pid, edgePartition) = edgePartitionIter.next()
-      val (_, vidToIndex) = vidToIndexIter.next()
-      val (_, vertexArray) = vertexArrayIter.next()
-      val et = new EdgeTriplet[VD, ED]
-      val newEdgePartition = edgePartition.map{e =>
-        et.set(e)
-        et.srcAttr = vertexArray(vidToIndex(e.srcId))
-        et.dstAttr = vertexArray(vidToIndex(e.dstId))
-        f(et)
-      }
-      Iterator((pid, newEdgePartition))
-    }
-    new GraphImpl(vTable, vid2pid, localVidMap, newETable)
-  }
+  override def mapTriplets[ED2: ClassManifest](f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] =
+    GraphImpl.mapTriplets(this, f)
 
 
   override def subgraph(epred: EdgeTriplet[VD,ED] => Boolean = (x => true), 
@@ -330,57 +307,8 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
   override def mapReduceTriplets[A: ClassManifest](
       mapFunc: EdgeTriplet[VD, ED] => Array[(Vid, A)],
       reduceFunc: (A, A) => A)
-    : VertexSetRDD[A] = {
-
-    ClosureCleaner.clean(mapFunc)
-    ClosureCleaner.clean(reduceFunc)
-
-    // Map and preaggregate 
-    val preAgg = eTable.zipPartitions(localVidMap, vTableReplicatedValues){ 
-      (edgePartitionIter, vidToIndexIter, vertexArrayIter) =>
-      val (pid, edgePartition) = edgePartitionIter.next()
-      val (_, vidToIndex) = vidToIndexIter.next()
-      val (_, vertexArray) = vertexArrayIter.next()
-      // We can reuse the vidToIndex map for aggregation here as well.
-      /** @todo Since this has the downside of not allowing "messages" to arbitrary
-       * vertices we should consider just using a fresh map.
-       */
-      val msgArray = new Array[A](vertexArray.size)
-      val msgBS = new BitSet(vertexArray.size)
-      // Iterate over the partition
-      val et = new EdgeTriplet[VD, ED]
-      edgePartition.foreach{e => 
-        et.set(e)
-        et.srcAttr = vertexArray(vidToIndex(e.srcId))
-        et.dstAttr = vertexArray(vidToIndex(e.dstId))
-        mapFunc(et).foreach{ case (vid, msg) =>
-          // verify that the vid is valid
-          assert(vid == et.srcId || vid == et.dstId)
-          val ind = vidToIndex(vid)
-          // Populate the aggregator map
-          if(msgBS(ind)) {
-            msgArray(ind) = reduceFunc(msgArray(ind), msg)
-          } else { 
-            msgArray(ind) = msg
-            msgBS(ind) = true
-          }
-        }
-      }
-      // Return the aggregate map
-      vidToIndex.long2IntEntrySet().fastIterator()
-      // Remove the entries that did not receive a message
-      .filter{ entry => msgBS(entry.getValue()) }
-      // Construct the actual pairs
-      .map{ entry => 
-        val vid = entry.getLongKey()
-        val ind = entry.getValue()
-        val msg = msgArray(ind)
-        (vid, msg)
-      }
-      }.partitionBy(vTable.index.rdd.partitioner.get)
-    // do the final reduction reusing the index map
-    VertexSetRDD(preAgg, vTable.index, reduceFunc)
-  }
+    : VertexSetRDD[A] = 
+    GraphImpl.mapReduceTriplets(this, mapFunc, reduceFunc)
 
 
   override def outerJoinVertices[U: ClassManifest, VD2: ClassManifest]
@@ -436,7 +364,6 @@ object GraphImpl {
   }
 
 
-
   /**
    * Create the edge table RDD, which is much more efficient for Java heap storage than the
    * normal edges data structure (RDD[(Vid, Vid, ED)]).
@@ -494,16 +421,9 @@ object GraphImpl {
     RDD[(Pid, VertexIdToIndexMap)] = {
     eTable.mapPartitions( _.map{ case (pid, epart) =>
       val vidToIndex = new VertexIdToIndexMap
-      var i = 0
       epart.foreach{ e => 
-        if(!vidToIndex.contains(e.srcId)) {
-          vidToIndex.put(e.srcId, i)
-          i += 1
-        }
-        if(!vidToIndex.contains(e.dstId)) {
-          vidToIndex.put(e.dstId, i)
-          i += 1
-        }
+        vidToIndex.add(e.srcId)
+        vidToIndex.add(e.dstId)
       }
       (pid, vidToIndex)
     }, preservesPartitioning = true).cache()
@@ -528,9 +448,9 @@ object GraphImpl {
       val (pid, vidToIndex) = mapIter.next()
       assert(!mapIter.hasNext)
       // Populate the vertex array using the vidToIndex map
-      val vertexArray = new Array[VD](vidToIndex.size)
+      val vertexArray = new Array[VD](vidToIndex.capacity)
       for (msg <- msgsIter) {
-        val ind = vidToIndex(msg.data._1)
+        val ind = vidToIndex.getPos(msg.data._1) & OpenHashSet.POSITION_MASK
         vertexArray(ind) = msg.data._2
       }
       Iterator((pid, vertexArray))
@@ -540,6 +460,95 @@ object GraphImpl {
   }
 
 
+  def makeTriplets[VD: ClassManifest, ED: ClassManifest]( 
+    localVidMap: RDD[(Pid, VertexIdToIndexMap)],
+    vTableReplicatedValues: RDD[(Pid, Array[VD]) ],
+    eTable: RDD[(Pid, EdgePartition[ED])]): RDD[EdgeTriplet[VD, ED]] = {
+    localVidMap.zipPartitions(vTableReplicatedValues, eTable) {
+      (vidMapIter, replicatedValuesIter, eTableIter) =>
+      val (_, vidToIndex) = vidMapIter.next()
+      val (_, vertexArray) = replicatedValuesIter.next()
+      val (_, edgePartition) = eTableIter.next()
+      new EdgeTripletIterator(vidToIndex, vertexArray, edgePartition)
+    }
+  }
+
+
+  def mapTriplets[VD: ClassManifest, ED: ClassManifest, ED2: ClassManifest](
+    g: GraphImpl[VD, ED],   
+    f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
+    val newETable = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValues){ 
+      (edgePartitionIter, vidToIndexIter, vertexArrayIter) =>
+      val (pid, edgePartition) = edgePartitionIter.next()
+      val (_, vidToIndex) = vidToIndexIter.next()
+      val (_, vertexArray) = vertexArrayIter.next()
+      val et = new EdgeTriplet[VD, ED]
+      val vmap = new PrimitiveKeyOpenHashMap[Vid, VD](vidToIndex, vertexArray)
+      val newEdgePartition = edgePartition.map{e =>
+        et.set(e)
+        et.srcAttr = vmap(e.srcId)
+        et.dstAttr = vmap(e.dstId)
+        f(et)
+      }
+      Iterator((pid, newEdgePartition))
+    }
+    new GraphImpl(g.vTable, g.vid2pid, g.localVidMap, newETable)
+  }
+
+
+  def mapReduceTriplets[VD: ClassManifest, ED: ClassManifest, A: ClassManifest](
+    g: GraphImpl[VD, ED],
+    mapFunc: EdgeTriplet[VD, ED] => Array[(Vid, A)],
+    reduceFunc: (A, A) => A): VertexSetRDD[A] = {
+
+    ClosureCleaner.clean(mapFunc)
+    ClosureCleaner.clean(reduceFunc)
+
+    // Map and preaggregate 
+    val preAgg = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValues){ 
+      (edgePartitionIter, vidToIndexIter, vertexArrayIter) =>
+      val (pid, edgePartition) = edgePartitionIter.next()
+      val (_, vidToIndex) = vidToIndexIter.next()
+      val (_, vertexArray) = vertexArrayIter.next()
+      assert(!edgePartitionIter.hasNext)
+      assert(!vidToIndexIter.hasNext)
+      assert(!vertexArrayIter.hasNext)
+      assert(vidToIndex.capacity == vertexArray.size)
+      val vmap = new PrimitiveKeyOpenHashMap[Vid, VD](vidToIndex, vertexArray)
+      // We can reuse the vidToIndex map for aggregation here as well.
+      /** @todo Since this has the downside of not allowing "messages" to arbitrary
+       * vertices we should consider just using a fresh map.
+       */
+      val msgArray = new Array[A](vertexArray.size)
+      val msgBS = new BitSet(vertexArray.size)
+      // Iterate over the partition
+      val et = new EdgeTriplet[VD, ED]
+      edgePartition.foreach{e => 
+        et.set(e)
+        et.srcAttr = vmap(e.srcId)
+        et.dstAttr = vmap(e.dstId)
+        mapFunc(et).foreach{ case (vid, msg) =>
+          // verify that the vid is valid
+          assert(vid == et.srcId || vid == et.dstId)
+          // Get the index of the key
+          val ind = vidToIndex.getPos(vid) & OpenHashSet.POSITION_MASK
+          // Populate the aggregator map
+          if(msgBS.get(ind)) {
+            msgArray(ind) = reduceFunc(msgArray(ind), msg)
+          } else { 
+            msgArray(ind) = msg
+            msgBS.set(ind)
+          }
+        }
+      }
+      // construct an iterator of tuples Iterator[(Vid, A)]
+      msgBS.iterator.map( ind => (vidToIndex.getValue(ind), msgArray(ind)) )
+    }.partitionBy(g.vTable.index.rdd.partitioner.get)
+    // do the final reduction reusing the index map
+    VertexSetRDD(preAgg, g.vTable.index, reduceFunc)
+  }
+
+
   protected def edgePartitionFunction1D(src: Vid, dst: Vid, numParts: Pid): Pid = {
     val mixingPrime: Vid = 1125899906842597L 
     (math.abs(src) * mixingPrime).toInt % numParts
diff --git a/graph/src/main/scala/org/apache/spark/graph/package.scala b/graph/src/main/scala/org/apache/spark/graph/package.scala
index 4627c3566c..37a4fb4a5e 100644
--- a/graph/src/main/scala/org/apache/spark/graph/package.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/package.scala
@@ -1,5 +1,10 @@
 package org.apache.spark
 
+import org.apache.spark.util.hash.BitSet
+import org.apache.spark.util.hash.OpenHashSet
+import org.apache.spark.util.hash.PrimitiveKeyOpenHashMap
+
+
 package object graph {
 
   type Vid = Long
@@ -8,8 +13,9 @@ package object graph {
   type VertexHashMap[T] = it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap[T]
   type VertexSet = it.unimi.dsi.fastutil.longs.LongOpenHashSet
   type VertexArrayList = it.unimi.dsi.fastutil.longs.LongArrayList
-  // @todo replace with rxin's fast hashmap
-  type VertexIdToIndexMap = it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap
+  
+  //  type VertexIdToIndexMap = it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap
+  type VertexIdToIndexMap = OpenHashSet[Vid]
 
   /**
    * Return the default null-like value for a data type T.
diff --git a/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala b/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala
index f2b3d5bdfe..2067b1613e 100644
--- a/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala
+++ b/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala
@@ -77,16 +77,15 @@ class GraphSuite extends FunSuite with LocalSparkContext {
     withSpark(new SparkContext("local", "test")) { sc =>
       val a = sc.parallelize((0 to 100).map(x => (x.toLong, x.toLong)), 5)
       val b = VertexSetRDD(a).mapValues(x => -x)
-      assert(b.leftJoin(a)
-        .mapValues(x => x._1 + x._2.get).map(x=> x._2).reduce(_+_) === 0)
+      assert(b.count === 101)
+      assert(b.leftJoin(a).mapValues(x => x._1 + x._2.get).map(x=> x._2).reduce(_+_) === 0)
       val c = VertexSetRDD(a, b.index)
-      assert(b.leftJoin(c)
-        .mapValues(x => x._1 + x._2.get).map(x=> x._2).reduce(_+_) === 0)
+      assert(b.leftJoin(c).mapValues(x => x._1 + x._2.get).map(x=> x._2).reduce(_+_) === 0)
       val d = c.filter(q => ((q._2 % 2) == 0))
       val e = a.filter(q => ((q._2 % 2) == 0))
       assert(d.count === e.count)
-      assert(b.zipJoin(c).mapValues(x => x._1 + x._2)
-        .map(x => x._2).reduce(_+_) === 0)
+      assert(b.zipJoin(c).mapValues(x => x._1 + x._2).map(x => x._2).reduce(_+_) === 0)
+
     }
   } 
   
-- 
GitLab