From 62c4781400dd908c2fccdcebf0dc816ff0cb8ed4 Mon Sep 17 00:00:00 2001
From: Jey Kottalam <jey@cs.berkeley.edu>
Date: Fri, 10 May 2013 15:48:48 -0700
Subject: [PATCH] Add tests and fixes for Python daemon shutdown

---
 core/src/main/scala/spark/SparkEnv.scala      |  1 +
 .../scala/spark/api/python/PythonWorker.scala |  4 ++
 python/pyspark/daemon.py                      | 46 ++++++++++---------
 python/pyspark/tests.py                       | 43 +++++++++++++++++
 python/pyspark/worker.py                      |  2 +
 5 files changed, 74 insertions(+), 22 deletions(-)

diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index 5691e24c32..5b55d45212 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -44,6 +44,7 @@ class SparkEnv (
   private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorker]()
 
   def stop() {
+    pythonWorkers.foreach { case(key, worker) => worker.stop() }
     httpFileServer.stop()
     mapOutputTracker.stop()
     shuffleFetcher.stop()
diff --git a/core/src/main/scala/spark/api/python/PythonWorker.scala b/core/src/main/scala/spark/api/python/PythonWorker.scala
index 8ee3c6884f..74c8c6d37a 100644
--- a/core/src/main/scala/spark/api/python/PythonWorker.scala
+++ b/core/src/main/scala/spark/api/python/PythonWorker.scala
@@ -33,6 +33,10 @@ private[spark] class PythonWorker(pythonExec: String, envVars: Map[String, Strin
     }
   }
 
+  def stop() {
+    stopDaemon
+  }
+
   private def startDaemon() {
     synchronized {
       // Is it already running?
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 642f30b2b9..ab9c19df57 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -12,7 +12,7 @@ try:
 except NotImplementedError:
     POOLSIZE = 4
 
-should_exit = False
+should_exit = multiprocessing.Event()
 
 
 def worker(listen_sock):
@@ -21,14 +21,13 @@ def worker(listen_sock):
 
     # Manager sends SIGHUP to request termination of workers in the pool
     def handle_sighup(signum, frame):
-        global should_exit
-        should_exit = True
+        assert should_exit.is_set()
     signal(SIGHUP, handle_sighup)
 
-    while not should_exit:
+    while not should_exit.is_set():
         # Wait until a client arrives or we have to exit
         sock = None
-        while not should_exit and sock is None:
+        while not should_exit.is_set() and sock is None:
             try:
                 sock, addr = listen_sock.accept()
             except EnvironmentError as err:
@@ -36,8 +35,8 @@ def worker(listen_sock):
                     raise
 
         if sock is not None:
-            # Fork a child to handle the client
-            if os.fork() == 0:
+            # Fork to handle the client
+            if os.fork() != 0:
                 # Leave the worker pool
                 signal(SIGHUP, SIG_DFL)
                 listen_sock.close()
@@ -50,7 +49,7 @@ def worker(listen_sock):
             else:
                 sock.close()
 
-    assert should_exit
+    assert should_exit.is_set()
     os._exit(0)
 
 
@@ -73,9 +72,7 @@ def manager():
     listen_sock.close()
 
     def shutdown():
-        global should_exit
-        os.kill(0, SIGHUP)
-        should_exit = True
+        should_exit.set()
 
     # Gracefully exit on SIGTERM, don't die on SIGHUP
     signal(SIGTERM, lambda signum, frame: shutdown())
@@ -85,8 +82,8 @@ def manager():
     def handle_sigchld(signum, frame):
         try:
             pid, status = os.waitpid(0, os.WNOHANG)
-            if (pid, status) != (0, 0) and not should_exit:
-                raise RuntimeError("pool member crashed: %s, %s" % (pid, status))
+            if status != 0 and not should_exit.is_set():
+                raise RuntimeError("worker crashed: %s, %s" % (pid, status))
         except EnvironmentError as err:
             if err.errno not in (ECHILD, EINTR):
                 raise
@@ -94,15 +91,20 @@ def manager():
 
     # Initialization complete
     sys.stdout.close()
-    while not should_exit:
-        try:
-            # Spark tells us to exit by closing stdin
-            if sys.stdin.read() == '':
-                shutdown()
-        except EnvironmentError as err:
-            if err.errno != EINTR:
-                shutdown()
-                raise
+    try:
+        while not should_exit.is_set():
+            try:
+                # Spark tells us to exit by closing stdin
+                if os.read(0, 512) == '':
+                    shutdown()
+            except EnvironmentError as err:
+                if err.errno != EINTR:
+                    shutdown()
+                    raise
+    finally:
+        should_exit.set()
+        # Send SIGHUP to notify workers of shutdown
+        os.kill(0, SIGHUP)
 
 
 if __name__ == '__main__':
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 6a1962d267..1e34d47365 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -12,6 +12,7 @@ import unittest
 from pyspark.context import SparkContext
 from pyspark.files import SparkFiles
 from pyspark.java_gateway import SPARK_HOME
+from pyspark.serializers import read_int
 
 
 class PySparkTestCase(unittest.TestCase):
@@ -117,5 +118,47 @@ class TestIO(PySparkTestCase):
         self.sc.parallelize([1]).foreach(func)
 
 
+class TestDaemon(unittest.TestCase):
+    def connect(self, port):
+        from socket import socket, AF_INET, SOCK_STREAM
+        sock = socket(AF_INET, SOCK_STREAM)
+        sock.connect(('127.0.0.1', port))
+        # send a split index of -1 to shutdown the worker
+        sock.send("\xFF\xFF\xFF\xFF")
+        sock.close()
+        return True
+
+    def do_termination_test(self, terminator):
+        from subprocess import Popen, PIPE
+        from errno import ECONNREFUSED
+
+        # start daemon
+        daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py")
+        daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE)
+
+        # read the port number
+        port = read_int(daemon.stdout)
+
+        # daemon should accept connections
+        self.assertTrue(self.connect(port))
+
+        # request shutdown
+        terminator(daemon)
+        time.sleep(1)
+
+        # daemon should no longer accept connections
+        with self.assertRaises(EnvironmentError) as trap:
+            self.connect(port)
+        self.assertEqual(trap.exception.errno, ECONNREFUSED)
+
+    def test_termination_stdin(self):
+        """Ensure that daemon and workers terminate when stdin is closed."""
+        self.do_termination_test(lambda daemon: daemon.stdin.close())
+
+    def test_termination_sigterm(self):
+        """Ensure that daemon and workers terminate on SIGTERM."""
+        from signal import SIGTERM
+        self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 94d612ea6e..f76ee3c236 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -30,6 +30,8 @@ def report_times(outfile, boot, init, finish):
 def main(infile, outfile):
     boot_time = time.time()
     split_index = read_int(infile)
+    if split_index == -1:  # for unit tests
+        return
     spark_files_dir = load_pickle(read_with_length(infile))
     SparkFiles._root_directory = spark_files_dir
     SparkFiles._is_running_on_worker = True
-- 
GitLab