Skip to content
Snippets Groups Projects
Commit 2893b305 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Various fixes to get unit tests running. In particular, shut down

ConnectionManager and DAGScheduler properly, plus a fix to
LocalScheduler that was not merged in from 0.5 and was actually caught
by one of the tests.
parent 08579ffa
No related branches found
No related tags found
No related merge requests found
......@@ -271,7 +271,6 @@ class SparkContext(
env.shuffleManager.stop()
env.blockManager.stop()
BlockManagerMaster.stopBlockManagerMaster()
env.connectionManager.stop()
SparkEnv.set(null)
ShuffleMapTask.clearCache()
}
......
......@@ -68,8 +68,7 @@ class ConnectionManager(port: Int) extends Logging {
def run() {
try {
var interrupted = false
while(!interrupted) {
while(!selectorThread.isInterrupted) {
while(!connectionRequests.isEmpty) {
val sendingConnection = connectionRequests.dequeue
sendingConnection.connect()
......@@ -103,10 +102,14 @@ class ConnectionManager(port: Int) extends Logging {
}
val selectedKeysCount = selector.select()
if (selectedKeysCount == 0) logInfo("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
if (selectedKeysCount == 0) {
logInfo("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
}
if (selectorThread.isInterrupted) {
logInfo("Selector thread was interrupted!")
return
}
interrupted = selectorThread.isInterrupted
val selectedKeys = selector.selectedKeys().iterator()
while (selectedKeys.hasNext()) {
val key = selectedKeys.next.asInstanceOf[SelectionKey]
......@@ -330,18 +333,16 @@ class ConnectionManager(port: Int) extends Logging {
}
def stop() {
if (!selectorThread.isAlive) {
selectorThread.interrupt()
selectorThread.join()
selector.close()
val connections = connectionsByKey.values
connections.foreach(_.close())
if (connectionsByKey.size != 0) {
logWarning("All connections not cleaned up")
}
handleMessageExecutor.shutdown()
logInfo("ConnectionManager stopped")
selectorThread.interrupt()
selectorThread.join()
selector.close()
val connections = connectionsByKey.values
connections.foreach(_.close())
if (connectionsByKey.size != 0) {
logWarning("All connections not cleaned up")
}
handleMessageExecutor.shutdown()
logInfo("ConnectionManager stopped")
}
}
......
......@@ -223,7 +223,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* events and responds by launching tasks. This runs in a dedicated thread and receives events
* via the eventQueue.
*/
def run() = {
def run() {
SparkEnv.set(env)
while (true) {
......@@ -258,6 +258,14 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
case completion: CompletionEvent =>
handleTaskCompletion(completion)
case StopDAGScheduler =>
// Cancel any active jobs
for (job <- activeJobs) {
val error = new SparkException("Job cancelled because SparkContext was shut down")
job.listener.jobFailed(error)
}
return
case null =>
// queue.poll() timed out, ignore it
}
......@@ -529,7 +537,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
def stop() {
// TODO: Put a stop event on our queue and break the event loop
eventQueue.put(StopDAGScheduler)
taskSched.stop()
}
}
......@@ -28,3 +28,5 @@ case class CompletionEvent(
extends DAGSchedulerEvent
case class HostLost(host: String) extends DAGSchedulerEvent
case object StopDAGScheduler extends DAGSchedulerEvent
......@@ -48,14 +48,20 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
Accumulators.clear
val bytes = Utils.serialize(task)
logInfo("Size of task " + idInJob + " is " + bytes.size + " bytes")
val deserializedTask = Utils.deserialize[Task[_]](
val ser = SparkEnv.get.closureSerializer.newInstance()
val bytes = ser.serialize(task)
logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes")
val deserializedTask = ser.deserialize[Task[_]](
bytes, Thread.currentThread.getContextClassLoader)
val result: Any = deserializedTask.run(attemptId)
// Serialize and deserialize the result to emulate what the Mesos
// executor does. This is useful to catch serialization errors early
// on in development (so when users move their local Spark programs
// to the cluster, they don't get surprised by serialization errors).
val resultToReturn = ser.deserialize[Any](ser.serialize(result))
val accumUpdates = Accumulators.values
logInfo("Finished task " + idInJob)
listener.taskEnded(task, Success, result, accumUpdates)
listener.taskEnded(task, Success, resultToReturn, accumUpdates)
} catch {
case t: Throwable => {
logError("Exception in task " + idInJob, t)
......@@ -77,7 +83,9 @@ class LocalScheduler(threads: Int, maxFailures: Int) extends TaskScheduler with
}
}
override def stop() {}
override def stop() {
threadPool.shutdownNow()
}
override def defaultParallelism() = threads
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment