diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
index b9088603100936b7b0627ed2f1fb8cf846d4d43e..796082721d696a46c5bebb7f718576ecef7c3635 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
@@ -151,7 +151,7 @@ object Pregel extends Logging {
       // count the iteration
       i += 1
     }
-
+    messages.unpersist(blocking = false)
     g
   } // end of apply
 
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
index 859f896039047a4125e524e4ad5d05f4539cdff1..f72cbb15242ecc477b26d3c8a222494a6c75c577 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala
@@ -47,9 +47,11 @@ object ConnectedComponents {
       }
     }
     val initialMessage = Long.MaxValue
-    Pregel(ccGraph, initialMessage, activeDirection = EdgeDirection.Either)(
+    val pregelGraph = Pregel(ccGraph, initialMessage, activeDirection = EdgeDirection.Either)(
       vprog = (id, attr, msg) => math.min(attr, msg),
       sendMsg = sendMessage,
       mergeMsg = (a, b) => math.min(a, b))
+    ccGraph.unpersist()
+    pregelGraph
   } // end of connectedComponents
 }
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index 1f5e27d5508b8eb14609e5f8bd177e440a98ff30..2fbc6f069d48d2e0089e37e42812701f9f7a46fb 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -428,4 +428,20 @@ class GraphSuite extends SparkFunSuite with LocalSparkContext {
     }
   }
 
+  test("unpersist graph RDD") {
+    withSpark { sc =>
+      val vert = sc.parallelize(List((1L, "a"), (2L, "b"), (3L, "c")), 1)
+      val edges = sc.parallelize(List(Edge[Long](1L, 2L), Edge[Long](1L, 3L)), 1)
+      val g0 = Graph(vert, edges)
+      val g = g0.partitionBy(PartitionStrategy.EdgePartition2D, 2)
+      val cc = g.connectedComponents()
+      assert(sc.getPersistentRDDs.nonEmpty)
+      cc.unpersist()
+      g.unpersist()
+      g0.unpersist()
+      vert.unpersist()
+      edges.unpersist()
+      assert(sc.getPersistentRDDs.isEmpty)
+    }
+  }
 }