diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
index 97a82239a97ee5db9d8e3ba7ff956bf8e7ce2e62..3e8c38530252bbf6b174e859f28828deb94003cc 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
@@ -17,10 +17,15 @@
 
 package org.apache.spark.graphx
 
+import scala.reflect.{classTag, ClassTag}
 import scala.reflect.ClassTag
 import scala.util.Random
 
+import org.apache.spark.HashPartitioner
+import org.apache.spark.SparkContext._
 import org.apache.spark.SparkException
+import org.apache.spark.graphx.impl.EdgePartitionBuilder
+import org.apache.spark.graphx.impl.GraphImpl
 import org.apache.spark.graphx.lib._
 import org.apache.spark.rdd.RDD
 
@@ -183,6 +188,15 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
     }
   }
 
+  /**
+   * Remove self edges.
+   *
+   * @return a graph with all self edges removed
+   */
+  def removeSelfEdges(): Graph[VD, ED] = {
+    graph.subgraph(epred = e => e.srcId != e.dstId)
+  }
+
   /**
    * Join the vertices with an RDD and then apply a function from the
    * vertex and RDD entry to a new vertex value.  The input table
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
index a5d598053f9ca304649d98c916f9247dd10c82ef..51bcdf20dec46b68d82cc4cd03cebd46bb21236d 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala
@@ -20,6 +20,7 @@ package org.apache.spark.graphx.lib
 import scala.reflect.ClassTag
 
 import org.apache.spark.graphx._
+import org.apache.spark.graphx.PartitionStrategy.EdgePartition2D
 
 /**
  * Compute the number of triangles passing through each vertex.
@@ -27,25 +28,47 @@ import org.apache.spark.graphx._
  * The algorithm is relatively straightforward and can be computed in three steps:
  *
  * <ul>
- * <li>Compute the set of neighbors for each vertex
- * <li>For each edge compute the intersection of the sets and send the count to both vertices.
- * <li> Compute the sum at each vertex and divide by two since each triangle is counted twice.
+ * <li> Compute the set of neighbors for each vertex</li>
+ * <li> For each edge compute the intersection of the sets and send the count to both vertices.</li>
+ * <li> Compute the sum at each vertex and divide by two since each triangle is counted twice.</li>
  * </ul>
  *
- * Note that the input graph should have its edges in canonical direction
- * (i.e. the `sourceId` less than `destId`). Also the graph must have been partitioned
- * using [[org.apache.spark.graphx.Graph#partitionBy]].
+ * There are two implementations.  The default `TriangleCount.run` implementation first removes
+ * self cycles and canonicalizes the graph to ensure that the following conditions hold:
+ * <ul>
+ * <li> There are no self edges</li>
+ * <li> All edges are oriented src > dst</li>
+ * <li> There are no duplicate edges</li>
+ * </ul>
+ * However, the canonicalization procedure is costly as it requires repartitioning the graph.
+ * If the input data is already in "canonical form" with self cycles removed then the
+ * `TriangleCount.runPreCanonicalized` should be used instead.
+ *
+ * {{{
+ * val canonicalGraph = graph.mapEdges(e => 1).removeSelfEdges().canonicalizeEdges()
+ * val counts = TriangleCount.runPreCanonicalized(canonicalGraph).vertices
+ * }}}
+ *
  */
 object TriangleCount {
 
   def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[Int, ED] = {
-    // Remove redundant edges
-    val g = graph.groupEdges((a, b) => a).cache()
+    // Transform the edge data something cheap to shuffle and then canonicalize
+    val canonicalGraph = graph.mapEdges(e => true).removeSelfEdges().convertToCanonicalEdges()
+    // Get the triangle counts
+    val counters = runPreCanonicalized(canonicalGraph).vertices
+    // Join them bath with the original graph
+    graph.outerJoinVertices(counters) { (vid, _, optCounter: Option[Int]) =>
+      optCounter.getOrElse(0)
+    }
+  }
 
+
+  def runPreCanonicalized[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[Int, ED] = {
     // Construct set representations of the neighborhoods
     val nbrSets: VertexRDD[VertexSet] =
-      g.collectNeighborIds(EdgeDirection.Either).mapValues { (vid, nbrs) =>
-        val set = new VertexSet(4)
+      graph.collectNeighborIds(EdgeDirection.Either).mapValues { (vid, nbrs) =>
+        val set = new VertexSet(nbrs.length)
         var i = 0
         while (i < nbrs.size) {
           // prevent self cycle
@@ -56,14 +79,14 @@ object TriangleCount {
         }
         set
       }
+
     // join the sets with the graph
-    val setGraph: Graph[VertexSet, ED] = g.outerJoinVertices(nbrSets) {
+    val setGraph: Graph[VertexSet, ED] = graph.outerJoinVertices(nbrSets) {
       (vid, _, optSet) => optSet.getOrElse(null)
     }
+
     // Edge function computes intersection of smaller vertex with larger vertex
     def edgeFunc(ctx: EdgeContext[VertexSet, ED, Int]) {
-      assert(ctx.srcAttr != null)
-      assert(ctx.dstAttr != null)
       val (smallSet, largeSet) = if (ctx.srcAttr.size < ctx.dstAttr.size) {
         (ctx.srcAttr, ctx.dstAttr)
       } else {
@@ -80,15 +103,15 @@ object TriangleCount {
       ctx.sendToSrc(counter)
       ctx.sendToDst(counter)
     }
+
     // compute the intersection along edges
     val counters: VertexRDD[Int] = setGraph.aggregateMessages(edgeFunc, _ + _)
     // Merge counters with the graph and divide by two since each triangle is counted twice
-    g.outerJoinVertices(counters) {
-      (vid, _, optCounter: Option[Int]) =>
-        val dblCount = optCounter.getOrElse(0)
-        // double count should be even (divisible by two)
-        assert((dblCount & 1) == 0)
-        dblCount / 2
+    graph.outerJoinVertices(counters) { (_, _, optCounter: Option[Int]) =>
+      val dblCount = optCounter.getOrElse(0)
+      // This algorithm double counts each triangle so the final count should be even
+      require(dblCount % 2 == 0, "Triangle count resulted in an invalid number of triangles.")
+      dblCount / 2
     }
-  } // end of TriangleCount
+  }
 }
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
index 57a8b95dd12e956e9727575f143582efc3dffaaa..3967f6683de71d2f8ecb82659692b44d87fa6f69 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala
@@ -55,6 +55,21 @@ class GraphOpsSuite extends SparkFunSuite with LocalSparkContext {
     }
   }
 
+  test("removeSelfEdges") {
+    withSpark { sc =>
+      val edgeArray = Array((1 -> 2), (2 -> 3), (3 -> 3), (4 -> 3), (1 -> 1))
+        .map {
+          case (a, b) => (a.toLong, b.toLong)
+        }
+      val correctEdges = edgeArray.filter { case (a, b) => a != b }.toSet
+      val graph = Graph.fromEdgeTuples(sc.parallelize(edgeArray), 1)
+      val canonicalizedEdges = graph.removeSelfEdges().edges.map(e => (e.srcId, e.dstId))
+        .collect
+      assert(canonicalizedEdges.toSet.size === canonicalizedEdges.size)
+      assert(canonicalizedEdges.toSet === correctEdges)
+    }
+  }
+
   test ("filter") {
     withSpark { sc =>
       val n = 5
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala
index 608e43cf3ff539bb72e4755620977dcddd14a6ee..f19c3acdc85cfa7684edd0d12e66f03df950d45d 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/TriangleCountSuite.scala
@@ -64,9 +64,9 @@ class TriangleCountSuite extends SparkFunSuite with LocalSparkContext {
       val verts = triangleCount.vertices
       verts.collect().foreach { case (vid, count) =>
         if (vid == 0) {
-          assert(count === 4)
-        } else {
           assert(count === 2)
+        } else {
+          assert(count === 1)
         }
       }
     }
@@ -75,7 +75,8 @@ class TriangleCountSuite extends SparkFunSuite with LocalSparkContext {
   test("Count a single triangle with duplicate edges") {
     withSpark { sc =>
       val rawEdges = sc.parallelize(Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
-        Array(0L -> 1L, 1L -> 2L, 2L -> 0L), 2)
+        Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
+        Array(1L -> 0L, 1L -> 1L), 2)
       val graph = Graph.fromEdgeTuples(rawEdges, true, uniqueEdges = Some(RandomVertexCut)).cache()
       val triangleCount = graph.triangleCount()
       val verts = triangleCount.vertices