diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d7bd832e5266b84292ed198f0ed1a238f395bd6f..5d0f2950d61c3e58b317bf57bee7ea8ecd219671 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -67,7 +67,7 @@ class SparkContext( System.setProperty("spark.master.port", "0") } - private val isLocal = (master == "local" || master.startsWith("local[")) + private val isLocal = (master == "local" || master.startsWith("local\\[")) // Create the Spark execution environment (cache, map output tracker, etc) val env = SparkEnv.createFromSystemProperties( @@ -84,7 +84,7 @@ class SparkContext( // Regular expression for local[N, maxRetries], used in tests with failing tasks val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+),([0-9]+)\]""".r // Regular expression for simulating a Spark cluster of [N, cores, memory] locally - val SPARK_LOCALCLUSTER_REGEX = """local-cluster\[([0-9]+)\,([0-9]+),([0-9]+)]""".r + val LOCAL_CLUSTER_REGEX = """local-cluster\[([0-9]+),([0-9]+),([0-9]+)]""".r // Regular expression for connecting to Spark deploy clusters val SPARK_REGEX = """(spark://.*)""".r @@ -104,13 +104,13 @@ class SparkContext( scheduler.initialize(backend) scheduler - case SPARK_LOCALCLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerlave) => + case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerlave) => val scheduler = new ClusterScheduler(this) val localCluster = new LocalSparkCluster(numSlaves.toInt, coresPerSlave.toInt, memoryPerlave.toInt) val sparkUrl = localCluster.start() val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, frameworkName) scheduler.initialize(backend) - backend.shutdownHook = (backend: SparkDeploySchedulerBackend) => { + backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { localCluster.stop() } scheduler diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala index 03986ea756302d00d20e03f971d4a9d8c0d443b2..eacf2375089e86f21cf901041aaf6c4cbbeb84f6 100644 --- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala @@ -76,9 +76,12 @@ private object HttpBroadcast extends Logging { } def stop() { - if (server != null) { - server.stop() - server = null + synchronized { + if (server != null) { + server.stop() + server = null + initialized = false + } } } diff --git a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala index da74df4dcf37ac6846d506706d06ef2e5e9f6937..1591bfdeb65d995426c6ca6f0739df4cb1d8ad11 100644 --- a/core/src/main/scala/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/spark/deploy/LocalSparkCluster.scala @@ -9,10 +9,8 @@ import spark.{Logging, Utils} import scala.collection.mutable.ArrayBuffer -class LocalSparkCluster(numSlaves : Int, coresPerSlave : Int, - memoryPerSlave : Int) extends Logging { +class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) extends Logging { - val threadPool = Utils.newDaemonFixedThreadPool(numSlaves + 1) val localIpAddress = Utils.localIpAddress var masterActor : ActorRef = _ @@ -24,35 +22,25 @@ class LocalSparkCluster(numSlaves : Int, coresPerSlave : Int, val slaveActors = ArrayBuffer[ActorRef]() def start() : String = { - logInfo("Starting a local Spark cluster with " + numSlaves + " slaves.") /* Start the Master */ - val (masterActorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0) + val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0) + masterActorSystem = actorSystem masterUrl = "spark://" + localIpAddress + ":" + masterPort - threadPool.execute(new Runnable { - def run() { - val actor = masterActorSystem.actorOf( - Props(new Master(localIpAddress, masterPort, 8080)), name = "Master") - masterActor = actor - masterActorSystem.awaitTermination() - } - }) + val actor = masterActorSystem.actorOf( + Props(new Master(localIpAddress, masterPort, 0)), name = "Master") + masterActor = actor /* Start the Slaves */ - (1 to numSlaves).foreach { slaveNum => + for (slaveNum <- 1 to numSlaves) { val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkWorker" + slaveNum, localIpAddress, 0) slaveActorSystems += actorSystem - threadPool.execute(new Runnable { - def run() { - val actor = actorSystem.actorOf( - Props(new Worker(localIpAddress, boundPort, 8080 + slaveNum, coresPerSlave, memoryPerSlave, masterUrl)), - name = "Worker") - slaveActors += actor - actorSystem.awaitTermination() - } - }) + val actor = actorSystem.actorOf( + Props(new Worker(localIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)), + name = "Worker") + slaveActors += actor } return masterUrl @@ -60,9 +48,10 @@ class LocalSparkCluster(numSlaves : Int, coresPerSlave : Int, def stop() { logInfo("Shutting down local Spark cluster.") - masterActorSystem.shutdown() + // Stop the slaves before the master so they don't get upset that it disconnected slaveActorSystems.foreach(_.shutdown()) + slaveActorSystems.foreach(_.awaitTermination()) + masterActorSystem.shutdown() + masterActorSystem.awaitTermination() } - - -} \ No newline at end of file +} diff --git a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala index 393f4a3ee6d40baf47a5833dbfa372772f5ec5bc..1740a42a7eff69bffd4ccbe3cfe881d534a8793f 100644 --- a/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/spark/deploy/worker/ExecutorRunner.scala @@ -29,6 +29,7 @@ class ExecutorRunner( val fullId = jobId + "/" + execId var workerThread: Thread = null var process: Process = null + var shutdownHook: Thread = null def start() { workerThread = new Thread("ExecutorRunner for " + fullId) { @@ -37,17 +38,16 @@ class ExecutorRunner( workerThread.start() // Shutdown hook that kills actors on shutdown. - Runtime.getRuntime.addShutdownHook( - new Thread() { - override def run() { - if(process != null) { - logInfo("Shutdown Hook killing process.") - process.destroy() - process.waitFor() - } + shutdownHook = new Thread() { + override def run() { + if (process != null) { + logInfo("Shutdown hook killing child process.") + process.destroy() + process.waitFor() } - }) - + } + } + Runtime.getRuntime.addShutdownHook(shutdownHook) } /** Stop this executor runner, including killing the process it launched */ @@ -58,8 +58,10 @@ class ExecutorRunner( if (process != null) { logInfo("Killing process!") process.destroy() + process.waitFor() } worker ! ExecutorStateChanged(jobId, execId, ExecutorState.KILLED, None) + Runtime.getRuntime.removeShutdownHook(shutdownHook) } } @@ -114,7 +116,12 @@ class ExecutorRunner( val out = new FileOutputStream(file) new Thread("redirect output to " + file) { override def run() { - Utils.copyStream(in, out, true) + try { + Utils.copyStream(in, out, true) + } catch { + case e: IOException => + logInfo("Redirection to " + file + " closed: " + e.getMessage) + } } }.start() } diff --git a/core/src/main/scala/spark/deploy/worker/Worker.scala b/core/src/main/scala/spark/deploy/worker/Worker.scala index 0a80463c0bdcc220b3c4bf91d6d971c2d7720b32..175464d40ddedb23b11fad3eca80e69f8b1245ee 100644 --- a/core/src/main/scala/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/spark/deploy/worker/Worker.scala @@ -153,6 +153,10 @@ class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, mas def generateWorkerId(): String = { "worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port) } + + override def postStop() { + executors.values.foreach(_.kill()) + } } object Worker { diff --git a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index ec3ff38d5c42aa073a1e3108e29d450bbc1c7e31..9093a329a369d257decd39694630c61224d7b7c6 100644 --- a/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -16,7 +16,7 @@ class SparkDeploySchedulerBackend( var client: Client = null var stopping = false - var shutdownHook : (SparkDeploySchedulerBackend) => Unit = _ + var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ val maxCores = System.getProperty("spark.cores.max", Int.MaxValue.toString).toInt @@ -62,8 +62,8 @@ class SparkDeploySchedulerBackend( stopping = true; super.stop() client.stop() - if (shutdownHook != null) { - shutdownHook(this) + if (shutdownCallback != null) { + shutdownCallback(this) } } diff --git a/core/src/test/scala/spark/DistributedSuite.scala b/core/src/test/scala/spark/DistributedSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..b7b8a79327748ba63092d46dc5a309e5a30a6e73 --- /dev/null +++ b/core/src/test/scala/spark/DistributedSuite.scala @@ -0,0 +1,68 @@ +package spark + +import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter +import org.scalatest.matchers.ShouldMatchers +import org.scalatest.prop.Checkers +import org.scalacheck.Arbitrary._ +import org.scalacheck.Gen +import org.scalacheck.Prop._ + +import com.google.common.io.Files + +import scala.collection.mutable.ArrayBuffer + +import SparkContext._ + +class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { + + val clusterUrl = "local-cluster[2,1,512]" + + var sc: SparkContext = _ + + after { + if (sc != null) { + sc.stop() + sc = null + } + } + + test("simple groupByKey") { + sc = new SparkContext(clusterUrl, "test") + val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 5) + val groups = pairs.groupByKey(5).collect() + assert(groups.size === 2) + val valuesFor1 = groups.find(_._1 == 1).get._2 + assert(valuesFor1.toList.sorted === List(1, 2, 3)) + val valuesFor2 = groups.find(_._1 == 2).get._2 + assert(valuesFor2.toList.sorted === List(1)) + } + + test("accumulators") { + sc = new SparkContext(clusterUrl, "test") + val accum = sc.accumulator(0) + sc.parallelize(1 to 10, 10).foreach(x => accum += x) + assert(accum.value === 55) + } + + test("broadcast variables") { + sc = new SparkContext(clusterUrl, "test") + val array = new Array[Int](100) + val bv = sc.broadcast(array) + array(2) = 3 // Change the array -- this should not be seen on workers + val rdd = sc.parallelize(1 to 10, 10) + val sum = rdd.map(x => bv.value.sum).reduce(_ + _) + assert(sum === 0) + } + + test("repeatedly failing task") { + sc = new SparkContext(clusterUrl, "test") + val accum = sc.accumulator(0) + val thrown = intercept[SparkException] { + sc.parallelize(1 to 10, 10).foreach(x => println(x / 0)) + } + assert(thrown.getClass === classOf[SparkException]) + assert(thrown.getMessage.contains("more than 4 times")) + } +} + diff --git a/run b/run index 8f7256b4e566685ef29b85ca0dcde9ad7c08e8f8..2946a04d3f7e2d22c59c7a748dc396c19da9bf32 100755 --- a/run +++ b/run @@ -52,6 +52,7 @@ CLASSPATH="$SPARK_CLASSPATH" CLASSPATH+=":$MESOS_CLASSPATH" CLASSPATH+=":$FWDIR/conf" CLASSPATH+=":$CORE_DIR/target/scala-$SCALA_VERSION/classes" +CLASSPATH+=":$CORE_DIR/target/scala-$SCALA_VERSION/test-classes" CLASSPATH+=":$CORE_DIR/src/main/resources" CLASSPATH+=":$REPL_DIR/target/scala-$SCALA_VERSION/classes" CLASSPATH+=":$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes"