diff --git a/bagel/src/test/scala/bagel/BagelSuite.scala b/bagel/src/test/scala/bagel/BagelSuite.scala index 1b47fc9ed5478d728157312dc7ebb7fb8cd62912..9e64d3f1362639873a760b186ae06dc2824483b8 100644 --- a/bagel/src/test/scala/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/bagel/BagelSuite.scala @@ -28,6 +28,7 @@ class BagelSuite extends FunSuite with Assertions { }) for (vert <- result.collect) assert(vert.age === numSupersteps) + sc.stop() } test("halting by message silence") { @@ -49,5 +50,6 @@ class BagelSuite extends FunSuite with Assertions { }) for (vert <- result.collect) assert(vert.age === numSupersteps) + sc.stop() } } diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 07fd605cca0af1ab61c3a6eee734141b1dedc497..4334034ecbc1c93af79ae1d616267de3b517b221 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -9,6 +9,7 @@ import scala.collection.mutable.HashSet sealed trait MapOutputTrackerMessage case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage +case object StopMapOutputTracker extends MapOutputTrackerMessage class MapOutputTracker(serverUris: ConcurrentHashMap[Int, Array[String]]) extends DaemonActor with Logging { @@ -23,6 +24,9 @@ extends DaemonActor with Logging { case GetMapOutputLocations(shuffleId: Int) => logInfo("Asked to get map output locations for shuffle " + shuffleId) reply(serverUris.get(shuffleId)) + case StopMapOutputTracker => + reply('OK) + exit() } } } @@ -95,4 +99,10 @@ object MapOutputTracker extends Logging { def getMapOutputUri(serverUri: String, shuffleId: Int, mapId: Int, reduceId: Int): String = { "%s/shuffle/%s/%s/%s".format(serverUri, shuffleId, mapId, reduceId) } -} \ No newline at end of file + + def stop() { + trackerActor !? StopMapOutputTracker + serverUris.clear() + trackerActor = null + } +} diff --git a/core/src/main/scala/spark/RDDCache.scala b/core/src/main/scala/spark/RDDCache.scala index d6c63e61ec6bd2708e0fd11be3f282e14333f097..c5557159a6a02216600263f0d5c29277e50ea820 100644 --- a/core/src/main/scala/spark/RDDCache.scala +++ b/core/src/main/scala/spark/RDDCache.scala @@ -12,6 +12,7 @@ case class DroppedFromCache(rddId: Int, partition: Int, host: String) extends Ca case class MemoryCacheLost(host: String) extends CacheMessage case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheMessage case object GetCacheLocations extends CacheMessage +case object StopCacheTracker extends CacheMessage class RDDCacheTracker extends DaemonActor with Logging { val locs = new HashMap[Int, Array[List[String]]] @@ -50,6 +51,10 @@ class RDDCacheTracker extends DaemonActor with Logging { locsCopy(rddId) = array.clone() } reply(locsCopy) + + case StopCacheTracker => + reply('OK) + exit() } } } @@ -57,15 +62,15 @@ class RDDCacheTracker extends DaemonActor with Logging { private object RDDCache extends Logging { // Stores map results for various splits locally - val cache = Cache.newKeySpace() + var cache: KeySpace = null - // Remembers which splits are currently being loaded + // Remembers which splits are currently being loaded (on worker nodes) val loading = new HashSet[(Int, Int)] // Tracker actor on the master, or remote reference to it on workers var trackerActor: AbstractActor = null - val registeredRddIds = new HashSet[Int] + var registeredRddIds: HashSet[Int] = null def initialize(isMaster: Boolean) { if (isMaster) { @@ -77,6 +82,8 @@ private object RDDCache extends Logging { val port = System.getProperty("spark.master.port").toInt trackerActor = RemoteActor.select(Node(host, port), 'RDDCacheTracker) } + registeredRddIds = new HashSet[Int] + cache = Cache.newKeySpace() } // Registers an RDD (on master only) @@ -138,4 +145,10 @@ private object RDDCache extends Logging { return Iterator.fromArray(array) } } + + def stop() { + trackerActor !? StopCacheTracker + registeredRddIds.clear() + trackerActor = null + } } diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 9357db22c4842185d07d6384dd8fe8957803c807..dc6964e14bd28443508fc667b7fbd8652a6ea154 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -121,6 +121,9 @@ extends Logging { def stop() { scheduler.stop() scheduler = null + // TODO: Broadcast.stop(), Cache.stop()? + MapOutputTracker.stop() + RDDCache.stop() } // Wait for the scheduler to be registered diff --git a/core/src/main/scala/spark/repl/SparkInterpreterLoop.scala b/core/src/main/scala/spark/repl/SparkInterpreterLoop.scala index d4974009ce055afc5d682fe952f5c78f7055d61f..a118abf3cad27d19cc02fcfcccefacbb1a3ae027 100644 --- a/core/src/main/scala/spark/repl/SparkInterpreterLoop.scala +++ b/core/src/main/scala/spark/repl/SparkInterpreterLoop.scala @@ -260,6 +260,8 @@ extends InterpreterControl { plushln("Type :help for more information.") } + var sparkContext: SparkContext = null + def createSparkContext(): SparkContext = { val master = this.master match { case Some(m) => m @@ -268,7 +270,8 @@ extends InterpreterControl { if (prop != null) prop else "local" } } - new SparkContext(master, "Spark shell") + sparkContext = new SparkContext(master, "Spark shell") + sparkContext } /** The main read-eval-print loop for the interpreter. It calls diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index a5773614e86aafe8b32075989427c3558bfa6209..308936075690596607862ec9551fd97985e5f91c 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -18,6 +18,7 @@ class ShuffleSuite extends FunSuite { assert(valuesFor1.toList.sorted === List(1, 2, 3)) val valuesFor2 = groups.find(_._1 == 2).get._2 assert(valuesFor2.toList.sorted === List(1)) + sc.stop() } test("groupByKey with duplicates") { @@ -29,6 +30,7 @@ class ShuffleSuite extends FunSuite { assert(valuesFor1.toList.sorted === List(1, 1, 2, 3)) val valuesFor2 = groups.find(_._1 == 2).get._2 assert(valuesFor2.toList.sorted === List(1)) + sc.stop() } test("groupByKey with negative key hash codes") { @@ -40,6 +42,7 @@ class ShuffleSuite extends FunSuite { assert(valuesForMinus1.toList.sorted === List(1, 2, 3)) val valuesFor2 = groups.find(_._1 == 2).get._2 assert(valuesFor2.toList.sorted === List(1)) + sc.stop() } test("groupByKey with many output partitions") { @@ -51,6 +54,7 @@ class ShuffleSuite extends FunSuite { assert(valuesFor1.toList.sorted === List(1, 2, 3)) val valuesFor2 = groups.find(_._1 == 2).get._2 assert(valuesFor2.toList.sorted === List(1)) + sc.stop() } test("reduceByKey") { @@ -58,6 +62,7 @@ class ShuffleSuite extends FunSuite { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) val sums = pairs.reduceByKey(_+_).collect() assert(sums.toSet === Set((1, 7), (2, 1))) + sc.stop() } test("reduceByKey with collectAsMap") { @@ -67,6 +72,7 @@ class ShuffleSuite extends FunSuite { assert(sums.size === 2) assert(sums(1) === 7) assert(sums(2) === 1) + sc.stop() } test("reduceByKey with many output partitons") { @@ -74,6 +80,7 @@ class ShuffleSuite extends FunSuite { val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1))) val sums = pairs.reduceByKey(_+_, 10).collect() assert(sums.toSet === Set((1, 7), (2, 1))) + sc.stop() } test("join") { @@ -88,6 +95,7 @@ class ShuffleSuite extends FunSuite { (2, (1, 'y')), (2, (1, 'z')) )) + sc.stop() } test("join all-to-all") { @@ -104,6 +112,7 @@ class ShuffleSuite extends FunSuite { (1, (3, 'x')), (1, (3, 'y')) )) + sc.stop() } test("join with no matches") { @@ -112,6 +121,7 @@ class ShuffleSuite extends FunSuite { val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w'))) val joined = rdd1.join(rdd2).collect() assert(joined.size === 0) + sc.stop() } test("join with many output partitions") { @@ -126,5 +136,6 @@ class ShuffleSuite extends FunSuite { (2, (1, 'y')), (2, (1, 'z')) )) + sc.stop() } } diff --git a/core/src/test/scala/spark/repl/ReplSuite.scala b/core/src/test/scala/spark/repl/ReplSuite.scala index 225e766c7114494bb7b67c757145c87a7c83a39e..829b1d934eec9f4394ab96123dc646ab60fa4e7c 100644 --- a/core/src/test/scala/spark/repl/ReplSuite.scala +++ b/core/src/test/scala/spark/repl/ReplSuite.scala @@ -27,6 +27,8 @@ class ReplSuite extends FunSuite { val separator = System.getProperty("path.separator") interp.main(Array("-classpath", paths.mkString(separator))) spark.repl.Main.interp = null + if (interp.sparkContext != null) + interp.sparkContext.stop() return out.toString }