Skip to content
Snippets Groups Projects
Commit 62c47814 authored by Jey Kottalam's avatar Jey Kottalam
Browse files

Add tests and fixes for Python daemon shutdown

parent c79a6078
No related branches found
No related tags found
No related merge requests found
...@@ -44,6 +44,7 @@ class SparkEnv ( ...@@ -44,6 +44,7 @@ class SparkEnv (
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorker]() private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorker]()
def stop() { def stop() {
pythonWorkers.foreach { case(key, worker) => worker.stop() }
httpFileServer.stop() httpFileServer.stop()
mapOutputTracker.stop() mapOutputTracker.stop()
shuffleFetcher.stop() shuffleFetcher.stop()
......
...@@ -33,6 +33,10 @@ private[spark] class PythonWorker(pythonExec: String, envVars: Map[String, Strin ...@@ -33,6 +33,10 @@ private[spark] class PythonWorker(pythonExec: String, envVars: Map[String, Strin
} }
} }
def stop() {
stopDaemon
}
private def startDaemon() { private def startDaemon() {
synchronized { synchronized {
// Is it already running? // Is it already running?
......
...@@ -12,7 +12,7 @@ try: ...@@ -12,7 +12,7 @@ try:
except NotImplementedError: except NotImplementedError:
POOLSIZE = 4 POOLSIZE = 4
should_exit = False should_exit = multiprocessing.Event()
def worker(listen_sock): def worker(listen_sock):
...@@ -21,14 +21,13 @@ def worker(listen_sock): ...@@ -21,14 +21,13 @@ def worker(listen_sock):
# Manager sends SIGHUP to request termination of workers in the pool # Manager sends SIGHUP to request termination of workers in the pool
def handle_sighup(signum, frame): def handle_sighup(signum, frame):
global should_exit assert should_exit.is_set()
should_exit = True
signal(SIGHUP, handle_sighup) signal(SIGHUP, handle_sighup)
while not should_exit: while not should_exit.is_set():
# Wait until a client arrives or we have to exit # Wait until a client arrives or we have to exit
sock = None sock = None
while not should_exit and sock is None: while not should_exit.is_set() and sock is None:
try: try:
sock, addr = listen_sock.accept() sock, addr = listen_sock.accept()
except EnvironmentError as err: except EnvironmentError as err:
...@@ -36,8 +35,8 @@ def worker(listen_sock): ...@@ -36,8 +35,8 @@ def worker(listen_sock):
raise raise
if sock is not None: if sock is not None:
# Fork a child to handle the client # Fork to handle the client
if os.fork() == 0: if os.fork() != 0:
# Leave the worker pool # Leave the worker pool
signal(SIGHUP, SIG_DFL) signal(SIGHUP, SIG_DFL)
listen_sock.close() listen_sock.close()
...@@ -50,7 +49,7 @@ def worker(listen_sock): ...@@ -50,7 +49,7 @@ def worker(listen_sock):
else: else:
sock.close() sock.close()
assert should_exit assert should_exit.is_set()
os._exit(0) os._exit(0)
...@@ -73,9 +72,7 @@ def manager(): ...@@ -73,9 +72,7 @@ def manager():
listen_sock.close() listen_sock.close()
def shutdown(): def shutdown():
global should_exit should_exit.set()
os.kill(0, SIGHUP)
should_exit = True
# Gracefully exit on SIGTERM, don't die on SIGHUP # Gracefully exit on SIGTERM, don't die on SIGHUP
signal(SIGTERM, lambda signum, frame: shutdown()) signal(SIGTERM, lambda signum, frame: shutdown())
...@@ -85,8 +82,8 @@ def manager(): ...@@ -85,8 +82,8 @@ def manager():
def handle_sigchld(signum, frame): def handle_sigchld(signum, frame):
try: try:
pid, status = os.waitpid(0, os.WNOHANG) pid, status = os.waitpid(0, os.WNOHANG)
if (pid, status) != (0, 0) and not should_exit: if status != 0 and not should_exit.is_set():
raise RuntimeError("pool member crashed: %s, %s" % (pid, status)) raise RuntimeError("worker crashed: %s, %s" % (pid, status))
except EnvironmentError as err: except EnvironmentError as err:
if err.errno not in (ECHILD, EINTR): if err.errno not in (ECHILD, EINTR):
raise raise
...@@ -94,15 +91,20 @@ def manager(): ...@@ -94,15 +91,20 @@ def manager():
# Initialization complete # Initialization complete
sys.stdout.close() sys.stdout.close()
while not should_exit: try:
try: while not should_exit.is_set():
# Spark tells us to exit by closing stdin try:
if sys.stdin.read() == '': # Spark tells us to exit by closing stdin
shutdown() if os.read(0, 512) == '':
except EnvironmentError as err: shutdown()
if err.errno != EINTR: except EnvironmentError as err:
shutdown() if err.errno != EINTR:
raise shutdown()
raise
finally:
should_exit.set()
# Send SIGHUP to notify workers of shutdown
os.kill(0, SIGHUP)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -12,6 +12,7 @@ import unittest ...@@ -12,6 +12,7 @@ import unittest
from pyspark.context import SparkContext from pyspark.context import SparkContext
from pyspark.files import SparkFiles from pyspark.files import SparkFiles
from pyspark.java_gateway import SPARK_HOME from pyspark.java_gateway import SPARK_HOME
from pyspark.serializers import read_int
class PySparkTestCase(unittest.TestCase): class PySparkTestCase(unittest.TestCase):
...@@ -117,5 +118,47 @@ class TestIO(PySparkTestCase): ...@@ -117,5 +118,47 @@ class TestIO(PySparkTestCase):
self.sc.parallelize([1]).foreach(func) self.sc.parallelize([1]).foreach(func)
class TestDaemon(unittest.TestCase):
def connect(self, port):
from socket import socket, AF_INET, SOCK_STREAM
sock = socket(AF_INET, SOCK_STREAM)
sock.connect(('127.0.0.1', port))
# send a split index of -1 to shutdown the worker
sock.send("\xFF\xFF\xFF\xFF")
sock.close()
return True
def do_termination_test(self, terminator):
from subprocess import Popen, PIPE
from errno import ECONNREFUSED
# start daemon
daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py")
daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE)
# read the port number
port = read_int(daemon.stdout)
# daemon should accept connections
self.assertTrue(self.connect(port))
# request shutdown
terminator(daemon)
time.sleep(1)
# daemon should no longer accept connections
with self.assertRaises(EnvironmentError) as trap:
self.connect(port)
self.assertEqual(trap.exception.errno, ECONNREFUSED)
def test_termination_stdin(self):
"""Ensure that daemon and workers terminate when stdin is closed."""
self.do_termination_test(lambda daemon: daemon.stdin.close())
def test_termination_sigterm(self):
"""Ensure that daemon and workers terminate on SIGTERM."""
from signal import SIGTERM
self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -30,6 +30,8 @@ def report_times(outfile, boot, init, finish): ...@@ -30,6 +30,8 @@ def report_times(outfile, boot, init, finish):
def main(infile, outfile): def main(infile, outfile):
boot_time = time.time() boot_time = time.time()
split_index = read_int(infile) split_index = read_int(infile)
if split_index == -1: # for unit tests
return
spark_files_dir = load_pickle(read_with_length(infile)) spark_files_dir = load_pickle(read_with_length(infile))
SparkFiles._root_directory = spark_files_dir SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True SparkFiles._is_running_on_worker = True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment