From efd6418c1b99c1ecc2b0a4c72e6430eea4d86260 Mon Sep 17 00:00:00 2001
From: Evan Chan <ev@ooyala.com>
Date: Tue, 23 Jul 2013 10:40:41 -0700
Subject: [PATCH] Move getPersistentRDDs testing to a new Suite

---
 core/src/test/scala/spark/RDDSuite.scala      |  6 --
 .../scala/spark/SparkContextInfoSuite.scala   | 60 +++++++++++++++++++
 2 files changed, 60 insertions(+), 6 deletions(-)
 create mode 100644 core/src/test/scala/spark/SparkContextInfoSuite.scala

diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index ff2dcd72d8..cbddf4e523 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -90,19 +90,15 @@ class RDDSuite extends FunSuite with SharedSparkContext {
   }
 
   test("basic caching") {
-    val origCachedRdds = sc.getCachedRDDs.size
     val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
     assert(rdd.collect().toList === List(1, 2, 3, 4))
     assert(rdd.collect().toList === List(1, 2, 3, 4))
     assert(rdd.collect().toList === List(1, 2, 3, 4))
-    // Should only result in one cached RDD
-    assert(sc.getCachedRDDs.size === origCachedRdds + 1)
   }
 
   test("caching with failures") {
     val onlySplit = new Partition { override def index: Int = 0 }
     var shouldFail = true
-    val origCachedRdds = sc.getCachedRDDs.size
     val rdd = new RDD[Int](sc, Nil) {
       override def getPartitions: Array[Partition] = Array(onlySplit)
       override val getDependencies = List[Dependency[_]]()
@@ -114,14 +110,12 @@ class RDDSuite extends FunSuite with SharedSparkContext {
         }
       }
     }.cache()
-    assert(sc.getCachedRDDs.size === origCachedRdds + 1)
     val thrown = intercept[Exception]{
       rdd.collect()
     }
     assert(thrown.getMessage.contains("injected failure"))
     shouldFail = false
     assert(rdd.collect().toList === List(1, 2, 3, 4))
-    assert(sc.getCachedRDDs.size === origCachedRdds + 1)
   }
 
   test("empty RDD") {
diff --git a/core/src/test/scala/spark/SparkContextInfoSuite.scala b/core/src/test/scala/spark/SparkContextInfoSuite.scala
new file mode 100644
index 0000000000..6d50bf5e1b
--- /dev/null
+++ b/core/src/test/scala/spark/SparkContextInfoSuite.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark
+
+import org.scalatest.FunSuite
+import spark.SparkContext._
+
+class SparkContextInfoSuite extends FunSuite with LocalSparkContext {
+  test("getPersistentRDDs only returns RDDs that are marked as cached") {
+    sc = new SparkContext("local", "test")
+    assert(sc.getPersistentRDDs.isEmpty === true)
+
+    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
+    assert(sc.getPersistentRDDs.isEmpty === true)
+
+    rdd.cache()
+    assert(sc.getPersistentRDDs.size === 1)
+    assert(sc.getPersistentRDDs.values.head === rdd)
+  }
+
+  test("getPersistentRDDs returns an immutable map") {
+    sc = new SparkContext("local", "test")
+    val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
+
+    val myRdds = sc.getPersistentRDDs
+    assert(myRdds.size === 1)
+    assert(myRdds.values.head === rdd1)
+
+    val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache()
+
+    // getPersistentRDDs should have 2 RDDs, but myRdds should not change
+    assert(sc.getPersistentRDDs.size === 2)
+    assert(myRdds.size === 1)
+  }
+
+  test("getRDDStorageInfo only reports on RDDs that actually persist data") {
+    sc = new SparkContext("local", "test")
+    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
+
+    assert(sc.getRDDStorageInfo.size === 0)
+
+    rdd.collect()
+    assert(sc.getRDDStorageInfo.size === 1)
+  }
+}
\ No newline at end of file
-- 
GitLab