diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala
new file mode 100644
index 0000000000000000000000000000000000000000..f70715fca6eea564f1512a9e675f5325551f19ec
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.graphx
+
+/**
+ * Represents an edge along with its neighboring vertices and allows sending messages along the
+ * edge. Used in [[Graph#aggregateMessages]].
+ */
+abstract class EdgeContext[VD, ED, A] {
+  /** The vertex id of the edge's source vertex. */
+  def srcId: VertexId
+  /** The vertex id of the edge's destination vertex. */
+  def dstId: VertexId
+  /** The vertex attribute of the edge's source vertex. */
+  def srcAttr: VD
+  /** The vertex attribute of the edge's destination vertex. */
+  def dstAttr: VD
+  /** The attribute associated with the edge. */
+  def attr: ED
+
+  /** Sends a message to the source vertex. */
+  def sendToSrc(msg: A): Unit
+  /** Sends a message to the destination vertex. */
+  def sendToDst(msg: A): Unit
+
+  /** Converts the edge and vertex properties into an [[EdgeTriplet]] for convenience. */
+  def toEdgeTriplet: EdgeTriplet[VD, ED] = {
+    val et = new EdgeTriplet[VD, ED]
+    et.srcId = srcId
+    et.srcAttr = srcAttr
+    et.dstId = dstId
+    et.dstAttr = dstAttr
+    et.attr = attr
+    et
+  }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
index fa4b891754c40fa053239dd759953b79778dfdf9..e0ba9403ba75b8103e64f8134c58154ca8479ac0 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
@@ -207,8 +207,39 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
    * }}}
    *
    */
-  def mapTriplets[ED2: ClassTag](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
-    mapTriplets((pid, iter) => iter.map(map))
+  def mapTriplets[ED2: ClassTag](
+      map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = {
+    mapTriplets((pid, iter) => iter.map(map), TripletFields.All)
+  }
+
+  /**
+   * Transforms each edge attribute using the map function, passing it the adjacent vertex
+   * attributes as well. If adjacent vertex values are not required,
+   * consider using `mapEdges` instead.
+   *
+   * @note This does not change the structure of the
+   * graph or modify the values of this graph.  As a consequence
+   * the underlying index structures can be reused.
+   *
+   * @param map the function from an edge object to a new edge value.
+   * @param tripletFields which fields should be included in the edge triplet passed to the map
+   *   function. If not all fields are needed, specifying this can improve performance.
+   *
+   * @tparam ED2 the new edge data type
+   *
+   * @example This function might be used to initialize edge
+   * attributes based on the attributes associated with each vertex.
+   * {{{
+   * val rawGraph: Graph[Int, Int] = someLoadFunction()
+   * val graph = rawGraph.mapTriplets[Int]( edge =>
+   *   edge.src.data - edge.dst.data)
+   * }}}
+   *
+   */
+  def mapTriplets[ED2: ClassTag](
+      map: EdgeTriplet[VD, ED] => ED2,
+      tripletFields: TripletFields): Graph[VD, ED2] = {
+    mapTriplets((pid, iter) => iter.map(map), tripletFields)
   }
 
   /**
@@ -223,12 +254,15 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
    * the underlying index structures can be reused.
    *
    * @param map the iterator transform
+   * @param tripletFields which fields should be included in the edge triplet passed to the map
+   *   function. If not all fields are needed, specifying this can improve performance.
    *
    * @tparam ED2 the new edge data type
    *
    */
-  def mapTriplets[ED2: ClassTag](map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2])
-    : Graph[VD, ED2]
+  def mapTriplets[ED2: ClassTag](
+      map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2],
+      tripletFields: TripletFields): Graph[VD, ED2]
 
   /**
    * Reverses all edges in the graph.  If this graph contains an edge from a to b then the returned
@@ -287,6 +321,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
    * "sent" to either vertex in the edge.  The `reduceFunc` is then used to combine the output of
    * the map phase destined to each vertex.
    *
+   * This function is deprecated in 1.2.0 because of SPARK-3936. Use aggregateMessages instead.
+   *
    * @tparam A the type of "message" to be sent to each vertex
    *
    * @param mapFunc the user defined map function which returns 0 or
@@ -296,13 +332,15 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
    * be commutative and associative and is used to combine the output
    * of the map phase
    *
-   * @param activeSetOpt optionally, a set of "active" vertices and a direction of edges to
-   * consider when running `mapFunc`. If the direction is `In`, `mapFunc` will only be run on
-   * edges with destination in the active set.  If the direction is `Out`,
-   * `mapFunc` will only be run on edges originating from vertices in the active set. If the
-   * direction is `Either`, `mapFunc` will be run on edges with *either* vertex in the active set
-   * . If the direction is `Both`, `mapFunc` will be run on edges with *both* vertices in the
-   * active set. The active set must have the same index as the graph's vertices.
+   * @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if
+   * desired. This is done by specifying a set of "active" vertices and an edge direction. The
+   * `sendMsg` function will then run only on edges connected to active vertices by edges in the
+   * specified direction. If the direction is `In`, `sendMsg` will only be run on edges with
+   * destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges
+   * originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be
+   * run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg`
+   * will be run on edges with *both* vertices in the active set. The active set must have the
+   * same index as the graph's vertices.
    *
    * @example We can use this function to compute the in-degree of each
    * vertex
@@ -319,6 +357,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
    * predicate or implement PageRank.
    *
    */
+  @deprecated("use aggregateMessages", "1.2.0")
   def mapReduceTriplets[A: ClassTag](
       mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
       reduceFunc: (A, A) => A,
@@ -326,8 +365,80 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
     : VertexRDD[A]
 
   /**
-   * Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`.  The
-   * input table should contain at most one entry for each vertex.  If no entry in `other` is
+   * Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied
+   * `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be
+   * sent to either vertex in the edge. The `mergeMsg` function is then used to combine all messages
+   * destined to the same vertex.
+   *
+   * @tparam A the type of message to be sent to each vertex
+   *
+   * @param sendMsg runs on each edge, sending messages to neighboring vertices using the
+   *   [[EdgeContext]].
+   * @param mergeMsg used to combine messages from `sendMsg` destined to the same vertex. This
+   *   combiner should be commutative and associative.
+   * @param tripletFields which fields should be included in the [[EdgeContext]] passed to the
+   *   `sendMsg` function. If not all fields are needed, specifying this can improve performance.
+   *
+   * @example We can use this function to compute the in-degree of each
+   * vertex
+   * {{{
+   * val rawGraph: Graph[_, _] = Graph.textFile("twittergraph")
+   * val inDeg: RDD[(VertexId, Int)] =
+   *   aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _)
+   * }}}
+   *
+   * @note By expressing computation at the edge level we achieve
+   * maximum parallelism.  This is one of the core functions in the
+   * Graph API in that enables neighborhood level computation. For
+   * example this function can be used to count neighbors satisfying a
+   * predicate or implement PageRank.
+   *
+   */
+  def aggregateMessages[A: ClassTag](
+      sendMsg: EdgeContext[VD, ED, A] => Unit,
+      mergeMsg: (A, A) => A,
+      tripletFields: TripletFields = TripletFields.All)
+    : VertexRDD[A] = {
+    aggregateMessagesWithActiveSet(sendMsg, mergeMsg, tripletFields, None)
+  }
+
+  /**
+   * Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied
+   * `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be
+   * sent to either vertex in the edge. The `mergeMsg` function is then used to combine all messages
+   * destined to the same vertex.
+   *
+   * This variant can take an active set to restrict the computation and is intended for internal
+   * use only.
+   *
+   * @tparam A the type of message to be sent to each vertex
+   *
+   * @param sendMsg runs on each edge, sending messages to neighboring vertices using the
+   *   [[EdgeContext]].
+   * @param mergeMsg used to combine messages from `sendMsg` destined to the same vertex. This
+   *   combiner should be commutative and associative.
+   * @param tripletFields which fields should be included in the [[EdgeContext]] passed to the
+   *   `sendMsg` function. If not all fields are needed, specifying this can improve performance.
+   * @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if
+   *   desired. This is done by specifying a set of "active" vertices and an edge direction. The
+   *   `sendMsg` function will then run on only edges connected to active vertices by edges in the
+   *   specified direction. If the direction is `In`, `sendMsg` will only be run on edges with
+   *   destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges
+   *   originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be
+   *   run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg`
+   *   will be run on edges with *both* vertices in the active set. The active set must have the
+   *   same index as the graph's vertices.
+   */
+  private[graphx] def aggregateMessagesWithActiveSet[A: ClassTag](
+      sendMsg: EdgeContext[VD, ED, A] => Unit,
+      mergeMsg: (A, A) => A,
+      tripletFields: TripletFields,
+      activeSetOpt: Option[(VertexRDD[_], EdgeDirection)])
+    : VertexRDD[A]
+
+  /**
+   * Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`.
+   * The input table should contain at most one entry for each vertex.  If no entry in `other` is
    * provided for a particular vertex in the graph, the map function receives `None`.
    *
    * @tparam U the type of entry in the table of updates
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
index d0dd45dba618ee155995dae22964807abc132063..d5150382d599bed70b47e9700b9f418ee89f3dba 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
@@ -69,11 +69,12 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
    */
   private def degreesRDD(edgeDirection: EdgeDirection): VertexRDD[Int] = {
     if (edgeDirection == EdgeDirection.In) {
-      graph.mapReduceTriplets(et => Iterator((et.dstId,1)), _ + _)
+      graph.aggregateMessages(_.sendToDst(1), _ + _, TripletFields.None)
     } else if (edgeDirection == EdgeDirection.Out) {
-      graph.mapReduceTriplets(et => Iterator((et.srcId,1)), _ + _)
+      graph.aggregateMessages(_.sendToSrc(1), _ + _, TripletFields.None)
     } else { // EdgeDirection.Either
-      graph.mapReduceTriplets(et => Iterator((et.srcId,1), (et.dstId,1)), _ + _)
+      graph.aggregateMessages(ctx => { ctx.sendToSrc(1); ctx.sendToDst(1) }, _ + _,
+        TripletFields.None)
     }
   }
 
@@ -88,18 +89,17 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
   def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]] = {
     val nbrs =
       if (edgeDirection == EdgeDirection.Either) {
-        graph.mapReduceTriplets[Array[VertexId]](
-          mapFunc = et => Iterator((et.srcId, Array(et.dstId)), (et.dstId, Array(et.srcId))),
-          reduceFunc = _ ++ _
-        )
+        graph.aggregateMessages[Array[VertexId]](
+          ctx => { ctx.sendToSrc(Array(ctx.dstId)); ctx.sendToDst(Array(ctx.srcId)) },
+          _ ++ _, TripletFields.None)
       } else if (edgeDirection == EdgeDirection.Out) {
-        graph.mapReduceTriplets[Array[VertexId]](
-          mapFunc = et => Iterator((et.srcId, Array(et.dstId))),
-          reduceFunc = _ ++ _)
+        graph.aggregateMessages[Array[VertexId]](
+          ctx => ctx.sendToSrc(Array(ctx.dstId)),
+          _ ++ _, TripletFields.None)
       } else if (edgeDirection == EdgeDirection.In) {
-        graph.mapReduceTriplets[Array[VertexId]](
-          mapFunc = et => Iterator((et.dstId, Array(et.srcId))),
-          reduceFunc = _ ++ _)
+        graph.aggregateMessages[Array[VertexId]](
+          ctx => ctx.sendToDst(Array(ctx.srcId)),
+          _ ++ _, TripletFields.None)
       } else {
         throw new SparkException("It doesn't make sense to collect neighbor ids without a " +
           "direction. (EdgeDirection.Both is not supported; use EdgeDirection.Either instead.)")
@@ -122,22 +122,27 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
    * @return the vertex set of neighboring vertex attributes for each vertex
    */
   def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = {
-    val nbrs = graph.mapReduceTriplets[Array[(VertexId,VD)]](
-      edge => {
-        val msgToSrc = (edge.srcId, Array((edge.dstId, edge.dstAttr)))
-        val msgToDst = (edge.dstId, Array((edge.srcId, edge.srcAttr)))
-        edgeDirection match {
-          case EdgeDirection.Either => Iterator(msgToSrc, msgToDst)
-          case EdgeDirection.In => Iterator(msgToDst)
-          case EdgeDirection.Out => Iterator(msgToSrc)
-          case EdgeDirection.Both =>
-            throw new SparkException("collectNeighbors does not support EdgeDirection.Both. Use" +
-              "EdgeDirection.Either instead.")
-        }
-      },
-      (a, b) => a ++ b)
-
-    graph.vertices.leftZipJoin(nbrs) { (vid, vdata, nbrsOpt) =>
+    val nbrs = edgeDirection match {
+      case EdgeDirection.Either =>
+        graph.aggregateMessages[Array[(VertexId,VD)]](
+          ctx => {
+            ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr)))
+            ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr)))
+          },
+          (a, b) => a ++ b, TripletFields.SrcDstOnly)
+      case EdgeDirection.In =>
+        graph.aggregateMessages[Array[(VertexId,VD)]](
+          ctx => ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))),
+          (a, b) => a ++ b, TripletFields.SrcOnly)
+      case EdgeDirection.Out =>
+        graph.aggregateMessages[Array[(VertexId,VD)]](
+          ctx => ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))),
+          (a, b) => a ++ b, TripletFields.DstOnly)
+      case EdgeDirection.Both =>
+        throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" +
+          "EdgeDirection.Either instead.")
+    }
+    graph.vertices.leftJoin(nbrs) { (vid, vdata, nbrsOpt) =>
       nbrsOpt.getOrElse(Array.empty[(VertexId, VD)])
     }
   } // end of collectNeighbor
@@ -160,18 +165,20 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
   def collectEdges(edgeDirection: EdgeDirection): VertexRDD[Array[Edge[ED]]] = {
     edgeDirection match {
       case EdgeDirection.Either =>
-        graph.mapReduceTriplets[Array[Edge[ED]]](
-          edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr))),
-                           (edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
-          (a, b) => a ++ b)
+        graph.aggregateMessages[Array[Edge[ED]]](
+          ctx => {
+            ctx.sendToSrc(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr)))
+            ctx.sendToDst(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr)))
+          },
+          (a, b) => a ++ b, TripletFields.EdgeOnly)
       case EdgeDirection.In =>
-        graph.mapReduceTriplets[Array[Edge[ED]]](
-          edge => Iterator((edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
-          (a, b) => a ++ b)
+        graph.aggregateMessages[Array[Edge[ED]]](
+          ctx => ctx.sendToDst(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))),
+          (a, b) => a ++ b, TripletFields.EdgeOnly)
       case EdgeDirection.Out =>
