From 4db50e26c75263b2edae468b0e8a9283b5c2e6f1 Mon Sep 17 00:00:00 2001
From: Matei Zaharia <matei@eecs.berkeley.edu>
Date: Fri, 13 May 2011 12:03:58 -0700
Subject: [PATCH] Fixed unit tests by making them clean up the SparkContext
 after use and thus clean up the various singletons (RDDCache,
 MapOutputTracker, etc). This isn't perfect yet (ideally we shouldn't use
 singleton objects at all) but we can fix that later.

---
 bagel/src/test/scala/bagel/BagelSuite.scala   |  2 ++
 .../main/scala/spark/MapOutputTracker.scala   | 12 +++++++++++-
 core/src/main/scala/spark/RDDCache.scala      | 19 ++++++++++++++++---
 core/src/main/scala/spark/SparkContext.scala  |  3 +++
 .../spark/repl/SparkInterpreterLoop.scala     |  5 ++++-
 core/src/test/scala/spark/ShuffleSuite.scala  | 11 +++++++++++
 .../src/test/scala/spark/repl/ReplSuite.scala |  2 ++
 7 files changed, 49 insertions(+), 5 deletions(-)

diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala
index 1b47fc9ed5..9e64d3f136 100644
--- a/bagel/src/test/scala/bagel/BagelSuite.scala
+++ b/bagel/src/test/scala/bagel/BagelSuite.scala
@@ -28,6 +28,7 @@ class BagelSuite extends FunSuite with Assertions {
     })
     for (vert <- result.collect)
       assert(vert.age === numSupersteps)
+    sc.stop()
   }
 
   test("halting by message silence") {
@@ -49,5 +50,6 @@ class BagelSuite extends FunSuite with Assertions {
     })
     for (vert <- result.collect)
       assert(vert.age === numSupersteps)
+    sc.stop()
   }
 }
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index 07fd605cca..4334034ecb 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -9,6 +9,7 @@ import scala.collection.mutable.HashSet
 
 sealed trait MapOutputTrackerMessage
 case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage 
+case object StopMapOutputTracker extends MapOutputTrackerMessage
 
 class MapOutputTracker(serverUris: ConcurrentHashMap[Int, Array[String]])
 extends DaemonActor with Logging {
@@ -23,6 +24,9 @@ extends DaemonActor with Logging {
         case GetMapOutputLocations(shuffleId: Int) =>
           logInfo("Asked to get map output locations for shuffle " + shuffleId)
           reply(serverUris.get(shuffleId))
+        case StopMapOutputTracker =>
+          reply('OK)
+          exit()
       }
     }
   }
@@ -95,4 +99,10 @@ object MapOutputTracker extends Logging {
   def getMapOutputUri(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int): String = {
     "%s/shuffle/%s/%s/%s".format(serverUri, shuffleId, mapId, reduceId)
   }
-}
\ No newline at end of file
+
+  def stop() {
+    trackerActor !? StopMapOutputTracker
+    serverUris.clear()
+    trackerActor = null
+  }
+}
diff --git a/core/src/main/scala/spark/RDDCache.scala b/core/src/main/scala/spark/RDDCache.scala
index d6c63e61ec..c5557159a6 100644
--- a/core/src/main/scala/spark/RDDCache.scala
+++ b/core/src/main/scala/spark/RDDCache.scala
@@ -12,6 +12,7 @@ case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends Ca
 case class MemoryCacheLost(host: String) extends CacheMessage
 case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheMessage
 case object GetCacheLocations extends CacheMessage
+case object StopCacheTracker extends CacheMessage
 
 class RDDCacheTracker extends DaemonActor with Logging {
   val locs = new HashMap[Int, Array[List[String]]]
@@ -50,6 +51,10 @@ class RDDCacheTracker extends DaemonActor with Logging {
             locsCopy(rddId) = array.clone()
           }
           reply(locsCopy)
+
+        case StopCacheTracker =>
+          reply('OK)
+          exit()
       }
     }
   }
@@ -57,15 +62,15 @@ class RDDCacheTracker extends DaemonActor with Logging {
 
 private object RDDCache extends Logging {
   // Stores map results for various splits locally
-  val cache = Cache.newKeySpace()
+  var cache: KeySpace = null
 
-  // Remembers which splits are currently being loaded
+  // Remembers which splits are currently being loaded (on worker nodes)
   val loading = new HashSet[(Int, Int)]
   
   // Tracker actor on the master, or remote reference to it on workers
   var trackerActor: AbstractActor = null
   
-  val registeredRddIds = new HashSet[Int]
+  var registeredRddIds: HashSet[Int] = null
   
   def initialize(isMaster: Boolean) {
     if (isMaster) {
@@ -77,6 +82,8 @@ private object RDDCache extends Logging {
       val port = System.getProperty("spark.master.port").toInt
       trackerActor = RemoteActor.select(Node(host, port), 'RDDCacheTracker)
     }
+    registeredRddIds = new HashSet[Int]
+    cache = Cache.newKeySpace()
   }
   
   // Registers an RDD (on master only)
@@ -138,4 +145,10 @@ private object RDDCache extends Logging {
       return Iterator.fromArray(array)
     }
   }
+
+  def stop() {
+    trackerActor !? StopCacheTracker
+    registeredRddIds.clear()
+    trackerActor = null
+  }
 }
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 9357db22c4..dc6964e14b 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -121,6 +121,9 @@ extends Logging {
   def stop() {
      scheduler.stop()
      scheduler = null
+     // TODO: Broadcast.stop(), Cache.stop()?
+     MapOutputTracker.stop()
+     RDDCache.stop()
   }
 
   // Wait for the scheduler to be registered
diff --git a/core/src/main/scala/spark/repl/SparkInterpreterLoop.scala b/core/src/main/scala/spark/repl/SparkInterpreterLoop.scala
index d4974009ce..a118abf3ca 100644
--- a/core/src/main/scala/spark/repl/SparkInterpreterLoop.scala
+++ b/core/src/main/scala/spark/repl/SparkInterpreterLoop.scala
@@ -260,6 +260,8 @@ extends InterpreterControl {
     plushln("Type :help for more information.")
   }
 
+  var sparkContext: SparkContext = null
+
   def createSparkContext(): SparkContext = {
     val master = this.master match {
       case Some(m) => m
@@ -268,7 +270,8 @@ extends InterpreterControl {
         if (prop != null) prop else "local"
       }
     }
-    new SparkContext(master, "Spark shell")
+    sparkContext = new SparkContext(master, "Spark shell")
+    sparkContext
   }
 
   /** The main read-eval-print loop for the interpreter.  It calls
diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala
index a5773614e8..3089360756 100644
--- a/core/src/test/scala/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/spark/ShuffleSuite.scala
@@ -18,6 +18,7 @@ class ShuffleSuite extends FunSuite {
     assert(valuesFor1.toList.sorted === List(1, 2, 3))
     val valuesFor2 = groups.find(_._1 == 2).get._2
     assert(valuesFor2.toList.sorted === List(1))
+    sc.stop()
   }
 
   test("groupByKey with duplicates") {
@@ -29,6 +30,7 @@ class ShuffleSuite extends FunSuite {
     assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
     val valuesFor2 = groups.find(_._1 == 2).get._2
     assert(valuesFor2.toList.sorted === List(1))
+    sc.stop()
   }
 
   test("groupByKey with negative key hash codes") {
@@ -40,6 +42,7 @@ class ShuffleSuite extends FunSuite {
     assert(valuesForMinus1.toList.sorted === List(1, 2, 3))
     val valuesFor2 = groups.find(_._1 == 2).get._2
     assert(valuesFor2.toList.sorted === List(1))
+    sc.stop()
   }
 
   test("groupByKey with many output partitions") {
@@ -51,6 +54,7 @@ class ShuffleSuite extends FunSuite {
     assert(valuesFor1.toList.sorted === List(1, 2, 3))
     val valuesFor2 = groups.find(_._1 == 2).get._2
     assert(valuesFor2.toList.sorted === List(1))
+    sc.stop()
   }
 
   test("reduceByKey") {
@@ -58,6 +62,7 @@ class ShuffleSuite extends FunSuite {
     val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
     val sums = pairs.reduceByKey(_+_).collect()
     assert(sums.toSet === Set((1, 7), (2, 1)))
+    sc.stop()
   }
 
   test("reduceByKey with collectAsMap") {
@@ -67,6 +72,7 @@ class ShuffleSuite extends FunSuite {
     assert(sums.size === 2)
     assert(sums(1) === 7)
     assert(sums(2) === 1)
+    sc.stop()
   }
 
   test("reduceByKey with many output partitons") {
@@ -74,6 +80,7 @@ class ShuffleSuite extends FunSuite {
     val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
     val sums = pairs.reduceByKey(_+_, 10).collect()
     assert(sums.toSet === Set((1, 7), (2, 1)))
+    sc.stop()
   }
 
   test("join") {
@@ -88,6 +95,7 @@ class ShuffleSuite extends FunSuite {
       (2, (1, 'y')),
       (2, (1, 'z'))
     ))
+    sc.stop()
   }
 
   test("join all-to-all") {
@@ -104,6 +112,7 @@ class ShuffleSuite extends FunSuite {
       (1, (3, 'x')),
       (1, (3, 'y'))
     ))
+    sc.stop()
   }
 
   test("join with no matches") {
@@ -112,6 +121,7 @@ class ShuffleSuite extends FunSuite {
     val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
     val joined = rdd1.join(rdd2).collect()
     assert(joined.size === 0)
+    sc.stop()
   }
 
   test("join with many output partitions") {
@@ -126,5 +136,6 @@ class ShuffleSuite extends FunSuite {
       (2, (1, 'y')),
       (2, (1, 'z'))
     ))
+    sc.stop()
   }
 }
diff --git a/core/src/test/scala/spark/repl/ReplSuite.scala b/core/src/test/scala/spark/repl/ReplSuite.scala
index 225e766c71..829b1d934e 100644
--- a/core/src/test/scala/spark/repl/ReplSuite.scala
+++ b/core/src/test/scala/spark/repl/ReplSuite.scala
@@ -27,6 +27,8 @@ class ReplSuite extends FunSuite {
     val separator = System.getProperty("path.separator")
     interp.main(Array("-classpath", paths.mkString(separator)))
     spark.repl.Main.interp = null
+    if (interp.sparkContext != null)
+      interp.sparkContext.stop()
     return out.toString
   }
   
-- 
GitLab