diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
index 7c0b9e23f28dc960483e953ba500d862d3f55a8a..ae1ea715e2badf3b79416c55c5fa1e94dbcb45fc 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
@@ -63,6 +63,13 @@ class EdgeTripletIterator[VD: ClassManifest, ED: ClassManifest](
 
 /**
  * A Graph RDD that supports computation on graphs.
+ *
+ * @param localVidMap Stores the location of vertex attributes after they are
+ * replicated. Within each partition, localVidMap holds a map from vertex ID to
+ * the index where that vertex's attribute is stored. This index refers to the
+ * arrays in the same partition in the variants of
+ * [[VTableReplicatedValues]]. Therefore, localVidMap can be reused across
+ * changes to the vertex attributes.
  */
 class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     @transient val vTable: VertexSetRDD[VD],
@@ -73,27 +80,8 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
 
   def this() = this(null, null, null, null)
 
-  /**
-   * (localVidMap: VertexSetRDD[Pid, VertexIdToIndexMap]) is a version of the
-   * vertex data after it is replicated. Within each partition, it holds a map
-   * from vertex ID to the index where that vertex's attribute is stored. This
-   * index refers to an array in the same partition in vTableReplicatedValues.
-   *
-   * (vTableReplicatedValues: VertexSetRDD[Pid, Array[VD]]) holds the vertex data
-   * and is arranged as described above.
-   */
-  @transient val vTableReplicatedValuesBothAttrs: RDD[(Pid, Array[VD])] =
-    createVTableReplicated(vTable, vid2pid.bothAttrs, localVidMap)
-
-  @transient val vTableReplicatedValuesSrcAttrOnly: RDD[(Pid, Array[VD])] =
-    createVTableReplicated(vTable, vid2pid.srcAttrOnly, localVidMap)
-
-  @transient val vTableReplicatedValuesDstAttrOnly: RDD[(Pid, Array[VD])] =
-    createVTableReplicated(vTable, vid2pid.dstAttrOnly, localVidMap)
-
-  // TODO(ankurdave): create this more efficiently
-  @transient val vTableReplicatedValuesNoAttrs: RDD[(Pid, Array[VD])] =
-    createVTableReplicated(vTable, vid2pid.noAttrs, localVidMap)
+  @transient val vTableReplicatedValues: VTableReplicatedValues[VD] =
+    new VTableReplicatedValues(vTable, vid2pid, localVidMap)
 
   /** Return a RDD of vertices. */
   @transient override val vertices = vTable
@@ -105,7 +93,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
 
   /** Return a RDD that brings edges with its source and destination vertices together. */
   @transient override val triplets: RDD[EdgeTriplet[VD, ED]] =
-    makeTriplets(localVidMap, vTableReplicatedValuesBothAttrs, eTable)
+    makeTriplets(localVidMap, vTableReplicatedValues.bothAttrs, eTable)
 
   override def persist(newLevel: StorageLevel): Graph[VD, ED] = {
     eTable.persist(newLevel)
@@ -188,9 +176,9 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     traverseLineage(localVidMap, "  ", visited)
     visited += (localVidMap.id -> "localVidMap")
 
-    println("\n\nvTableReplicatedValuesBothAttrs -----------------")
-    traverseLineage(vTableReplicatedValuesBothAttrs, "  ", visited)
-    visited += (vTableReplicatedValuesBothAttrs.id -> "vTableReplicatedValuesBothAttrs")
+    println("\n\nvTableReplicatedValues.bothAttrs ----------------")
+    traverseLineage(vTableReplicatedValues.bothAttrs, "  ", visited)
+    visited += (vTableReplicatedValues.bothAttrs.id -> "vTableReplicatedValues.bothAttrs")
 
     println("\n\ntriplets ----------------------------------------")
     traverseLineage(triplets, "  ", visited)
@@ -386,8 +374,9 @@ object GraphImpl {
     }, preservesPartitioning = true).cache()
   }
 
-  protected def createLocalVidMap[ED: ClassManifest](eTable: RDD[(Pid, EdgePartition[ED])]):
-    RDD[(Pid, VertexIdToIndexMap)] = {
+  private def createLocalVidMap(
+      eTable: RDD[(Pid, EdgePartition[ED])] forSome { type ED }
+    ): RDD[(Pid, VertexIdToIndexMap)] = {
     eTable.mapPartitions( _.map{ case (pid, epart) =>
       val vidToIndex = new VertexIdToIndexMap
       epart.foreach{ e =>
@@ -398,36 +387,6 @@ object GraphImpl {
     }, preservesPartitioning = true).cache()
   }
 
-  protected def createVTableReplicated[VD: ClassManifest](
-      vTable: VertexSetRDD[VD],
-      vid2pid: VertexSetRDD[Array[Pid]],
-      replicationMap: RDD[(Pid, VertexIdToIndexMap)]):
-    RDD[(Pid, Array[VD])] = {
-    // Join vid2pid and vTable, generate a shuffle dependency on the joined
-    // result, and get the shuffle id so we can use it on the slave.
-    val msgsByPartition = vTable.zipJoinFlatMap(vid2pid) { (vid, vdata, pids) =>
-      // TODO(rxin): reuse VertexBroadcastMessage
-      pids.iterator.map { pid =>
-        new VertexBroadcastMsg[VD](pid, vid, vdata)
-      }
-    }.partitionBy(replicationMap.partitioner.get).cache()
-
-    replicationMap.zipPartitions(msgsByPartition){
-      (mapIter, msgsIter) =>
-      val (pid, vidToIndex) = mapIter.next()
-      assert(!mapIter.hasNext)
-      // Populate the vertex array using the vidToIndex map
-      val vertexArray = new Array[VD](vidToIndex.capacity)
-      for (msg <- msgsIter) {
-        val ind = vidToIndex.getPos(msg.vid) & OpenHashSet.POSITION_MASK
-        vertexArray(ind) = msg.data
-      }
-      Iterator((pid, vertexArray))
-    }.cache()
-
-    // @todo assert edge table has partitioner
-  }
-
   def makeTriplets[VD: ClassManifest, ED: ClassManifest](
     localVidMap: RDD[(Pid, VertexIdToIndexMap)],
     vTableReplicatedValues: RDD[(Pid, Array[VD]) ],
@@ -444,7 +403,7 @@ object GraphImpl {
   def mapTriplets[VD: ClassManifest, ED: ClassManifest, ED2: ClassManifest](
     g: GraphImpl[VD, ED],
     f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
-    val newETable = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValuesBothAttrs){
+    val newETable = g.eTable.zipPartitions(g.localVidMap, g.vTableReplicatedValues.bothAttrs){
       (edgePartitionIter, vidToIndexIter, vertexArrayIter) =>
       val (pid, edgePartition) = edgePartitionIter.next()
       val (_, vidToIndex) = vidToIndexIter.next()
@@ -476,15 +435,12 @@ object GraphImpl {
       BytecodeUtils.invokedMethod(mapFunc, classOf[EdgeTriplet[VD, ED]], "srcAttr")
     val mapFuncUsesDstAttr =
       BytecodeUtils.invokedMethod(mapFunc, classOf[EdgeTriplet[VD, ED]], "dstAttr")
-    val vTableReplicatedValues = (mapFuncUsesSrcAttr, mapFuncUsesDstAttr) match {
-      case (true, true) => g.vTableReplicatedValuesBothAttrs
-      case (true, false) => g.vTableReplicatedValuesSrcAttrOnly
-      case (false, true) => g.vTableReplicatedValuesDstAttrOnly
-      case (false, false) => g.vTableReplicatedValuesNoAttrs
-    }
 
     // Map and preaggregate
-    val preAgg = g.eTable.zipPartitions(g.localVidMap, vTableReplicatedValues){
+    val preAgg = g.eTable.zipPartitions(
+      g.localVidMap,
+      g.vTableReplicatedValues.get(mapFuncUsesSrcAttr, mapFuncUsesDstAttr)
+    ){
       (edgePartitionIter, vidToIndexIter, vertexArrayIter) =>
       val (_, edgePartition) = edgePartitionIter.next()
       val (_, vidToIndex) = vidToIndexIter.next()
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala b/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala
new file mode 100644
index 0000000000000000000000000000000000000000..a9ab6255fa3c482c0d43db3b4ad11dfe1ef41554
--- /dev/null
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/VTableReplicatedValues.scala
@@ -0,0 +1,72 @@
+package org.apache.spark.graph.impl
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.collection.OpenHashSet
+
+import org.apache.spark.graph._
+import org.apache.spark.graph.impl.MsgRDDFunctions._
+
+/**
+ * Stores the vertex attribute values after they are replicated. See
+ * the description of localVidMap in [[GraphImpl]].
+ */
+class VTableReplicatedValues[VD: ClassManifest](
+    vTable: VertexSetRDD[VD],
+    vid2pid: Vid2Pid,
+    localVidMap: RDD[(Pid, VertexIdToIndexMap)]) {
+
+  val bothAttrs: RDD[(Pid, Array[VD])] =
+    VTableReplicatedValues.createVTableReplicated(vTable, vid2pid, localVidMap, true, true)
+  val srcAttrOnly: RDD[(Pid, Array[VD])] =
+    VTableReplicatedValues.createVTableReplicated(vTable, vid2pid, localVidMap, true, false)
+  val dstAttrOnly: RDD[(Pid, Array[VD])] =
+    VTableReplicatedValues.createVTableReplicated(vTable, vid2pid, localVidMap, false, true)
+  val noAttrs: RDD[(Pid, Array[VD])] =
+    VTableReplicatedValues.createVTableReplicated(vTable, vid2pid, localVidMap, false, false)
+
+
+  def get(includeSrcAttr: Boolean, includeDstAttr: Boolean): RDD[(Pid, Array[VD])] =
+    (includeSrcAttr, includeDstAttr) match {
+      case (true, true) => bothAttrs
+      case (true, false) => srcAttrOnly
+      case (false, true) => dstAttrOnly
+      case (false, false) => noAttrs
+    }
+}
+
+
+
+object VTableReplicatedValues {
+  protected def createVTableReplicated[VD: ClassManifest](
+      vTable: VertexSetRDD[VD],
+      vid2pid: Vid2Pid,
+      localVidMap: RDD[(Pid, VertexIdToIndexMap)],
+      includeSrcAttr: Boolean,
+      includeDstAttr: Boolean): RDD[(Pid, Array[VD])] = {
+
+    // Join vid2pid and vTable, generate a shuffle dependency on the joined
+    // result, and get the shuffle id so we can use it on the slave.
+    val msgsByPartition = vTable.zipJoinFlatMap(vid2pid.get(includeSrcAttr, includeDstAttr)) {
+      // TODO(rxin): reuse VertexBroadcastMessage
+      (vid, vdata, pids) => pids.iterator.map { pid =>
+        new VertexBroadcastMsg[VD](pid, vid, vdata)
+      }
+    }.partitionBy(localVidMap.partitioner.get).cache()
+
+    localVidMap.zipPartitions(msgsByPartition){
+      (mapIter, msgsIter) =>
+      val (pid, vidToIndex) = mapIter.next()
+      assert(!mapIter.hasNext)
+      // Populate the vertex array using the vidToIndex map
+      val vertexArray = new Array[VD](vidToIndex.capacity)
+      for (msg <- msgsIter) {
+        val ind = vidToIndex.getPos(msg.vid) & OpenHashSet.POSITION_MASK
+        vertexArray(ind) = msg.data
+      }
+      Iterator((pid, vertexArray))
+    }.cache()
+
+    // @todo assert edge table has partitioner
+  }
+
+}
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala b/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala
index d8c8d35ee10c05818c75125fbde44a60565394e4..9bdca7f40763092e60838e8e22cb2aa7b1fc630e 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/Vid2Pid.scala
@@ -3,12 +3,13 @@ package org.apache.spark.graph.impl
 import scala.collection.JavaConversions._
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.graph._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
 
+import org.apache.spark.graph._
+
 /**
- * Stores the layout of vertex attributes.
+ * Stores the layout of vertex attributes for GraphImpl.
  */
 class Vid2Pid(
     eTable: RDD[(Pid, EdgePartition[ED])] forSome { type ED },
@@ -17,9 +18,16 @@ class Vid2Pid(
   val bothAttrs: VertexSetRDD[Array[Pid]] = createVid2Pid(true, true)
   val srcAttrOnly: VertexSetRDD[Array[Pid]] = createVid2Pid(true, false)
   val dstAttrOnly: VertexSetRDD[Array[Pid]] = createVid2Pid(false, true)
-  // TODO(ankurdave): create this more efficiently
   val noAttrs: VertexSetRDD[Array[Pid]] = createVid2Pid(false, false)
 
+  def get(includeSrcAttr: Boolean, includeDstAttr: Boolean): VertexSetRDD[Array[Pid]] =
+    (includeSrcAttr, includeDstAttr) match {
+      case (true, true) => bothAttrs
+      case (true, false) => srcAttrOnly
+      case (false, true) => dstAttrOnly
+      case (false, false) => noAttrs
+    }
+
   def persist(newLevel: StorageLevel) {
     bothAttrs.persist(newLevel)
     srcAttrOnly.persist(newLevel)
@@ -28,15 +36,17 @@ class Vid2Pid(
   }
 
   private def createVid2Pid(
-    includeSrcAttr: Boolean,
-    includeDstAttr: Boolean): VertexSetRDD[Array[Pid]] = {
+      includeSrcAttr: Boolean,
+      includeDstAttr: Boolean): VertexSetRDD[Array[Pid]] = {
     val preAgg = eTable.mapPartitions { iter =>
       val (pid, edgePartition) = iter.next()
       val vSet = new VertexSet
-      edgePartition.foreach(e => {
-        if (includeSrcAttr) vSet.add(e.srcId)
-        if (includeDstAttr) vSet.add(e.dstId)
-      })
+      if (includeSrcAttr || includeDstAttr) {
+        edgePartition.foreach(e => {
+          if (includeSrcAttr) vSet.add(e.srcId)
+          if (includeDstAttr) vSet.add(e.dstId)
+        })
+      }
       vSet.iterator.map { vid => (vid.toLong, pid) }
     }
     VertexSetRDD[Pid, ArrayBuffer[Pid]](preAgg, vTableIndex,