-        graph.mapReduceTriplets[Array[Edge[ED]]](
-          edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))),
-          (a, b) => a ++ b)
+        graph.aggregateMessages[Array[Edge[ED]]](
+          ctx => ctx.sendToSrc(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))),
+          (a, b) => a ++ b, TripletFields.EdgeOnly)
       case EdgeDirection.Both =>
         throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" +
           "EdgeDirection.Either instead.")
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java
new file mode 100644
index 0000000000000000000000000000000000000000..34df4b7ee7a0615ff01e33d32b7d5570b7c6c46c
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.graphx;
+
+import java.io.Serializable;
+
+/**
+ * Represents a subset of the fields of an [[EdgeTriplet]] or [[EdgeContext]]. This allows the
+ * system to populate only those fields for efficiency.
+ */
+public class TripletFields implements Serializable {
+  public final boolean useSrc;
+  public final boolean useDst;
+  public final boolean useEdge;
+
+  public TripletFields() {
+    this(true, true, true);
+  }
+
+  public TripletFields(boolean useSrc, boolean useDst, boolean useEdge) {
+    this.useSrc = useSrc;
+    this.useDst = useDst;
+    this.useEdge = useEdge;
+  }
+
+  public static final TripletFields None = new TripletFields(false, false, false);
+  public static final TripletFields EdgeOnly = new TripletFields(false, false, true);
+  public static final TripletFields SrcOnly = new TripletFields(true, false, false);
+  public static final TripletFields DstOnly = new TripletFields(false, true, false);
+  public static final TripletFields SrcDstOnly = new TripletFields(true, true, false);
+  public static final TripletFields SrcAndEdge = new TripletFields(true, false, true);
+  public static final TripletFields Src = SrcAndEdge;
+  public static final TripletFields DstAndEdge = new TripletFields(false, true, true);
+  public static final TripletFields Dst = DstAndEdge;
+  public static final TripletFields All = new TripletFields(true, true, true);
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
index a5c9cd1f8b4e632f035a1cab909ddfd5cc11d80b..78d8ac24b5271b9bc9102d2c16dc830c204817e6 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala
@@ -21,63 +21,93 @@ import scala.reflect.{classTag, ClassTag}
 
 import org.apache.spark.graphx._
 import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
+import org.apache.spark.util.collection.BitSet
 
 /**
- * A collection of edges stored in columnar format, along with any vertex attributes referenced. The
- * edges are stored in 3 large columnar arrays (src, dst, attribute). The arrays are clustered by
- * src. There is an optional active vertex set for filtering computation on the edges.
+ * A collection of edges, along with referenced vertex attributes and an optional active vertex set
+ * for filtering computation on the edges.
+ *
+ * The edges are stored in columnar format in `localSrcIds`, `localDstIds`, and `data`. All
+ * referenced global vertex ids are mapped to a compact set of local vertex ids according to the
+ * `global2local` map. Each local vertex id is a valid index into `vertexAttrs`, which stores the
+ * corresponding vertex attribute, and `local2global`, which stores the reverse mapping to global
+ * vertex id. The global vertex ids that are active are optionally stored in `activeSet`.
+ *
+ * The edges are clustered by source vertex id, and the mapping from global vertex id to the index
+ * of the corresponding edge cluster is stored in `index`.
  *
  * @tparam ED the edge attribute type
  * @tparam VD the vertex attribute type
  *
- * @param srcIds the source vertex id of each edge
- * @param dstIds the destination vertex id of each edge
+ * @param localSrcIds the local source vertex id of each edge as an index into `local2global` and
+ *   `vertexAttrs`
+ * @param localDstIds the local destination vertex id of each edge as an index into `local2global`
+ *   and `vertexAttrs`
  * @param data the attribute associated with each edge
- * @param index a clustered index on source vertex id
- * @param vertices a map from referenced vertex ids to their corresponding attributes. Must
- *   contain all vertex ids from `srcIds` and `dstIds`, though not necessarily valid attributes for
- *   those vertex ids. The mask is not used.
+ * @param index a clustered index on source vertex id as a map from each global source vertex id to
+ *   the offset in the edge arrays where the cluster for that vertex id begins
+ * @param global2local a map from referenced vertex ids to local ids which index into vertexAttrs
+ * @param local2global an array of global vertex ids where the offsets are local vertex ids
+ * @param vertexAttrs an array of vertex attributes where the offsets are local vertex ids
  * @param activeSet an optional active vertex set for filtering computation on the edges
  */
 private[graphx]
 class EdgePartition[
     @specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED: ClassTag, VD: ClassTag](
-    val srcIds: Array[VertexId] = null,
-    val dstIds: Array[VertexId] = null,
-    val data: Array[ED] = null,
-    val index: GraphXPrimitiveKeyOpenHashMap[VertexId, Int] = null,
-    val vertices: VertexPartition[VD] = null,
-    val activeSet: Option[VertexSet] = None
-  ) extends Serializable {
+    localSrcIds: Array[Int],
+    localDstIds: Array[Int],
+    data: Array[ED],
+    index: GraphXPrimitiveKeyOpenHashMap[VertexId, Int],
+    global2local: GraphXPrimitiveKeyOpenHashMap[VertexId, Int],
+    local2global: Array[VertexId],
+    vertexAttrs: Array[VD],
+    activeSet: Option[VertexSet])
+  extends Serializable {
 
-  /** Return a new `EdgePartition` with the specified edge data. */
-  def withData[ED2: ClassTag](data_ : Array[ED2]): EdgePartition[ED2, VD] = {
-    new EdgePartition(srcIds, dstIds, data_, index, vertices, activeSet)
-  }
+  private def this() = this(null, null, null, null, null, null, null, null)
 
-  /** Return a new `EdgePartition` with the specified vertex partition. */
-  def withVertices[VD2: ClassTag](
-      vertices_ : VertexPartition[VD2]): EdgePartition[ED, VD2] = {
-    new EdgePartition(srcIds, dstIds, data, index, vertices_, activeSet)
+  /** Return a new `EdgePartition` with the specified edge data. */
+  def withData[ED2: ClassTag](data: Array[ED2]): EdgePartition[ED2, VD] = {
+    new EdgePartition(
+      localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, activeSet)
   }
 
   /** Return a new `EdgePartition` with the specified active set, provided as an iterator. */
   def withActiveSet(iter: Iterator[VertexId]): EdgePartition[ED, VD] = {
-    val newActiveSet = new VertexSet
-    iter.foreach(newActiveSet.add(_))
-    new EdgePartition(srcIds, dstIds, data, index, vertices, Some(newActiveSet))
-  }
-
-  /** Return a new `EdgePartition` with the specified active set. */
-  def withActiveSet(activeSet_ : Option[VertexSet]): EdgePartition[ED, VD] = {
-    new EdgePartition(srcIds, dstIds, data, index, vertices, activeSet_)
+    val activeSet = new VertexSet
+    while (iter.hasNext) { activeSet.add(iter.next()) }
+    new EdgePartition(
+      localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs,
+      Some(activeSet))
   }
 
   /** Return a new `EdgePartition` with updates to vertex attributes specified in `iter`. */
   def updateVertices(iter: Iterator[(VertexId, VD)]): EdgePartition[ED, VD] = {
-    this.withVertices(vertices.innerJoinKeepLeft(iter))
+    val newVertexAttrs = new Array[VD](vertexAttrs.length)
+    System.arraycopy(vertexAttrs, 0, newVertexAttrs, 0, vertexAttrs.length)
+    while (iter.hasNext) {
+      val kv = iter.next()
+      newVertexAttrs(global2local(kv._1)) = kv._2
+    }
+    new EdgePartition(
+      localSrcIds, localDstIds, data, index, global2local, local2global, newVertexAttrs,
+      activeSet)
+  }
+
+  /** Return a new `EdgePartition` without any locally cached vertex attributes. */
+  def withoutVertexAttributes[VD2: ClassTag](): EdgePartition[ED, VD2] = {
+    val newVertexAttrs = new Array[VD2](vertexAttrs.length)
+    new EdgePartition(
+      localSrcIds, localDstIds, data, index, global2local, local2global, newVertexAttrs,
+      activeSet)
   }
 
+  @inline private def srcIds(pos: Int): VertexId = local2global(localSrcIds(pos))
+
+  @inline private def dstIds(pos: Int): VertexId = local2global(localDstIds(pos))
+
+  @inline private def attrs(pos: Int): ED = data(pos)
+
   /** Look up vid in activeSet, throwing an exception if it is None. */
   def isActive(vid: VertexId): Boolean = {
     activeSet.get.contains(vid)
@@ -92,11 +122,19 @@ class EdgePartition[
    * @return a new edge partition with all edges reversed.
    */
   def reverse: EdgePartition[ED, VD] = {
-    val builder = new EdgePartitionBuilder(size)(classTag[ED], classTag[VD])
-    for (e <- iterator) {
-      builder.add(e.dstId, e.srcId, e.attr)
+    val builder = new ExistingEdgePartitionBuilder[ED, VD](
+      global2local, local2global, vertexAttrs, activeSet, size)
+    var i = 0
+    while (i < size) {
+      val localSrcId = localSrcIds(i)
+      val localDstId = localDstIds(i)
+      val srcId = local2global(localSrcId)
+      val dstId = local2global(localDstId)
+      val attr = data(i)
+      builder.add(dstId, srcId, localDstId, localSrcId, attr)
+      i += 1
     }
-    builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet)
+    builder.toEdgePartition
   }
 
   /**
@@ -157,13 +195,25 @@ class EdgePartition[
   def filter(
       epred: EdgeTriplet[VD, ED] => Boolean,
       vpred: (VertexId, VD) => Boolean): EdgePartition[ED, VD] = {
-    val filtered = tripletIterator().filter(et =>
-      vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et))
-    val builder = new EdgePartitionBuilder[ED, VD]
-    for (e <- filtered) {
-      builder.add(e.srcId, e.dstId, e.attr)
+    val builder = new ExistingEdgePartitionBuilder[ED, VD](
+      global2local, local2global, vertexAttrs, activeSet)
+    var i = 0
+    while (i < size) {
+      // The user sees the EdgeTriplet, so we can't reuse it and must create one per edge.
+      val localSrcId = localSrcIds(i)
+      val localDstId = localDstIds(i)
+      val et = new EdgeTriplet[VD, ED]
+      et.srcId = local2global(localSrcId)
+      et.dstId = local2global(localDstId)
+      et.srcAttr = vertexAttrs(localSrcId)
+      et.dstAttr = vertexAttrs(localDstId)
+      et.attr = data(i)
+      if (vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et)) {
+        builder.add(et.srcId, et.dstId, localSrcId, localDstId, et.attr)
+      }
+      i += 1
     }
-    builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet)
+    builder.toEdgePartition
   }
 
   /**
@@ -183,28 +233,40 @@ class EdgePartition[
    * @return a new edge partition without duplicate edges
    */
   def groupEdges(merge: (ED, ED) => ED): EdgePartition[ED, VD] = {
-    val builder = new EdgePartitionBuilder[ED, VD]
+    val builder = new ExistingEdgePartitionBuilder[ED, VD](
+      global2local, local2global, vertexAttrs, activeSet)
     var currSrcId: VertexId = null.asInstanceOf[VertexId]
     var currDstId: VertexId = null.asInstanceOf[VertexId]
+    var currLocalSrcId = -1
+    var currLocalDstId = -1
     var currAttr: ED = null.asInstanceOf[ED]
+    // Iterate through the edges, accumulating runs of identical edges using the curr* variables and
+    // releasing them to the builder when we see the beginning of the next run
     var i = 0
     while (i < size) {
       if (i > 0 && currSrcId == srcIds(i) && currDstId == dstIds(i)) {
+        // This edge should be accumulated into the existing run
         currAttr = merge(currAttr, data(i))
       } else {
+        // This edge starts a new run of edges
         if (i > 0) {
-          builder.add(currSrcId, currDstId, currAttr)
+          // First release the existing run to the builder
+          builder.add(currSrcId, currDstId, currLocalSrcId, currLocalDstId, currAttr)
         }
+        // Then start accumulating for a new run
         currSrcId = srcIds(i)
         currDstId = dstIds(i)
+        currLocalSrcId = localSrcIds(i)
+        currLocalDstId = localDstIds(i)
         currAttr = data(i)
       }
       i += 1
     }
+    // Finally, release the last accumulated run
     if (size > 0) {
-      builder.add(currSrcId, currDstId, currAttr)
+      builder.add(currSrcId, currDstId, currLocalSrcId, currLocalDstId, currAttr)
     }
-    builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet)
+    builder.toEdgePartition
   }
 
   /**
@@ -220,7 +282,8 @@ class EdgePartition[
   def innerJoin[ED2: ClassTag, ED3: ClassTag]
       (other: EdgePartition[ED2, _])
       (f: (VertexId, VertexId, ED, ED2) => ED3): EdgePartition[ED3, VD] = {
-    val builder = new EdgePartitionBuilder[ED3, VD]
+    val builder = new ExistingEdgePartitionBuilder[ED3, VD](
+      global2local, local2global, vertexAttrs, activeSet)
     var i = 0
     var j = 0
     // For i = index of each edge in `this`...
@@ -233,12 +296,13 @@ class EdgePartition[
         while (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) < dstId) { j += 1 }
         if (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) == dstId) {
           // ... run `f` on the matching edge
-          builder.add(srcId, dstId, f(srcId, dstId, this.data(i), other.data(j)))
+          builder.add(srcId, dstId, localSrcIds(i), localDstIds(i),
+            f(srcId, dstId, this.data(i), other.attrs(j)))
         }
       }
       i += 1
     }
-    builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet)
+    builder.toEdgePartition
   }
 
   /**
@@ -246,7 +310,7 @@ class EdgePartition[
    *
    * @return size of the partition
    */
-  val size: Int = srcIds.size
+  val size: Int = localSrcIds.size
 
   /** The number of unique source vertices in the partition. */
   def indexSize: Int = index.size
@@ -280,55 +344,197 @@ class EdgePartition[
    * It is safe to keep references to the objects from this iterator.
    */
   def tripletIterator(
-      includeSrc: Boolean = true, includeDst: Boolean = true): Iterator[EdgeTriplet[VD, ED]] = {
-    new EdgeTripletIterator(this, includeSrc, includeDst)
+      includeSrc: Boolean = true, includeDst: Boolean = true)
+      : Iterator[EdgeTriplet[VD, ED]] = new Iterator[EdgeTriplet[VD, ED]] {
+    private[this] var pos = 0
+
+    override def hasNext: Boolean = pos < EdgePartition.this.size
+
+    override def next() = {
+      val triplet = new EdgeTriplet[VD, ED]
+      val localSrcId = localSrcIds(pos)
+      val localDstId = localDstIds(pos)
+      triplet.srcId = local2global(localSrcId)
+      triplet.dstId = local2global(localDstId)
+      if (includeSrc) {
+        triplet.srcAttr = vertexAttrs(localSrcId)
+      }
+      if (includeDst) {
+        triplet.dstAttr = vertexAttrs(localDstId)
+      }
+      triplet.attr = data(pos)
+      pos += 1
+      triplet
+    }
   }
 
   /**
-   * Upgrade the given edge iterator into a triplet iterator.
+   * Send messages along edges and aggregate them at the receiving vertices. Implemented by scanning
+   * all edges sequentially.
+   *
+   * @param sendMsg generates messages to neighboring vertices of an edge
+   * @param mergeMsg the combiner applied to messages destined to the same vertex
+   * @param tripletFields which triplet fields `sendMsg` uses
+   * @param srcMustBeActive if true, edges will only be considered if their source vertex is in the
+   *   active set
+   * @param dstMustBeActive if true, edges will only be considered if their destination vertex is in
+   *   the active set
+   * @param maySatisfyEither if true, only one vertex need be in the active set for an edge to be
+   *   considered
    *
-   * Be careful not to keep references to the objects from this iterator. To improve GC performance
-   * the same object is re-used in `next()`.
+   * @return iterator aggregated messages keyed by the receiving vertex id
    */
-  def upgradeIterator(
-      edgeIter: Iterator[Edge[ED]], includeSrc: Boolean = true, includeDst: Boolean = true)
-    : Iterator[EdgeTriplet[VD, ED]] = {
-    new ReusingEdgeTripletIterator(edgeIter, this, includeSrc, includeDst)
+  def aggregateMessagesEdgeScan[A: ClassTag](
+      sendMsg: EdgeContext[VD, ED, A] => Unit,
+      mergeMsg: (A, A) => A,
+      tripletFields: TripletFields,
+      srcMustBeActive: Boolean,
+      dstMustBeActive: Boolean,
+      maySatisfyEither: Boolean): Iterator[(VertexId, A)] = {
+    val aggregates = new Array[A](vertexAttrs.length)
+    val bitset = new BitSet(vertexAttrs.length)
+
+    var ctx = new AggregatingEdgeContext[VD, ED, A](mergeMsg, aggregates, bitset)
+    var i = 0
+    while (i < size) {
+      val localSrcId = localSrcIds(i)
+      val srcId = local2global(localSrcId)
+      val localDstId = localDstIds(i)
+      val dstId = local2global(localDstId)
+      val srcIsActive = !srcMustBeActive || isActive(srcId)
+      val dstIsActive = !dstMustBeActive || isActive(dstId)
+      val edgeIsActive =
+        if (maySatisfyEither) srcIsActive || dstIsActive else srcIsActive && dstIsActive
+      if (edgeIsActive) {
+        val srcAttr = if (tripletFields.useSrc) vertexAttrs(localSrcId) else null.asInstanceOf[VD]
+        val dstAttr = if (tripletFields.useDst) vertexAttrs(localDstId) else null.asInstanceOf[VD]
+        ctx.set(srcId, dstId, localSrcId, localDstId, srcAttr, dstAttr, data(i))
+        sendMsg(ctx)
+      }
+      i += 1
+    }
+
+    bitset.iterator.map { localId => (local2global(localId), aggregates(localId)) }
   }
 
   /**
-   * Get an iterator over the edges in this partition whose source vertex ids match srcIdPred. The
-   * iterator is generated using an index scan, so it is efficient at skipping edges that don't
-   * match srcIdPred.
+   * Send messages along edges and aggregate them at the receiving vertices. Implemented by
+   * filtering the source vertex index, then scanning each edge cluster.
    *
-   * Be careful not to keep references to the objects from this iterator. To improve GC performance
-   * the same object is re-used in `next()`.
-   */
-  def indexIterator(srcIdPred: VertexId => Boolean): Iterator[Edge[ED]] =
-    index.iterator.filter(kv => srcIdPred(kv._1)).flatMap(Function.tupled(clusterIterator))
-
-  /**
-   * Get an iterator over the cluster of edges in this partition with source vertex id `srcId`. The
-   * cluster must start at position `index`.
+   * @param sendMsg generates messages to neighboring vertices of an edge
+   * @param mergeMsg the combiner applied to messages destined to the same vertex
+   * @param tripletFields which triplet fields `sendMsg` uses
+   * @param srcMustBeActive if true, edges will only be considered if their source vertex is in the
+   *   active set
+   * @param dstMustBeActive if true, edges will only be considered if their destination vertex is in
+   *   the active set
+   * @param maySatisfyEither if true, only one vertex need be in the active set for an edge to be
+   *   considered
    *
-   * Be careful not to keep references to the objects from this iterator. To improve GC performance
-   * the same object is re-used in `next()`.
+   * @return iterator aggregated messages keyed by the receiving vertex id
    */
-  private def clusterIterator(srcId: VertexId, index: Int) = new Iterator[Edge[ED]] {
-    private[this] val edge = new Edge[ED]
-    private[this] var pos = index
+  def aggregateMessagesIndexScan[A: ClassTag](
+      sendMsg: EdgeContext[VD, ED, A] => Unit,
+      mergeMsg: (A, A) => A,
+      tripletFields: TripletFields,
+      srcMustBeActive: Boolean,
+      dstMustBeActive: Boolean,
+      maySatisfyEither: Boolean): Iterator[(VertexId, A)] = {
+    val aggregates = new Array[A](vertexAttrs.length)
+    val bitset = new BitSet(vertexAttrs.length)
 
-    override def hasNext: Boolean = {
-      pos >= 0 && pos < EdgePartition.this.size && srcIds(pos) == srcId
+    var ctx = new AggregatingEdgeContext[VD, ED, A](mergeMsg, aggregates, bitset)
+    index.iterator.foreach { cluster =>
+      val clusterSrcId = cluster._1
+      val clusterPos = cluster._2
+      val clusterLocalSrcId = localSrcIds(clusterPos)
+      val srcIsActive = !srcMustBeActive || isActive(clusterSrcId)
+      if (srcIsActive || maySatisfyEither) {
+        var pos = clusterPos
+        val srcAttr =
+          if (tripletFields.useSrc) vertexAttrs(clusterLocalSrcId) else null.asInstanceOf[VD]
+        ctx.setSrcOnly(clusterSrcId, clusterLocalSrcId, srcAttr)
+        while (pos < size && localSrcIds(pos) == clusterLocalSrcId) {
+          val localDstId = localDstIds(pos)
+          val dstId = local2global(localDstId)
+          val dstIsActive = !dstMustBeActive || isActive(dstId)
+          val edgeIsActive =
+            if (maySatisfyEither) srcIsActive || dstIsActive else srcIsActive && dstIsActive
+          if (edgeIsActive) {
+            val dstAttr =
+              if (tripletFields.useDst) vertexAttrs(localDstId) else null.asInstanceOf[VD]
+            ctx.setRest(dstId, localDstId, dstAttr, data(pos))
+            sendMsg(ctx)
+          }
+          pos += 1
+        }
+      }
     }
 
-    override def next(): Edge[ED] = {
-      assert(srcIds(pos) == srcId)
-      edge.srcId = srcIds(pos)
-      edge.dstId = dstIds(pos)
-      edge.attr = data(pos)
-      pos += 1
-      edge
+    bitset.iterator.map { localId => (local2global(localId), aggregates(localId)) }
+  }
+}
+
+private class AggregatingEdgeContext[VD, ED, A](
+    mergeMsg: (A, A) => A,
+    aggregates: Array[A],
+    bitset: BitSet)
+  extends EdgeContext[VD, ED, A] {
+
+  private[this] var _srcId: VertexId = _
+  private[this] var _dstId: VertexId = _
+  private[this] var _localSrcId: Int = _
+  private[this] var _localDstId: Int = _
+  private[this] var _srcAttr: VD = _
+  private[this] var _dstAttr: VD = _
+  private[this] var _attr: ED = _
+
+  def set(
+      srcId: VertexId, dstId: VertexId,
+      localSrcId: Int, localDstId: Int,
+      srcAttr: VD, dstAttr: VD,
+      attr: ED) {
+    _srcId = srcId
+    _dstId = dstId
+    _localSrcId = localSrcId
+    _localDstId = localDstId
+    _srcAttr = srcAttr
+    _dstAttr = dstAttr
+    _attr = attr
+  }
+
+  def setSrcOnly(srcId: VertexId, localSrcId: Int, srcAttr: VD) {
+    _srcId = srcId
+    _localSrcId = localSrcId
+    _srcAttr = srcAttr
+  }
+
+  def setRest(dstId: VertexId, localDstId: Int, dstAttr: VD, attr: ED) {
+    _dstId = dstId
+    _localDstId = localDstId
+    _dstAttr = dstAttr
+    _attr = attr
+  }
+
+  override def srcId = _srcId
+  override def dstId = _dstId
+  override def srcAttr = _srcAttr
+  override def dstAttr = _dstAttr
+  override def attr = _attr
+
+  override def sendToSrc(msg: A) {
+    send(_localSrcId, msg)
+  }
+  override def sendToDst(msg: A) {
+    send(_localDstId, msg)
+  }
+
+  @inline private def send(localId: Int, msg: A) {
+    if (bitset.get(localId)) {
+      aggregates(localId) = mergeMsg(aggregates(localId), msg)
+    } else {
+      aggregates(localId) = msg
+      bitset.set(localId)
     }
   }
 }
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
index 2b6137be2554784d195424ea643cfccdaff4bf8a..b0cb0fe47d4611eded9aa60cab87718b159c047a 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala
@@ -25,10 +25,11 @@ import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveVector}
 import org.apache.spark.graphx._
 import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
 
+/** Constructs an EdgePartition from scratch. */
 private[graphx]
 class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag](
     size: Int = 64) {
-  var edges = new PrimitiveVector[Edge[ED]](size)
+  private[this] val edges = new PrimitiveVector[Edge[ED]](size)
 
   /** Add a new edge to the partition. */
   def add(src: VertexId, dst: VertexId, d: ED) {
@@ -38,8 +39,67 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla
   def toEdgePartition: EdgePartition[ED, VD] = {
     val edgeArray = edges.trim().array
     Sorting.quickSort(edgeArray)(Edge.lexicographicOrdering)
-    val srcIds = new Array[VertexId](edgeArray.size)
-    val dstIds = new Array[VertexId](edgeArray.size)
+    val localSrcIds = new Array[Int](edgeArray.size)
+    val localDstIds = new Array[Int](edgeArray.size)
+    val data = new Array[ED](edgeArray.size)
+    val index = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int]
+    val global2local = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int]
+    val local2global = new PrimitiveVector[VertexId]
+    var vertexAttrs = Array.empty[VD]
+    // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and
+    // adding them to the index. Also populate a map from vertex id to a sequential local offset.
+    if (edgeArray.length > 0) {
+      index.update(edgeArray(0).srcId, 0)
+      var currSrcId: VertexId = edgeArray(0).srcId
+      var currLocalId = -1
+      var i = 0
+      while (i < edgeArray.size) {
+        val srcId = edgeArray(i).srcId
+        val dstId = edgeArray(i).dstId
+        localSrcIds(i) = global2local.changeValue(srcId,
+          { currLocalId += 1; local2global += srcId; currLocalId }, identity)
+        localDstIds(i) = global2local.changeValue(dstId,
+          { currLocalId += 1; local2global += dstId; currLocalId }, identity)
+        data(i) = edgeArray(i).attr
+        if (srcId != currSrcId) {
+          currSrcId = srcId
+          index.update(currSrcId, i)
+        }
+
+        i += 1
+      }
+      vertexAttrs = new Array[VD](currLocalId + 1)
+    }
+    new EdgePartition(
+      localSrcIds, localDstIds, data, index, global2local, local2global.trim().array, vertexAttrs,
+      None)
+  }
+}
+
+/**
+ * Constructs an EdgePartition from an existing EdgePartition with the same vertex set. This enables
+ * reuse of the local vertex ids. Intended for internal use in EdgePartition only.
+ */
+private[impl]
+class ExistingEdgePartitionBuilder[
+    @specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag](
+    global2local: GraphXPrimitiveKeyOpenHashMap[VertexId, Int],
+    local2global: Array[VertexId],
+    vertexAttrs: Array[VD],
+    activeSet: Option[VertexSet],
+    size: Int = 64) {
+  private[this] val edges = new PrimitiveVector[EdgeWithLocalIds[ED]](size)
+
+  /** Add a new edge to the partition. */
+  def add(src: VertexId, dst: VertexId, localSrc: Int, localDst: Int, d: ED) {
+    edges += EdgeWithLocalIds(src, dst, localSrc, localDst, d)
+  }
+
+  def toEdgePartition: EdgePartition[ED, VD] = {
+    val edgeArray = edges.trim().array
+    Sorting.quickSort(edgeArray)(EdgeWithLocalIds.lexicographicOrdering)
+    val localSrcIds = new Array[Int](edgeArray.size)
+    val localDstIds = new Array[Int](edgeArray.size)
     val data = new Array[ED](edgeArray.size)
     val index = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int]
     // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and
@@ -49,8 +109,8 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla
       var currSrcId: VertexId = edgeArray(0).srcId
       var i = 0
       while (i < edgeArray.size) {
-        srcIds(i) = edgeArray(i).srcId
-        dstIds(i) = edgeArray(i).dstId
+        localSrcIds(i) = edgeArray(i).localSrcId
+        localDstIds(i) = edgeArray(i).localDstId
         data(i) = edgeArray(i).attr
         if (edgeArray(i).srcId != currSrcId) {
           currSrcId = edgeArray(i).srcId
@@ -60,13 +120,24 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla
       }
     }
 
-    // Create and populate a VertexPartition with vids from the edges, but no attributes
-    val vidsIter = srcIds.iterator ++ dstIds.iterator
-    val vertexIds = new OpenHashSet[VertexId]
-    vidsIter.foreach(vid => vertexIds.add(vid))
-    val vertices = new VertexPartition(
-      vertexIds, new Array[VD](vertexIds.capacity), vertexIds.getBitSet)
+    new EdgePartition(
+      localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, activeSet)
+  }
+}
 
-    new EdgePartition(srcIds, dstIds, data, index, vertices)
+private[impl] case class EdgeWithLocalIds[@specialized ED](
+    srcId: VertexId, dstId: VertexId, localSrcId: Int, localDstId: Int, attr: ED)
+
+private[impl] object EdgeWithLocalIds {
+  implicit def lexicographicOrdering[ED] = new Ordering[EdgeWithLocalIds[ED]] {
+    override def compare(a: EdgeWithLocalIds[ED], b: EdgeWithLocalIds[ED]): Int = {
+      if (a.srcId == b.srcId) {
+        if (a.dstId == b.dstId) 0
+        else if (a.dstId < b.dstId) -1
+        else 1
+      } else if (a.srcId < b.srcId) -1
+      else 1
+    }
   }
+
 }
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala
deleted file mode 100644
index 56f79a7097fce11483bb3ca97424d8bc125a2e75..0000000000000000000000000000000000000000
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala
+++ /dev/null
@@ -1,84 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.graphx.impl
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.graphx._
-import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
-
-/**
- * The Iterator type returned when constructing edge triplets. This could be an anonymous class in
- * EdgePartition.tripletIterator, but we name it here explicitly so it is easier to debug / profile.
- */
-private[impl]
-class EdgeTripletIterator[VD: ClassTag, ED: ClassTag](
-    val edgePartition: EdgePartition[ED, VD],
-    val includeSrc: Boolean,
-    val includeDst: Boolean)
-  extends Iterator[EdgeTriplet[VD, ED]] {
-
-  // Current position in the array.
-  private var pos = 0
-
-  override def hasNext: Boolean = pos < edgePartition.size
-
-  override def next() = {
-    val triplet = new EdgeTriplet[VD, ED]
-    triplet.srcId = edgePartition.srcIds(pos)
-    if (includeSrc) {
-      triplet.srcAttr = edgePartition.vertices(triplet.srcId)
-    }
-    triplet.dstId = edgePartition.dstIds(pos)
-    if (includeDst) {
-      triplet.dstAttr = edgePartition.vertices(triplet.dstId)
-    }
-    triplet.attr = edgePartition.data(pos)
-    pos += 1
-    triplet
-  }
-}
-
-/**
- * An Iterator type for internal use that reuses EdgeTriplet objects. This could be an anonymous
- * class in EdgePartition.upgradeIterator, but we name it here explicitly so it is easier to debug /
- * profile.
- */
-private[impl]
-class ReusingEdgeTripletIterator[VD: ClassTag, ED: ClassTag](
-    val edgeIter: Iterator[Edge[ED]],
-    val edgePartition: EdgePartition[ED, VD],
-    val includeSrc: Boolean,
-    val includeDst: Boolean)
-  extends Iterator[EdgeTriplet[VD, ED]] {
-
-  private val triplet = new EdgeTriplet[VD, ED]
-
-  override def hasNext = edgeIter.hasNext
-
-  override def next() = {
-    triplet.set(edgeIter.next())
-    if (includeSrc) {
-      triplet.srcAttr = edgePartition.vertices(triplet.srcId)
-    }
-    if (includeDst) {
-      triplet.dstAttr = edgePartition.vertices(triplet.dstId)
-    }
-    triplet
-  }
-}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
index 33f35cfb69a26c01c2311d228fff1c70be263a45..a1fe158b7b490d49b33cc31c2097983e49ec737c 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala
@@ -23,7 +23,6 @@ import org.apache.spark.HashPartitioner
 import org.apache.spark.SparkContext._
 import org.apache.spark.rdd.{RDD, ShuffledRDD}
 import org.apache.spark.storage.StorageLevel
-
 import org.apache.spark.graphx._
 import org.apache.spark.graphx.impl.GraphImpl._
 import org.apache.spark.graphx.util.BytecodeUtils
@@ -127,13 +126,12 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
   }
 
   override def mapTriplets[ED2: ClassTag](
-      f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]): Graph[VD, ED2] = {
+      f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2],
+      tripletFields: TripletFields): Graph[VD, ED2] = {
     vertices.cache()
-    val mapUsesSrcAttr = accessesVertexAttr(f, "srcAttr")
-    val mapUsesDstAttr = accessesVertexAttr(f, "dstAttr")
-    replicatedVertexView.upgrade(vertices, mapUsesSrcAttr, mapUsesDstAttr)
+    replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst)
     val newEdges = replicatedVertexView.edges.mapEdgePartitions { (pid, part) =>
-      part.map(f(pid, part.tripletIterator(mapUsesSrcAttr, mapUsesDstAttr)))
+      part.map(f(pid, part.tripletIterator(tripletFields.useSrc, tripletFields.useDst)))
     }
     new GraphImpl(vertices, replicatedVertexView.withEdges(newEdges))
   }
