diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index eeaf1d7c11e5d80763d9cfc45fd5c606b2e79ea1..b43aca2b97facac3b68dd84355469027f3ed78dd 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -271,6 +271,7 @@ class SparkContext( env.shuffleManager.stop() env.blockManager.stop() BlockManagerMaster.stopBlockManagerMaster() + env.connectionManager.stop() SparkEnv.set(null) ShuffleMapTask.clearCache() } diff --git a/core/src/main/scala/spark/network/ConnectionManager.scala b/core/src/main/scala/spark/network/ConnectionManager.scala index a5a707a57def3f50917ff0c44f85ac52c10c04b2..3222187990eaa630997d012a0babbcd1a61cbb20 100644 --- a/core/src/main/scala/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/spark/network/ConnectionManager.scala @@ -68,7 +68,8 @@ class ConnectionManager(port: Int) extends Logging { def run() { try { - while(!selectorThread.isInterrupted) { + var interrupted = false + while(!interrupted) { while(!connectionRequests.isEmpty) { val sendingConnection = connectionRequests.dequeue sendingConnection.connect() @@ -102,14 +103,10 @@ class ConnectionManager(port: Int) extends Logging { } val selectedKeysCount = selector.select() - if (selectedKeysCount == 0) { - logInfo("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys") - } - if (selectorThread.isInterrupted) { - logInfo("Selector thread was interrupted!") - return - } + if (selectedKeysCount == 0) logInfo("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys") + interrupted = selectorThread.isInterrupted + val selectedKeys = selector.selectedKeys().iterator() while (selectedKeys.hasNext()) { val key = selectedKeys.next.asInstanceOf[SelectionKey] @@ -333,16 +330,18 @@ class ConnectionManager(port: Int) extends Logging { } def stop() { - selectorThread.interrupt() - selectorThread.join() - selector.close() - val connections = connectionsByKey.values - connections.foreach(_.close()) - if (connectionsByKey.size != 0) { - logWarning("All connections not cleaned up") + if (!selectorThread.isAlive) { + selectorThread.interrupt() + selectorThread.join() + selector.close() + val connections = connectionsByKey.values + connections.foreach(_.close()) + if (connectionsByKey.size != 0) { + logWarning("All connections not cleaned up") + } + handleMessageExecutor.shutdown() + logInfo("ConnectionManager stopped") } - handleMessageExecutor.shutdown() - logInfo("ConnectionManager stopped") } } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index fc8adbc517c0ddfb79504b4b700e76f366f905bf..f9d53d3b5d4457a975b696552af87e3ded3f7bc5 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -223,7 +223,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with * events and responds by launching tasks. This runs in a dedicated thread and receives events * via the eventQueue. */ - def run() { + def run() = { SparkEnv.set(env) while (true) { @@ -258,14 +258,6 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with case completion: CompletionEvent => handleTaskCompletion(completion) - case StopDAGScheduler => - // Cancel any active jobs - for (job <- activeJobs) { - val error = new SparkException("Job cancelled because SparkContext was shut down") - job.listener.jobFailed(error) - } - return - case null => // queue.poll() timed out, ignore it } @@ -537,7 +529,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with } def stop() { - eventQueue.put(StopDAGScheduler) + // TODO: Put a stop event on our queue and break the event loop taskSched.stop() } } diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala index 0fc73059c347929471fa2de5be0bf07302e78abe..c10abc92028993d9200676d60139493ee5df5f62 100644 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala @@ -28,5 +28,3 @@ case class CompletionEvent( extends DAGSchedulerEvent case class HostLost(host: String) extends DAGSchedulerEvent - -case object StopDAGScheduler extends DAGSchedulerEvent diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala index 1a47f3fddf7c147668281c2957276f25189cbc3b..8339c0ae9025aab942f26f97a078d31235f99613 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala @@ -48,20 +48,14 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with // Serialize and deserialize the task so that accumulators are changed to thread-local ones; // this adds a bit of unnecessary overhead but matches how the Mesos Executor works. Accumulators.clear - val ser = SparkEnv.get.closureSerializer.newInstance() - val bytes = ser.serialize(task) - logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes") - val deserializedTask = ser.deserialize[Task[_]]( + val bytes = Utils.serialize(task) + logInfo("Size of task " + idInJob + " is " + bytes.size + " bytes") + val deserializedTask = Utils.deserialize[Task[_]]( bytes, Thread.currentThread.getContextClassLoader) val result: Any = deserializedTask.run(attemptId) - // Serialize and deserialize the result to emulate what the Mesos - // executor does. This is useful to catch serialization errors early - // on in development (so when users move their local Spark programs - // to the cluster, they don't get surprised by serialization errors). - val resultToReturn = ser.deserialize[Any](ser.serialize(result)) val accumUpdates = Accumulators.values logInfo("Finished task " + idInJob) - listener.taskEnded(task, Success, resultToReturn, accumUpdates) + listener.taskEnded(task, Success, result, accumUpdates) } catch { case t: Throwable => { logError("Exception in task " + idInJob, t) @@ -83,9 +77,7 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with } } - override def stop() { - threadPool.shutdownNow() - } + override def stop() {} override def defaultParallelism() = threads }