diff --git a/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala b/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala
index 145be3c126a3845f91f801364d908a78e2034f0d..2d74ce92e25efc15ca7a37d95b6c27d9d857ebdf 100644
--- a/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala
+++ b/graph/src/test/scala/org/apache/spark/graph/GraphSuite.scala
@@ -8,8 +8,6 @@ import org.apache.spark.graph.LocalSparkContext._
 
 class GraphSuite extends FunSuite with LocalSparkContext {
 
-//  val sc = new SparkContext("local[4]", "test")
-
   System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
   System.setProperty("spark.kryo.registrator", "org.apache.spark.graph.GraphKryoRegistrator")
 
@@ -22,48 +20,57 @@ class GraphSuite extends FunSuite with LocalSparkContext {
     }
   }
 
+  test("mapEdges") {
+    withSpark(new SparkContext("local", "test")) { sc =>
+      val n = 3
+      val star = Graph(sc.parallelize((1 to n).map(x => (0: Vid, x: Vid))))
+      val starWithEdgeAttrs = star.mapEdges(e => e.dstId)
+
+      // map(_.copy()) is a workaround for https://github.com/amplab/graphx/issues/25
+      val edges = starWithEdgeAttrs.edges.map(_.copy()).collect()
+      assert(edges.size === n)
+      assert(edges.toSet === (1 to n).map(x => Edge(0, x, x)).toSet)
+    }
+  }
+
   test("aggregateNeighbors") {
     withSpark(new SparkContext("local", "test")) { sc =>
-      val star = Graph(sc.parallelize(List((0, 1), (0, 2), (0, 3))))
+      val n = 3
+      val star = Graph(sc.parallelize((1 to n).map(x => (0: Vid, x: Vid))))
 
       val indegrees = star.aggregateNeighbors(
         (vid, edge) => Some(1),
         (a: Int, b: Int) => a + b,
-        EdgeDirection.In)// .map((vid, attr) => (vid, attr._2.getOrElse(0)))
-      assert(indegrees.collect().toSet === Set((1, 1), (2, 1), (3, 1))) // (0, 0),
+        EdgeDirection.In)
+      assert(indegrees.collect().toSet === (1 to n).map(x => (x, 1)).toSet)
 
       val outdegrees = star.aggregateNeighbors(
         (vid, edge) => Some(1),
         (a: Int, b: Int) => a + b,
-        EdgeDirection.Out) //.map((vid, attr) => (vid, attr._2.getOrElse(0)))
-      assert(outdegrees.collect().toSet === Set((0, 3))) //, (1, 0), (2, 0), (3, 0)))
+        EdgeDirection.Out)
+      assert(outdegrees.collect().toSet === Set((0, n)))
 
       val noVertexValues = star.aggregateNeighbors[Int](
         (vid: Vid, edge: EdgeTriplet[Int, Int]) => None,
         (a: Int, b: Int) => throw new Exception("reduceFunc called unexpectedly"),
-        EdgeDirection.In)//.map((vid, attr) => (vid, attr))
-      assert(noVertexValues.collect().toSet === Set.empty[(Vid, Int)] ) // ((0, None), (1, None), (2, None), (3, None)))
+        EdgeDirection.In)
+      assert(noVertexValues.collect().toSet === Set.empty[(Vid, Int)])
     }
   }
 
- /* test("joinVertices") {
-    sc = new SparkContext("local", "test")
-    val vertices = sc.parallelize(Seq(Vertex(1, "one"), Vertex(2, "two"), Vertex(3, "three")), 2)
-    val edges = sc.parallelize((Seq(Edge(1, 2, "onetwo"))))
-    val g: Graph[String, String] = new GraphImpl(vertices, edges)
-
-    val tbl = sc.parallelize(Seq((1, 10), (2, 20)))
-    val g1 = g.joinVertices(tbl, (v: Vertex[String], u: Int) => v.data + u)
+  test("joinVertices") {
+    withSpark(new SparkContext("local", "test")) { sc =>
+      val vertices = sc.parallelize(Seq[(Vid, String)]((1, "one"), (2, "two"), (3, "three")), 2)
+      val edges = sc.parallelize((Seq(Edge(1, 2, "onetwo"))))
+      val g: Graph[String, String] = Graph(vertices, edges)
 
-    val v = g1.vertices.collect().sortBy(_.id)
-    assert(v(0).data === "one10")
-    assert(v(1).data === "two20")
-    assert(v(2).data === "three")
+      val tbl = sc.parallelize(Seq[(Vid, Int)]((1, 10), (2, 20)))
+      val g1 = g.joinVertices(tbl) { (vid: Vid, attr: String, u: Int) => attr + u }
 
-    val e = g1.edges.collect()
-    assert(e(0).data === "onetwo")
+      val v = g1.vertices.collect().toSet
+      assert(v === Set((1, "one10"), (2, "two20"), (3, "three")))
+    }
   }
-  */
 
 //  test("graph partitioner") {
 //    sc = new SparkContext("local", "test")