@@ -171,15 +169,38 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
   override def mapReduceTriplets[A: ClassTag](
       mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
       reduceFunc: (A, A) => A,
-      activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None): VertexRDD[A] = {
+      activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = {
+
+    def sendMsg(ctx: EdgeContext[VD, ED, A]) {
+      mapFunc(ctx.toEdgeTriplet).foreach { kv =>
+        val id = kv._1
+        val msg = kv._2
+        if (id == ctx.srcId) {
+          ctx.sendToSrc(msg)
+        } else {
+          assert(id == ctx.dstId)
+          ctx.sendToDst(msg)
+        }
+      }
+    }
 
-    vertices.cache()
+    val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr")
+    val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr")
+    val tripletFields = new TripletFields(mapUsesSrcAttr, mapUsesDstAttr, true)
+
+    aggregateMessagesWithActiveSet(sendMsg, reduceFunc, tripletFields, activeSetOpt)
+  }
+
+  override def aggregateMessagesWithActiveSet[A: ClassTag](
+      sendMsg: EdgeContext[VD, ED, A] => Unit,
+      mergeMsg: (A, A) => A,
+      tripletFields: TripletFields,
+      activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = {
 
+    vertices.cache()
     // For each vertex, replicate its attribute only to partitions where it is
     // in the relevant position in an edge.
-    val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr")
-    val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr")
-    replicatedVertexView.upgrade(vertices, mapUsesSrcAttr, mapUsesDstAttr)
+    replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst)
     val view = activeSetOpt match {
       case Some((activeSet, _)) =>
         replicatedVertexView.withActiveSet(activeSet)
@@ -193,42 +214,40 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
       case (pid, edgePartition) =>
         // Choose scan method
         val activeFraction = edgePartition.numActives.getOrElse(0) / edgePartition.indexSize.toFloat
-        val edgeIter = activeDirectionOpt match {
+        activeDirectionOpt match {
           case Some(EdgeDirection.Both) =>
             if (activeFraction < 0.8) {
-              edgePartition.indexIterator(srcVertexId => edgePartition.isActive(srcVertexId))
-                .filter(e => edgePartition.isActive(e.dstId))
+              edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields,
+                true, true, false)
             } else {
-              edgePartition.iterator.filter(e =>
-                edgePartition.isActive(e.srcId) && edgePartition.isActive(e.dstId))
+              edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
+                true, true, false)
             }
           case Some(EdgeDirection.Either) =>
             // TODO: Because we only have a clustered index on the source vertex ID, we can't filter
             // the index here. Instead we have to scan all edges and then do the filter.
-            edgePartition.iterator.filter(e =>
-              edgePartition.isActive(e.srcId) || edgePartition.isActive(e.dstId))
+            edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
+              true, true, true)
           case Some(EdgeDirection.Out) =>
             if (activeFraction < 0.8) {
-              edgePartition.indexIterator(srcVertexId => edgePartition.isActive(srcVertexId))
+              edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields,
+                true, false, false)
             } else {
-              edgePartition.iterator.filter(e => edgePartition.isActive(e.srcId))
+              edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
+                true, false, false)
             }
           case Some(EdgeDirection.In) =>
