diff --git a/graph/src/main/scala/org/apache/spark/graph/EdgeRDD.scala b/graph/src/main/scala/org/apache/spark/graph/EdgeRDD.scala
index 9aa76c93945f93e4491bc5812f8ec73ada029962..3dda5c7c604e534df93e8860238d343f128830bb 100644
--- a/graph/src/main/scala/org/apache/spark/graph/EdgeRDD.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/EdgeRDD.scala
@@ -42,32 +42,32 @@ class EdgeRDD[@specialized ED: ClassManifest](
   /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
   override def cache(): EdgeRDD[ED] = persist()
 
-  def mapEdgePartitions[ED2: ClassManifest](f: EdgePartition[ED] => EdgePartition[ED2])
+  def mapEdgePartitions[ED2: ClassManifest](f: (Pid, EdgePartition[ED]) => EdgePartition[ED2])
     : EdgeRDD[ED2] = {
 //       iter => iter.map { case (pid, ep) => (pid, f(ep)) }
     new EdgeRDD[ED2](partitionsRDD.mapPartitions({ iter =>
       val (pid, ep) = iter.next()
-      Iterator(Tuple2(pid, f(ep)))
+      Iterator(Tuple2(pid, f(pid, ep)))
     }, preservesPartitioning = true))
   }
 
   def zipEdgePartitions[T: ClassManifest, U: ClassManifest]
       (other: RDD[T])
-      (f: (EdgePartition[ED], Iterator[T]) => Iterator[U]): RDD[U] = {
+      (f: (Pid, EdgePartition[ED], Iterator[T]) => Iterator[U]): RDD[U] = {
     partitionsRDD.zipPartitions(other, preservesPartitioning = true) { (ePartIter, otherIter) =>
-      val (_, edgePartition) = ePartIter.next()
-      f(edgePartition, otherIter)
+      val (pid, edgePartition) = ePartIter.next()
+      f(pid, edgePartition, otherIter)
     }
   }
 
   def zipEdgePartitions[ED2: ClassManifest, ED3: ClassManifest]
       (other: EdgeRDD[ED2])
-      (f: (EdgePartition[ED], EdgePartition[ED2]) => EdgePartition[ED3]): EdgeRDD[ED3] = {
+      (f: (Pid, EdgePartition[ED], EdgePartition[ED2]) => EdgePartition[ED3]): EdgeRDD[ED3] = {
     new EdgeRDD[ED3](partitionsRDD.zipPartitions(other.partitionsRDD, preservesPartitioning = true) {
       (thisIter, otherIter) =>
         val (pid, thisEPart) = thisIter.next()
         val (_, otherEPart) = otherIter.next()
-      Iterator(Tuple2(pid, f(thisEPart, otherEPart)))
+      Iterator(Tuple2(pid, f(pid, thisEPart, otherEPart)))
     })
   }
 
@@ -76,7 +76,7 @@ class EdgeRDD[@specialized ED: ClassManifest](
       (f: (Vid, Vid, ED, ED2) => ED3): EdgeRDD[ED3] = {
     val ed2Manifest = classManifest[ED2]
     val ed3Manifest = classManifest[ED3]
-    zipEdgePartitions(other) { (thisEPart, otherEPart) =>
+    zipEdgePartitions(other) { (pid, thisEPart, otherEPart) =>
       thisEPart.innerJoin(otherEPart)(f)(ed2Manifest, ed3Manifest)
     }
   }
diff --git a/graph/src/main/scala/org/apache/spark/graph/Graph.scala b/graph/src/main/scala/org/apache/spark/graph/Graph.scala
index e8fa8e611c9cb0b43a2c0736fd5050195b12d5dc..b725b2a15584bcbde21f1b81cdce037cee2492af 100644
--- a/graph/src/main/scala/org/apache/spark/graph/Graph.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/Graph.scala
@@ -72,9 +72,15 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
    */
   val triplets: RDD[EdgeTriplet[VD, ED]]
 
+  /**
+   * Cache the vertices and edges associated with this graph.
+   *
+   * @param newLevel the level at which to cache the graph.
 
-
-  def persist(newLevel: StorageLevel): Graph[VD, ED]
+   * @return A reference to this graph for convenience.
+   *
+   */
+  def persist(newLevel: StorageLevel = StorageLevel.MEMORY_ONLY): Graph[VD, ED]
 
   /**
    * Return a graph that is cached when first created. This is used to
@@ -120,7 +126,7 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
   def mapVertices[VD2: ClassManifest](map: (Vid, VD) => VD2): Graph[VD2, ED]
 
   /**
-   * Construct a new graph where each the value of each edge is
+   * Construct a new graph where the value of each edge is
    * transformed by the map operation.  This function is not passed
    * the vertex value for the vertices adjacent to the edge.  If
    * vertex values are desired use the mapTriplets function.
@@ -137,18 +143,44 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
    * attributes.
    *
    */
-  def mapEdges[ED2: ClassManifest](map: Edge[ED] => ED2): Graph[VD, ED2]
+  def mapEdges[ED2: ClassManifest](map: Edge[ED] => ED2): Graph[VD, ED2] = {
+    mapEdges((pid, iter) => iter.map(map))
+  }
+
+  /**
+   * Construct a new graph transforming the value of each edge using
+   * the user defined iterator transform.  The iterator transform is
+   * given an iterator over edge triplets within a logical partition
+   * and should yield a new iterator over the new values of each edge
+   * in the order in which they are provided to the iterator transform
+   * If adjacent vertex values are not required, consider using the
+   * mapEdges function instead.
+   *
+   * @note This that this does not change the structure of the
+   * graph or modify the values of this graph.  As a consequence
+   * the underlying index structures can be reused.
+   *
+   * @param map the function which takes a partition id and an iterator
+   * over all the edges in the partition and must return an iterator over
+   * the new values for each edge in the order of the input iterator.
+   *
+   * @tparam ED2 the new edge data type
+   *
+   */
+  def mapEdges[ED2: ClassManifest](
+      map: (Pid, Iterator[Edge[ED]]) => Iterator[ED2]):
+    Graph[VD, ED2]
 
   /**
-   * Construct a new graph where each the value of each edge is
+   * Construct a new graph where the value of each edge is
    * transformed by the map operation.  This function passes vertex
    * values for the adjacent vertices to the map function.  If
    * adjacent vertex values are not required, consider using the
    * mapEdges function instead.
    *
-   * @note This graph is not changed and that the new graph has the
-   * same structure.  As a consequence the underlying index structures
-   * can be reused.
+   * @note This that this does not change the structure of the
+   * graph or modify the values of this graph.  As a consequence
+   * the underlying index structures can be reused.
    *
    * @param map the function from an edge object to a new edge value.
    *
@@ -163,7 +195,33 @@ abstract class Graph[VD: ClassManifest, ED: ClassManifest] {
    * }}}
    *
    */
-  def mapTriplets[ED2: ClassManifest](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2]
+  def mapTriplets[ED2: ClassManifest](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
+    mapTriplets((pid, iter) => iter.map(map))
+  }
+
+  /**
+   * Construct a new graph transforming the value of each edge using
+   * the user defined iterator transform.  The iterator transform is
+   * given an iterator over edge triplets within a logical partition
+   * and should yield a new iterator over the new values of each edge
+   * in the order in which they are provided to the iterator transform
+   * If adjacent vertex values are not required, consider using the
+   * mapEdges function instead.
+   *
+   * @note This that this does not change the structure of the
+   * graph or modify the values of this graph.  As a consequence
+   * the underlying index structures can be reused.
+   *
+   * @param map the function which takes a partition id and an iterator
+   * over all the edges in the partition and must return an iterator over
+   * the new values for each edge in the order of the input iterator.
+   *
+   * @tparam ED2 the new edge data type
+   *
+   */
+  def mapTriplets[ED2: ClassManifest](
+      map: (Pid, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]):
+    Graph[VD, ED2]
 
   /**
    * Construct a new graph with all the edges reversed.  If this graph
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/EdgePartition.scala b/graph/src/main/scala/org/apache/spark/graph/impl/EdgePartition.scala
index e97522feaeaf2381a02ec9cdf5b152f4f44181c6..4fcf08efce38293d0380445c649175d4b47c89a1 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/EdgePartition.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/EdgePartition.scala
@@ -56,6 +56,30 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double)
     new EdgePartition(srcIds, dstIds, newData, index)
   }
 
+  /**
+   * Construct a new edge partition by using the edge attributes
+   * contained in the iterator.
+   *
+   * @note The input iterator should return edge attributes in the
+   * order of the edges returned by `EdgePartition.iterator` and
+   * should return attributes equal to the number of edges.
+   *
+   * @param f a function from an edge to a new attribute
+   * @tparam ED2 the type of the new attribute
+   * @return a new edge partition with the result of the function `f`
+   *         applied to each edge
+   */
+  def map[ED2: ClassManifest](iter: Iterator[ED2]): EdgePartition[ED2] = {
+    val newData = new Array[ED2](data.size)
+    var i = 0
+    while (iter.hasNext) {
+      newData(i) = iter.next()
+      i += 1
+    }
+    assert(newData.size == i)
+    new EdgePartition(srcIds, dstIds, newData, index)
+  }
+
   /**
    * Apply the function f to all edges in this partition.
    *
diff --git a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
index 16d73820f07df560163db2bb2c48cf83e57291b2..79c11c780a69a97b3873c9036f30c9be5d756c22 100644
--- a/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
+++ b/graph/src/main/scala/org/apache/spark/graph/impl/GraphImpl.scala
@@ -47,8 +47,7 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
   @transient override val triplets: RDD[EdgeTriplet[VD, ED]] = {
     val vdManifest = classManifest[VD]
     val edManifest = classManifest[ED]
-
-    edges.zipEdgePartitions(replicatedVertexView.get(true, true)) { (ePart, vPartIter) =>
+    edges.zipEdgePartitions(replicatedVertexView.get(true, true)) { (pid, ePart, vPartIter) =>
       val (_, vPart) = vPartIter.next()
       new EdgeTripletIterator(vPart.index, vPart.values, ePart)(vdManifest, edManifest)
     }
@@ -149,8 +148,10 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     println(visited)
   } // end of printLineage
 
-  override def reverse: Graph[VD, ED] =
-    new GraphImpl(vertices, edges.mapEdgePartitions(_.reverse), routingTable, replicatedVertexView)
+  override def reverse: Graph[VD, ED] = {
+    val newETable = edges.mapEdgePartitions((pid, part) => part.reverse)
+    new GraphImpl(vertices, newETable, routingTable, replicatedVertexView)
+  }
 
   override def mapVertices[VD2: ClassManifest](f: (Vid, VD) => VD2): Graph[VD2, ED] = {
     if (classManifest[VD] equals classManifest[VD2]) {
@@ -167,25 +168,36 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     }
   }
 
-  override def mapEdges[ED2: ClassManifest](f: Edge[ED] => ED2): Graph[VD, ED2] =
-    new GraphImpl(vertices, edges.mapEdgePartitions(_.map(f)), routingTable, replicatedVertexView)
+  override def mapEdges[ED2: ClassManifest](
+      f: (Pid, Iterator[Edge[ED]]) => Iterator[ED2]): Graph[VD, ED2] = {
+    val newETable = edges.mapEdgePartitions((pid, part) => part.map(f(pid, part.iterator)))
+    new GraphImpl(vertices, newETable , routingTable, replicatedVertexView)
+  }
 
-  override def mapTriplets[ED2: ClassManifest](f: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
+  override def mapTriplets[ED2: ClassManifest](
+      f: (Pid, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]): Graph[VD, ED2] = {
     // Use an explicit manifest in PrimitiveKeyOpenHashMap init so we don't pull in the implicit
     // manifest from GraphImpl (which would require serializing GraphImpl).
     val vdManifest = classManifest[VD]
     val newEdgePartitions =
-      edges.zipEdgePartitions(replicatedVertexView.get(true, true)) { (edgePartition, vPartIter) =>
-        val (pid, vPart) = vPartIter.next()
+      edges.zipEdgePartitions(replicatedVertexView.get(true, true)) {
+        (ePid, edgePartition, vTableReplicatedIter) =>
+        val (vPid, vPart) = vTableReplicatedIter.next()
+        assert(!vTableReplicatedIter.hasNext)
+        assert(ePid == vPid)
         val et = new EdgeTriplet[VD, ED]
-        val newEdgePartition = edgePartition.map { e =>
+        val inputIterator = edgePartition.iterator.map { e =>
           et.set(e)
           et.srcAttr = vPart(e.srcId)
           et.dstAttr = vPart(e.dstId)
-          f(et)
+          et
         }
-        Iterator((pid, newEdgePartition))
-    }
+        // Apply the user function to the vertex partition
+        val outputIter = f(ePid, inputIterator)
+        // Consume the iterator to update the edge attributes
+        val newEdgePartition = edgePartition.map(outputIter)
+        Iterator((ePid, newEdgePartition))
+      }
     new GraphImpl(vertices, new EdgeRDD(newEdgePartitions), routingTable, replicatedVertexView)
   }
 
@@ -224,8 +236,8 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
 
   override def groupEdges(merge: (ED, ED) => ED): Graph[VD, ED] = {
     ClosureCleaner.clean(merge)
-    val newEdges = edges.mapEdgePartitions(_.groupEdges(merge))
-    new GraphImpl(vertices, newEdges, routingTable, replicatedVertexView)
+    val newETable = edges.mapEdgePartitions((pid, part) => part.groupEdges(merge))
+    new GraphImpl(vertices, newETable, routingTable, replicatedVertexView)
   }
 
   //////////////////////////////////////////////////////////////////////////////////////////////////
@@ -253,9 +265,10 @@ class GraphImpl[VD: ClassManifest, ED: ClassManifest] protected (
     val activeDirectionOpt = activeSetOpt.map(_._2)
 
     // Map and combine.
-    val preAgg = edges.zipEdgePartitions(vs) { (edgePartition, vPartIter) =>
-      val (_, vPart) = vPartIter.next()
-
+    val preAgg = edges.zipEdgePartitions(vs) { (ePid, edgePartition, vPartIter) =>
+      val (vPid, vPart) = vPartIter.next()
+      assert(!vPartIter.hasNext)
+      assert(ePid == vPid)
       // Choose scan method
       val activeFraction = vPart.numActives.getOrElse(0) / edgePartition.indexSize.toFloat
       val edgeIter = activeDirectionOpt match {