diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 92c809d85416795833f135178d136395b87ebbb9..0bce531aaba3e68cf6c7c7984d92d72d01928de2 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.File +import java.net.Socket import scala.collection.JavaConversions._ import scala.collection.mutable @@ -102,10 +103,10 @@ class SparkEnv ( } private[spark] - def destroyPythonWorker(pythonExec: String, envVars: Map[String, String]) { + def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) { synchronized { val key = (pythonExec, envVars) - pythonWorkers(key).stop() + pythonWorkers.get(key).foreach(_.stopWorker(worker)) } } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index fe9a9e50ef21d032ce1542cc4135d625c97c0985..0b5322c6fb96571876bcfbb15c2efc40c0e5d06b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -62,8 +62,8 @@ private[spark] class PythonRDD( val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map( f => f.getPath()).mkString(",") - val worker: Socket = env.createPythonWorker(pythonExec, - envVars.toMap + ("SPARK_LOCAL_DIR" -> localdir)) + envVars += ("SPARK_LOCAL_DIR" -> localdir) // it's also used in monitor thread + val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) @@ -241,7 +241,7 @@ private[spark] class PythonRDD( if (!context.completed) { try { logWarning("Incomplete task interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.toMap) + env.destroyPythonWorker(pythonExec, envVars.toMap, worker) } catch { case e: Exception => logError("Exception when trying to kill worker", e) @@ -685,9 +685,8 @@ private[spark] object PythonRDD extends Logging { /** * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). - * This function is outdated, PySpark does not use it anymore */ - @deprecated + @deprecated("PySpark does not use it anymore", "1.1") def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { pyRDD.rdd.mapPartitions { iter => val unpickle = new Unpickler diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 15fe8a9be6bfee01d04e9d05f2d2fd54b94ecd99..7af260d0b7f26b7b4bcd3691c8f78b83ca4c939d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -17,9 +17,11 @@ package org.apache.spark.api.python -import java.io.{DataInputStream, InputStream, OutputStreamWriter} +import java.lang.Runtime +import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} +import scala.collection.mutable import scala.collection.JavaConversions._ import org.apache.spark._ @@ -39,6 +41,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 + var daemonWorkers = new mutable.WeakHashMap[Socket, Int]() + + var simpleWorkers = new mutable.WeakHashMap[Socket, Process]() val pythonPath = PythonUtils.mergePythonPaths( PythonUtils.sparkPythonPath, @@ -58,25 +63,31 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String * to avoid the high cost of forking from Java. This currently only works on UNIX-based systems. */ private def createThroughDaemon(): Socket = { + + def createSocket(): Socket = { + val socket = new Socket(daemonHost, daemonPort) + val pid = new DataInputStream(socket.getInputStream).readInt() + if (pid < 0) { + throw new IllegalStateException("Python daemon failed to launch worker") + } + daemonWorkers.put(socket, pid) + socket + } + synchronized { // Start the daemon if it hasn't been started startDaemon() // Attempt to connect, restart and retry once if it fails try { - val socket = new Socket(daemonHost, daemonPort) - val launchStatus = new DataInputStream(socket.getInputStream).readInt() - if (launchStatus != 0) { - throw new IllegalStateException("Python daemon failed to launch worker") - } - socket + createSocket() } catch { case exc: SocketException => logWarning("Failed to open socket to Python daemon:", exc) logWarning("Assuming that daemon unexpectedly quit, attempting to restart") stopDaemon() startDaemon() - new Socket(daemonHost, daemonPort) + createSocket() } } } @@ -107,7 +118,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // Wait for it to connect to our socket serverSocket.setSoTimeout(10000) try { - return serverSocket.accept() + val socket = serverSocket.accept() + simpleWorkers.put(socket, worker) + return socket } catch { case e: Exception => throw new SparkException("Python worker did not connect back in time", e) @@ -189,19 +202,40 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String private def stopDaemon() { synchronized { - // Request shutdown of existing daemon by sending SIGTERM - if (daemon != null) { - daemon.destroy() - } + if (useDaemon) { + // Request shutdown of existing daemon by sending SIGTERM + if (daemon != null) { + daemon.destroy() + } - daemon = null - daemonPort = 0 + daemon = null + daemonPort = 0 + } else { + simpleWorkers.mapValues(_.destroy()) + } } } def stop() { stopDaemon() } + + def stopWorker(worker: Socket) { + if (useDaemon) { + if (daemon != null) { + daemonWorkers.get(worker).foreach { pid => + // tell daemon to kill worker by pid + val output = new DataOutputStream(daemon.getOutputStream) + output.writeInt(pid) + output.flush() + daemon.getOutputStream.flush() + } + } + } else { + simpleWorkers.get(worker).foreach(_.destroy()) + } + worker.close() + } } private object PythonWorkerFactory { diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 9fde0dde0f4b4e7ffc12483e03c2a71c4e8d8ab4..b00da833d06f1d2af89505694c716014b7bf5c9d 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -26,7 +26,7 @@ from errno import EINTR, ECHILD from socket import AF_INET, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN from pyspark.worker import main as worker_main -from pyspark.serializers import write_int +from pyspark.serializers import read_int, write_int def compute_real_exit_code(exit_code): @@ -67,7 +67,8 @@ def worker(sock): outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) exit_code = 0 try: - write_int(0, outfile) # Acknowledge that the fork was successful + # Acknowledge that the fork was successful + write_int(os.getpid(), outfile) outfile.flush() worker_main(infile, outfile) except SystemExit as exc: @@ -125,14 +126,23 @@ def manager(): else: raise if 0 in ready_fds: - # Spark told us to exit by closing stdin - shutdown(0) + try: + worker_pid = read_int(sys.stdin) + except EOFError: + # Spark told us to exit by closing stdin + shutdown(0) + try: + os.kill(worker_pid, signal.SIGKILL) + except OSError: + pass # process already died + + if listen_sock in ready_fds: sock, addr = listen_sock.accept() # Launch a worker process try: - fork_return_code = os.fork() - if fork_return_code == 0: + pid = os.fork() + if pid == 0: listen_sock.close() try: worker(sock) @@ -143,11 +153,13 @@ def manager(): os._exit(0) else: sock.close() + except OSError as e: print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) write_int(-1, outfile) # Signal that the fork failed outfile.flush() + outfile.close() sock.close() finally: shutdown(1) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 16fb5a925622034e3177f83514ac824714e9094b..acc3c303716215c16ca2a671c670a9eaef9a5d7d 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -790,6 +790,57 @@ class TestDaemon(unittest.TestCase): self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) +class TestWorker(PySparkTestCase): + def test_cancel_task(self): + temp = tempfile.NamedTemporaryFile(delete=True) + temp.close() + path = temp.name + def sleep(x): + import os, time + with open(path, 'w') as f: + f.write("%d %d" % (os.getppid(), os.getpid())) + time.sleep(100) + + # start job in background thread + def run(): + self.sc.parallelize(range(1)).foreach(sleep) + import threading + t = threading.Thread(target=run) + t.daemon = True + t.start() + + daemon_pid, worker_pid = 0, 0 + while True: + if os.path.exists(path): + data = open(path).read().split(' ') + daemon_pid, worker_pid = map(int, data) + break + time.sleep(0.1) + + # cancel jobs + self.sc.cancelAllJobs() + t.join() + + for i in range(50): + try: + os.kill(worker_pid, 0) + time.sleep(0.1) + except OSError: + break # worker was killed + else: + self.fail("worker has not been killed after 5 seconds") + + try: + os.kill(daemon_pid, 0) + except OSError: + self.fail("daemon had been killed") + + def test_fd_leak(self): + N = 1100 # fd limit is 1024 by default + rdd = self.sc.parallelize(range(N), N) + self.assertEquals(N, rdd.count()) + + class TestSparkSubmit(unittest.TestCase): def setUp(self): self.programDir = tempfile.mkdtemp()