-            edgePartition.iterator.filter(e => edgePartition.isActive(e.dstId))
+            edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
+              false, true, false)
           case _ => // None
-            edgePartition.iterator
+            edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields,
+              false, false, false)
         }
-
-        // Scan edges and run the map function
-        val mapOutputs = edgePartition.upgradeIterator(edgeIter, mapUsesSrcAttr, mapUsesDstAttr)
-          .flatMap(mapFunc(_))
-        // Note: This doesn't allow users to send messages to arbitrary vertices.
-        edgePartition.vertices.aggregateUsingIndex(mapOutputs, reduceFunc).iterator
-    }).setName("GraphImpl.mapReduceTriplets - preAgg")
+    }).setName("GraphImpl.aggregateMessages - preAgg")
 
     // do the final reduction reusing the index map
-    vertices.aggregateUsingIndex(preAgg, reduceFunc)
-  } // end of mapReduceTriplets
+    vertices.aggregateUsingIndex(preAgg, mergeMsg)
+  }
 
   override def outerJoinVertices[U: ClassTag, VD2: ClassTag]
       (other: RDD[(VertexId, U)])
@@ -306,9 +325,7 @@ object GraphImpl {
       vertices: VertexRDD[VD],
       edges: EdgeRDD[ED, _]): GraphImpl[VD, ED] = {
     // Convert the vertex partitions in edges to the correct type
-    val newEdges = edges.mapEdgePartitions(
-      (pid, part) => part.withVertices(part.vertices.map(
-        (vid, attr) => null.asInstanceOf[VD])))
+    val newEdges = edges.mapEdgePartitions((pid, part) => part.withoutVertexAttributes[VD])
     GraphImpl.fromExistingRDDs(vertices, newEdges)
   }
 
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
index 7a7fa91aadfe1948b5a7a41748036beb11e2f6c9..eb3c997e0f3c024973eca1bf76c66de397eac6de 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
@@ -56,11 +56,9 @@ object RoutingTablePartition {
     // Determine which positions each vertex id appears in using a map where the low 2 bits
     // represent src and dst
     val map = new GraphXPrimitiveKeyOpenHashMap[VertexId, Byte]
-    edgePartition.srcIds.iterator.foreach { srcId =>
-      map.changeValue(srcId, 0x1, (b: Byte) => (b | 0x1).toByte)
-    }
-    edgePartition.dstIds.iterator.foreach { dstId =>
-      map.changeValue(dstId, 0x2, (b: Byte) => (b | 0x2).toByte)
+    edgePartition.iterator.foreach { e =>
+      map.changeValue(e.srcId, 0x1, (b: Byte) => (b | 0x1).toByte)
+      map.changeValue(e.dstId, 0x2, (b: Byte) => (b | 0x2).toByte)
     }
     map.iterator.map { vidAndPosition =>
       val vid = vidAndPosition._1
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
index 257e2f3a361154cead41989ad793d5ee58a2bfa7..e40ae0d61546633084d2863e03dadc0ef61d48ff 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
@@ -85,7 +85,7 @@ object PageRank extends Logging {
       // Associate the degree with each vertex
       .outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) }
       // Set the weight on the edges based on the degree
-      .mapTriplets( e => 1.0 / e.srcAttr )
+      .mapTriplets( e => 1.0 / e.srcAttr, TripletFields.SrcOnly )
       // Set the vertex attributes to the initial pagerank values
       .mapVertices( (id, attr) => resetProb )
 
@@ -96,8 +96,8 @@ object PageRank extends Logging {
 
       // Compute the outgoing rank contributions of each vertex, perform local preaggregation, and
       // do the final aggregation at the receiving vertices. Requires a shuffle for aggregation.
-      val rankUpdates = rankGraph.mapReduceTriplets[Double](
-        e => Iterator((e.dstId, e.srcAttr * e.attr)), _ + _)
+      val rankUpdates = rankGraph.aggregateMessages[Double](
+        ctx => ctx.sendToDst(ctx.srcAttr * ctx.attr), _ + _, TripletFields.SrcAndEdge)
 
       // Apply the final rank updates to get the new ranks, using join to preserve ranks of vertices
       // that didn't receive a message. Requires a shuffle for broadcasting updated ranks to the
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
index ccd7de537b6e31fb8b17a3f7d1ef37e868d24055..f58587e10a820f254d197825015c0210c2236fec 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
@@ -74,9 +74,9 @@ object SVDPlusPlus {
     var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache()
 
     // Calculate initial bias and norm
-    val t0 = g.mapReduceTriplets(
-      et => Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr))),
-        (g1: (Long, Double), g2: (Long, Double)) => (g1._1 + g2._1, g1._2 + g2._2))
+    val t0 = g.aggregateMessages[(Long, Double)](
+      ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) },
+      (g1, g2) => (g1._1 + g2._1, g1._2 + g2._2))
 
     g = g.outerJoinVertices(t0) {
       (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double),
@@ -84,15 +84,17 @@ object SVDPlusPlus {
         (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1))
     }
 
