diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index aba713cb4267a94f5168d36d580bd1a8220d9b9d..906a00b0bd17c7d8ebc1f0213b3823f78e02ba97 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -68,6 +68,7 @@ class SparkEnv (
     val shuffleMemoryManager: ShuffleMemoryManager,
     val conf: SparkConf) extends Logging {
 
+  private[spark] var isStopped = false
   private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
 
   // A general, soft-reference map for metadata needed during HadoopRDD split computation
@@ -75,6 +76,7 @@ class SparkEnv (
   private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
 
   private[spark] def stop() {
+    isStopped = true
     pythonWorkers.foreach { case(key, worker) => worker.stop() }
     Option(httpFileServer).foreach(_.stop())
     mapOutputTracker.stop()
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 29ca751519abdf90b3ca1b006321cc2c4465dde1..163dca6cade5a3e10c8874b73986462fcf687233 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -75,6 +75,7 @@ private[spark] class PythonRDD(
     var complete_cleanly = false
     context.addTaskCompletionListener { context =>
       writerThread.shutdownOnTaskCompletion()
+      writerThread.join()
       if (reuse_worker && complete_cleanly) {
         env.releasePythonWorker(pythonExec, envVars.toMap, worker)
       } else {
@@ -145,7 +146,9 @@ private[spark] class PythonRDD(
                 stream.readFully(update)
                 accumulator += Collections.singletonList(update)
               }
-               complete_cleanly = true
+              if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
+                complete_cleanly = true
+              }
               null
           }
         } catch {
@@ -154,6 +157,10 @@ private[spark] class PythonRDD(
             logDebug("Exception thrown after task interruption", e)
             throw new TaskKilledException
 
+          case e: Exception if env.isStopped =>
+            logDebug("Exception thrown after context is stopped", e)
+            null  // exit silently
+
           case e: Exception if writerThread.exception.isDefined =>
             logError("Python worker exited unexpectedly (crashed)", e)
             logError("This may have been caused by a prior exception:", writerThread.exception.get)
@@ -235,6 +242,7 @@ private[spark] class PythonRDD(
         // Data values
         PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut)
         dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+        dataOut.writeInt(SpecialLengths.END_OF_STREAM)
         dataOut.flush()
       } catch {
         case e: Exception if context.isCompleted || context.isInterrupted =>
@@ -306,6 +314,7 @@ private object SpecialLengths {
   val END_OF_DATA_SECTION = -1
   val PYTHON_EXCEPTION_THROWN = -2
   val TIMING_DATA = -3
+  val END_OF_STREAM = -4
 }
 
 private[spark] object PythonRDD extends Logging {
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 64d6202acb27deded1385dd7609ed8b719a93dd6..dbb34775d9ac50bf9864c60d2c40b810521fa2e2 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -26,7 +26,7 @@ import time
 import gc
 from errno import EINTR, ECHILD, EAGAIN
 from socket import AF_INET, SOCK_STREAM, SOMAXCONN
-from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
+from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT
 from pyspark.worker import main as worker_main
 from pyspark.serializers import read_int, write_int
 
@@ -46,6 +46,9 @@ def worker(sock):
     signal.signal(SIGHUP, SIG_DFL)
     signal.signal(SIGCHLD, SIG_DFL)
     signal.signal(SIGTERM, SIG_DFL)
+    # restore the handler for SIGINT,
+    # it's useful for debugging (show the stacktrace before exit)
+    signal.signal(SIGINT, signal.default_int_handler)
 
     # Read the socket using fdopen instead of socket.makefile() because the latter
     # seems to be very slow; note that we need to dup() the file descriptor because
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 08a0f0d8ffb3eccbb957d7e20fee85b445c70b4d..904bd9f2652d39a65e2531acb0a4ccd09ec69dd9 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -80,6 +80,7 @@ class SpecialLengths(object):
     END_OF_DATA_SECTION = -1
     PYTHON_EXCEPTION_THROWN = -2
     TIMING_DATA = -3
+    END_OF_STREAM = -4
 
 
 class Serializer(object):
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 1a8e4150e63c3e40111776e72d35e267e677ad6e..7a2107ec326ee833b07970b4701aef338080a755 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -31,7 +31,7 @@ import tempfile
 import time
 import zipfile
 import random
-from platform import python_implementation
+import threading
 
 if sys.version_info[:2] <= (2, 6):
     try:
@@ -1380,6 +1380,23 @@ class WorkerTests(PySparkTestCase):
         self.assertEqual(sum(range(100)), acc2.value)
         self.assertEqual(sum(range(100)), acc1.value)
 
+    def test_reuse_worker_after_take(self):
+        rdd = self.sc.parallelize(range(100000), 1)
+        self.assertEqual(0, rdd.first())
+
+        def count():
+            try:
+                rdd.count()
+            except Exception:
+                pass
+
+        t = threading.Thread(target=count)
+        t.daemon = True
+        t.start()
+        t.join(5)
+        self.assertTrue(not t.isAlive())
+        self.assertEqual(100000, rdd.count())
+
 
 class SparkSubmitTests(unittest.TestCase):
 
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 8257dddfee1c39188b1afd01e8cb3da20ed310fc..2bdccb5e93f0953fc68f8a057e9dd8311ea2eda0 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -57,7 +57,7 @@ def main(infile, outfile):
         boot_time = time.time()
         split_index = read_int(infile)
         if split_index == -1:  # for unit tests
-            return
+            exit(-1)
 
         # initialize global state
         shuffle.MemoryBytesSpilled = 0
@@ -111,7 +111,6 @@ def main(infile, outfile):
         try:
             write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
             write_with_length(traceback.format_exc(), outfile)
-            outfile.flush()
         except IOError:
             # JVM close the socket
             pass
@@ -131,6 +130,14 @@ def main(infile, outfile):
     for (aid, accum) in _accumulatorRegistry.items():
         pickleSer._write_with_length((aid, accum._value), outfile)
 
+    # check end of stream
+    if read_int(infile) == SpecialLengths.END_OF_STREAM:
+        write_int(SpecialLengths.END_OF_STREAM, outfile)
+    else:
+        # write a different value to tell JVM to not reuse this worker
+        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+        exit(-1)
+
 
 if __name__ == '__main__':
     # Read a local port to connect to from stdin