-    def mapTrainF(conf: Conf, u: Double)
-        (et: EdgeTriplet[(DoubleMatrix, DoubleMatrix, Double, Double), Double])
-      : Iterator[(VertexId, (DoubleMatrix, DoubleMatrix, Double))] = {
-      val (usr, itm) = (et.srcAttr, et.dstAttr)
+    def sendMsgTrainF(conf: Conf, u: Double)
+        (ctx: EdgeContext[
+          (DoubleMatrix, DoubleMatrix, Double, Double),
+          Double,
+          (DoubleMatrix, DoubleMatrix, Double)]) {
+      val (usr, itm) = (ctx.srcAttr, ctx.dstAttr)
       val (p, q) = (usr._1, itm._1)
       var pred = u + usr._3 + itm._3 + q.dot(usr._2)
       pred = math.max(pred, conf.minVal)
       pred = math.min(pred, conf.maxVal)
-      val err = et.attr - pred
+      val err = ctx.attr - pred
       val updateP = q.mul(err)
         .subColumnVector(p.mul(conf.gamma7))
         .mul(conf.gamma2)
@@ -102,16 +104,16 @@ object SVDPlusPlus {
       val updateY = q.mul(err * usr._4)
         .subColumnVector(itm._2.mul(conf.gamma7))
         .mul(conf.gamma2)
-      Iterator((et.srcId, (updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)),
-        (et.dstId, (updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1)))
+      ctx.sendToSrc((updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1))
+      ctx.sendToDst((updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1))
     }
 
     for (i <- 0 until conf.maxIters) {
       // Phase 1, calculate pu + |N(u)|^(-0.5)*sum(y) for user nodes
       g.cache()
-      val t1 = g.mapReduceTriplets(
-        et => Iterator((et.srcId, et.dstAttr._2)),
-        (g1: DoubleMatrix, g2: DoubleMatrix) => g1.addColumnVector(g2))
+      val t1 = g.aggregateMessages[DoubleMatrix](
+        ctx => ctx.sendToSrc(ctx.dstAttr._2),
+        (g1, g2) => g1.addColumnVector(g2))
       g = g.outerJoinVertices(t1) {
         (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double),
          msg: Option[DoubleMatrix]) =>
@@ -121,8 +123,8 @@ object SVDPlusPlus {
 
       // Phase 2, update p for user nodes and q, y for item nodes
       g.cache()
-      val t2 = g.mapReduceTriplets(
-        mapTrainF(conf, u),
+      val t2 = g.aggregateMessages(
+        sendMsgTrainF(conf, u),
         (g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) =>
           (g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3))
       g = g.outerJoinVertices(t2) {
@@ -135,20 +137,18 @@ object SVDPlusPlus {
     }
 
     // calculate error on training set
-    def mapTestF(conf: Conf, u: Double)
-        (et: EdgeTriplet[(DoubleMatrix, DoubleMatrix, Double, Double), Double])
-      : Iterator[(VertexId, Double)] =
-    {
-      val (usr, itm) = (et.srcAttr, et.dstAttr)
+    def sendMsgTestF(conf: Conf, u: Double)
+        (ctx: EdgeContext[(DoubleMatrix, DoubleMatrix, Double, Double), Double, Double]) {
+      val (usr, itm) = (ctx.srcAttr, ctx.dstAttr)
       val (p, q) = (usr._1, itm._1)
       var pred = u + usr._3 + itm._3 + q.dot(usr._2)
       pred = math.max(pred, conf.minVal)
       pred = math.min(pred, conf.maxVal)
-      val err = (et.attr - pred) * (et.attr - pred)
-      Iterator((et.dstId, err))
+      val err = (ctx.attr - pred) * (ctx.attr - pred)
+      ctx.sendToDst(err)
     }
     g.cache()
-    val t3 = g.mapReduceTriplets(mapTestF(conf, u), (g1: Double, g2: Double) => g1 + g2)
+    val t3 = g.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _)
     g = g.outerJoinVertices(t3) {
       (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) =>
         if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
index 7c396e6e66a289d66e5ab985ba2c4c5b7755a92f..daf162085e3e4334c3781d6183ea771b73b2f266 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
@@ -61,26 +61,27 @@ object TriangleCount {
       (vid, _, optSet) => optSet.getOrElse(null)
     }
     // Edge function computes intersection of smaller vertex with larger vertex
-    def edgeFunc(et: EdgeTriplet[VertexSet, ED]): Iterator[(VertexId, Int)] = {
-      assert(et.srcAttr != null)
-      assert(et.dstAttr != null)
-      val (smallSet, largeSet) = if (et.srcAttr.size < et.dstAttr.size) {
-        (et.srcAttr, et.dstAttr)
+    def edgeFunc(ctx: EdgeContext[VertexSet, ED, Int]) {
+      assert(ctx.srcAttr != null)
+      assert(ctx.dstAttr != null)
+      val (smallSet, largeSet) = if (ctx.srcAttr.size < ctx.dstAttr.size) {
+        (ctx.srcAttr, ctx.dstAttr)
       } else {
-        (et.dstAttr, et.srcAttr)
+        (ctx.dstAttr, ctx.srcAttr)
       }
       val iter = smallSet.iterator
       var counter: Int = 0
       while (iter.hasNext) {
         val vid = iter.next()
-        if (vid != et.srcId && vid != et.dstId && largeSet.contains(vid)) {
+        if (vid != ctx.srcId && vid != ctx.dstId && largeSet.contains(vid)) {
           counter += 1
         }
       }
-      Iterator((et.srcId, counter), (et.dstId, counter))
+      ctx.sendToSrc(counter)
+      ctx.sendToDst(counter)
     }
     // compute the intersection along edges
-    val counters: VertexRDD[Int] = setGraph.mapReduceTriplets(edgeFunc, _ + _)
+    val counters: VertexRDD[Int] = setGraph.aggregateMessages(edgeFunc, _ + _)
     // Merge counters with the graph and divide by two since each triangle is counted twice
     g.outerJoinVertices(counters) {
       (vid, _, optCounter: Option[Int]) =>
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index 6506bac73d71c9aca6467199c24fc7c94431c9c9..df773db6e432606fc05a874306b3a9194c83b1a6 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -118,7 +118,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
       // Each vertex should be replicated to at most 2 * sqrt(p) partitions
       val partitionSets = partitionedGraph.edges.partitionsRDD.mapPartitions { iter =>
         val part = iter.next()._2
-        Iterator((part.srcIds ++ part.dstIds).toSet)
+        Iterator((part.iterator.flatMap(e => Iterator(e.srcId, e.dstId))).toSet)
       }.collect
       if (!verts.forall(id => partitionSets.count(_.contains(id)) <= bound)) {
         val numFailures = verts.count(id => partitionSets.count(_.contains(id)) > bound)
@@ -130,7 +130,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
       // This should not be true for the default hash partitioning
       val partitionSetsUnpartitioned = graph.edges.partitionsRDD.mapPartitions { iter =>
         val part = iter.next()._2
-        Iterator((part.srcIds ++ part.dstIds).toSet)
+        Iterator((part.iterator.flatMap(e => Iterator(e.srcId, e.dstId))).toSet)
       }.collect
       assert(verts.exists(id => partitionSetsUnpartitioned.count(_.contains(id)) > bound))
 
@@ -318,6 +318,21 @@ class GraphSuite extends FunSuite with LocalSparkContext {
     }
   }
 
+  test("aggregateMessages") {
+    withSpark { sc =>
+      val n = 5
+      val agg = starGraph(sc, n).aggregateMessages[String](
+        ctx => {
+          if (ctx.dstAttr != null) {
+            throw new Exception(
+              "expected ctx.dstAttr to be null due to TripletFields, but it was " + ctx.dstAttr)
+          }
+          ctx.sendToDst(ctx.srcAttr)
+        }, _ + _, TripletFields.SrcOnly)
+      assert(agg.collect().toSet === (1 to n).map(x => (x: VertexId, "v")).toSet)
+    }
+  }
+
   test("outerJoinVertices") {
     withSpark { sc =>
       val n = 5
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
index db1dac616008046299b0374453d5d7e89279840a..515f3a9cd02eb3132d3b9a15e457b7312eef211c 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
@@ -82,29 +82,6 @@ class EdgePartitionSuite extends FunSuite {
     assert(edgePartition.groupEdges(_ + _).iterator.map(_.copy()).toList === groupedEdges)
   }
 
-  test("upgradeIterator") {
-    val edges = List((0, 1, 0), (1, 0, 0))
-    val verts = List((0L, 1), (1L, 2))
-    val part = makeEdgePartition(edges).updateVertices(verts.iterator)
-    assert(part.upgradeIterator(part.iterator).map(_.toTuple).toList ===
-      part.tripletIterator().toList.map(_.toTuple))
-  }
-
-  test("indexIterator") {
-    val edgesFrom0 = List(Edge(0, 1, 0))
-    val edgesFrom1 = List(Edge(1, 0, 0), Edge(1, 2, 0))
-    val sortedEdges = edgesFrom0 ++ edgesFrom1
-    val builder = new EdgePartitionBuilder[Int, Nothing]
-    for (e <- Random.shuffle(sortedEdges)) {
-      builder.add(e.srcId, e.dstId, e.attr)
-    }
-
-    val edgePartition = builder.toEdgePartition
-    assert(edgePartition.iterator.map(_.copy()).toList === sortedEdges)
-    assert(edgePartition.indexIterator(_ == 0).map(_.copy()).toList === edgesFrom0)
-    assert(edgePartition.indexIterator(_ == 1).map(_.copy()).toList === edgesFrom1)
-  }
-
   test("innerJoin") {
     val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0))
     val bList = List((0, 1, 0), (1, 0, 0), (1, 1, 0), (3, 4, 0), (5, 5, 0))
@@ -125,8 +102,18 @@ class EdgePartitionSuite extends FunSuite {
     assert(ep.numActives == Some(2))
   }
 
+  test("tripletIterator") {
+    val builder = new EdgePartitionBuilder[Int, Int]
+    builder.add(1, 2, 0)
+    builder.add(1, 3, 0)
+    builder.add(1, 4, 0)
+    val ep = builder.toEdgePartition
+    val result = ep.tripletIterator().toList.map(et => (et.srcId, et.dstId))
+    assert(result === Seq((1, 2), (1, 3), (1, 4)))
+  }
+
   test("serialization") {
-    val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0))
+    val aList = List((0, 1, 1), (1, 0, 2), (1, 2, 3), (5, 4, 4), (5, 5, 5))
     val a: EdgePartition[Int, Int] = makeEdgePartition(aList)
     val javaSer = new JavaSerializer(new SparkConf())
     val conf = new SparkConf()
@@ -135,11 +122,7 @@ class EdgePartitionSuite extends FunSuite {
 
     for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) {
       val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a))
-      assert(aSer.srcIds.toList === a.srcIds.toList)
-      assert(aSer.dstIds.toList === a.dstIds.toList)
-      assert(aSer.data.toList === a.data.toList)
-      assert(aSer.index != null)
-      assert(aSer.vertices.iterator.toSet === a.vertices.iterator.toSet)
+      assert(aSer.tripletIterator().toList === a.tripletIterator().toList)
     }
   }
 }
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala
deleted file mode 100644
index 49b2704390fea17e8562413bc21f2b03f476072d..0000000000000000000000000000000000000000
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.graphx.impl
-
-import scala.reflect.ClassTag
-import scala.util.Random
-
-import org.scalatest.FunSuite
-
-import org.apache.spark.graphx._
-
-class EdgeTripletIteratorSuite extends FunSuite {
-  test("iterator.toList") {
-    val builder = new EdgePartitionBuilder[Int, Int]
-    builder.add(1, 2, 0)
-    builder.add(1, 3, 0)
-    builder.add(1, 4, 0)
-    val iter = new EdgeTripletIterator[Int, Int](builder.toEdgePartition, true, true)
-    val result = iter.toList.map(et => (et.srcId, et.dstId))
-    assert(result === Seq((1, 2), (1, 3), (1, 4)))
-  }
-}