diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index dc012cc381346643d05db8954b5876e12d764bb3..fc4812753d005b0087203f32090fc0b913844d4b 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -42,9 +42,13 @@ class TaskContext(
   // List of callback functions to execute when the task completes.
   @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]
 
+  // Set to true when the task is completed, before the onCompleteCallbacks are executed.
+  @volatile var completed: Boolean = false
+
   /**
    * Add a callback function to be executed on task completion. An example use
    * is for HadoopRDD to register a callback to close the input stream.
+   * Will be called in any situation - success, failure, or cancellation.
    * @param f Callback function.
    */
   def addOnCompleteCallback(f: () => Unit) {
@@ -52,6 +56,7 @@ class TaskContext(
   }
 
   def executeOnCompleteCallbacks() {
+    completed = true
     // Process complete callbacks in the reverse order of registration
     onCompleteCallbacks.reverse.foreach{_()}
   }
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 61407007087c6820216c262855ee624b7e8682ff..fecd9762f3f608190a0ce89fe72c928057fce0a3 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -56,122 +56,37 @@ private[spark] class PythonRDD[T: ClassTag](
     val env = SparkEnv.get
     val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
 
-    // Ensure worker socket is closed on task completion. Closing sockets is idempotent.
-    context.addOnCompleteCallback(() =>
+    // Start a thread to feed the process input from our parent's iterator
+    val writerThread = new WriterThread(env, worker, split, context)
+
+    context.addOnCompleteCallback { () =>
+      writerThread.shutdownOnTaskCompletion()
+
+      // Cleanup the worker socket. This will also cause the Python worker to exit.
       try {
         worker.close()
       } catch {
         case e: Exception => logWarning("Failed to close worker socket", e)
       }
-    )
-
-    @volatile var readerException: Exception = null
-
-    // Start a thread to feed the process input from our parent's iterator
-    new Thread("stdin writer for " + pythonExec) {
-      override def run() {
-        try {
-          SparkEnv.set(env)
-          val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
-          val dataOut = new DataOutputStream(stream)
-          // Partition index
-          dataOut.writeInt(split.index)
-          // sparkFilesDir
-          PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
-          // Broadcast variables
-          dataOut.writeInt(broadcastVars.length)
-          for (broadcast <- broadcastVars) {
-            dataOut.writeLong(broadcast.id)
-            dataOut.writeInt(broadcast.value.length)
-            dataOut.write(broadcast.value)
-          }
-          // Python includes (*.zip and *.egg files)
-          dataOut.writeInt(pythonIncludes.length)
-          for (include <- pythonIncludes) {
-            PythonRDD.writeUTF(include, dataOut)
-          }
-          dataOut.flush()
-          // Serialized command:
-          dataOut.writeInt(command.length)
-          dataOut.write(command)
-          // Data values
-          PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
-          dataOut.flush()
-          worker.shutdownOutput()
-        } catch {
-
-          case e: java.io.FileNotFoundException =>
-            readerException = e
-            Try(worker.shutdownOutput()) // kill Python worker process
-
-          case e: IOException =>
-            // This can happen for legitimate reasons if the Python code stops returning data
-            // before we are done passing elements through, e.g., for take(). Just log a message to
-            // say it happened (as it could also be hiding a real IOException from a data source).
-            logInfo("stdin writer to Python finished early (may not be an error)", e)
-
-          case e: Exception =>
-            // We must avoid throwing exceptions here, because the thread uncaught exception handler
-            // will kill the whole executor (see Executor).
-            readerException = e
-            Try(worker.shutdownOutput()) // kill Python worker process
-        }
-      }
-    }.start()
-
-    // Necessary to distinguish between a task that has failed and a task that is finished
-    @volatile var complete: Boolean = false
-
-    // It is necessary to have a monitor thread for python workers if the user cancels with
-    // interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
-    // threads can block indefinitely.
-    new Thread(s"Worker Monitor for $pythonExec") {
-      override def run() {
-        // Kill the worker if it is interrupted or completed
-        // When a python task completes, the context is always set to interupted
-        while (!context.interrupted) {
-          Thread.sleep(2000)
-        }
-        if (!complete) {
-          try {
-            logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
-            env.destroyPythonWorker(pythonExec, envVars.toMap)
-          } catch {
-            case e: Exception =>
-              logError("Exception when trying to kill worker", e)
-          }
-        }
-      }
-    }.start()
-
-    /*
-     * Partial fix for SPARK-1019: Attempts to stop reading the input stream since
-     * other completion callbacks might invalidate the input. Because interruption
-     * is not synchronous this still leaves a potential race where the interruption is
-     * processed only after the stream becomes invalid.
-     */
-    context.addOnCompleteCallback{ () =>
-      complete = true // Indicate that the task has completed successfully
-      context.interrupted = true
     }
 
+    writerThread.start()
+    new MonitorThread(env, worker, context).start()
+
     // Return an iterator that read lines from the process's stdout
     val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
     val stdoutIterator = new Iterator[Array[Byte]] {
       def next(): Array[Byte] = {
         val obj = _nextObj
         if (hasNext) {
-          // FIXME: can deadlock if worker is waiting for us to
-          // respond to current message (currently irrelevant because
-          // output is shutdown before we read any input)
           _nextObj = read()
         }
         obj
       }
 
       private def read(): Array[Byte] = {
-        if (readerException != null) {
-          throw readerException
+        if (writerThread.exception.isDefined) {
+          throw writerThread.exception.get
         }
         try {
           stream.readInt() match {
@@ -190,13 +105,14 @@ private[spark] class PythonRDD[T: ClassTag](
               val total = finishTime - startTime
               logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
                 init, finish))
-              read
+              read()
             case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
               // Signals that an exception has been thrown in python
               val exLength = stream.readInt()
               val obj = new Array[Byte](exLength)
               stream.readFully(obj)
-              throw new PythonException(new String(obj, "utf-8"), readerException)
+              throw new PythonException(new String(obj, "utf-8"),
+                writerThread.exception.getOrElse(null))
             case SpecialLengths.END_OF_DATA_SECTION =>
               // We've finished the data section of the output, but we can still
               // read some accumulator updates:
@@ -210,10 +126,15 @@ private[spark] class PythonRDD[T: ClassTag](
               Array.empty[Byte]
           }
         } catch {
-          case e: Exception if readerException != null =>
+
+          case e: Exception if context.interrupted =>
+            logDebug("Exception thrown after task interruption", e)
+            throw new TaskKilledException
+
+          case e: Exception if writerThread.exception.isDefined =>
             logError("Python worker exited unexpectedly (crashed)", e)
-            logError("Python crash may have been caused by prior exception:", readerException)
-            throw readerException
+            logError("This may have been caused by a prior exception:", writerThread.exception.get)
+            throw writerThread.exception.get
 
           case eof: EOFException =>
             throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
@@ -224,10 +145,100 @@ private[spark] class PythonRDD[T: ClassTag](
 
       def hasNext = _nextObj.length != 0
     }
-    stdoutIterator
+    new InterruptibleIterator(context, stdoutIterator)
   }
 
   val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
+
+  /**
+   * The thread responsible for writing the data from the PythonRDD's parent iterator to the
+   * Python process.
+   */
+  class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext)
+    extends Thread(s"stdout writer for $pythonExec") {
+
+    @volatile private var _exception: Exception = null
+
+    setDaemon(true)
+
+    /** Contains the exception thrown while writing the parent iterator to the Python process. */
+    def exception: Option[Exception] = Option(_exception)
+
+    /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
+    def shutdownOnTaskCompletion() {
+      assert(context.completed)
+      this.interrupt()
+    }
+
+    override def run() {
+      try {
+        SparkEnv.set(env)
+        val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
+        val dataOut = new DataOutputStream(stream)
+        // Partition index
+        dataOut.writeInt(split.index)
+        // sparkFilesDir
+        PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut)
+        // Broadcast variables
+        dataOut.writeInt(broadcastVars.length)
+        for (broadcast <- broadcastVars) {
+          dataOut.writeLong(broadcast.id)
+          dataOut.writeInt(broadcast.value.length)
+          dataOut.write(broadcast.value)
+        }
+        // Python includes (*.zip and *.egg files)
+        dataOut.writeInt(pythonIncludes.length)
+        for (include <- pythonIncludes) {
+          PythonRDD.writeUTF(include, dataOut)
+        }
+        dataOut.flush()
+        // Serialized command:
+        dataOut.writeInt(command.length)
+        dataOut.write(command)
+        // Data values
+        PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
+        dataOut.flush()
+      } catch {
+        case e: Exception if context.completed || context.interrupted =>
+          logDebug("Exception thrown after task completion (likely due to cleanup)", e)
+
+        case e: Exception =>
+          // We must avoid throwing exceptions here, because the thread uncaught exception handler
+          // will kill the whole executor (see org.apache.spark.executor.Executor).
+          _exception = e
+      } finally {
+        Try(worker.shutdownOutput()) // kill Python worker process
+      }
+    }
+  }
+
+  /**
+   * It is necessary to have a monitor thread for python workers if the user cancels with
+   * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
+   * threads can block indefinitely.
+   */
+  class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext)
+    extends Thread(s"Worker Monitor for $pythonExec") {
+
+    setDaemon(true)
+
+    override def run() {
+      // Kill the worker if it is interrupted, checking until task completion.
+      // TODO: This has a race condition if interruption occurs, as completed may still become true.
+      while (!context.interrupted && !context.completed) {
+        Thread.sleep(2000)
+      }
+      if (!context.completed) {
+        try {
+          logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
+          env.destroyPythonWorker(pythonExec, envVars.toMap)
+        } catch {
+          case e: Exception =>
+            logError("Exception when trying to kill worker", e)
+        }
+      }
+    }
+  }
 }
 
 /** Thrown for exceptions in user Python code. */
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 02b62de7e36b6ece3ca6ac884ef67f18a07b62dc..2259df0b56badeaf128f8ef79e7fa7ca2dbcda41 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -17,11 +17,13 @@
 
 package org.apache.spark.scheduler
 
+import scala.language.existentials
+
 import java.io._
 import java.util.zip.{GZIPInputStream, GZIPOutputStream}
 
 import scala.collection.mutable.HashMap
-import scala.language.existentials
+import scala.util.Try
 
 import org.apache.spark._
 import org.apache.spark.executor.ShuffleWriteMetrics
@@ -196,7 +198,11 @@ private[spark] class ShuffleMapTask(
     } finally {
       // Release the writers back to the shuffle block manager.
       if (shuffle != null && shuffle.writers != null) {
-        shuffle.releaseWriters(success)
+        try {
+          shuffle.releaseWriters(success)
+        } catch {
+          case e: Exception => logError("Failed to release shuffle writers", e)
+        }
       }
       // Execute the callbacks on task completion.
       context.executeOnCompleteCallbacks()
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index c7dc85ea03544c696ddfb620800d281d3b4bbf28..cac133d0fcf6c21ba4bf567c0a0e3e4533ce6b32 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -453,7 +453,7 @@ class SparkContext(object):
         >>> lock = threading.Lock()
         >>> def map_func(x):
         ...     sleep(100)
-        ...     return x * x
+        ...     raise Exception("Task should have been cancelled")
         >>> def start_job(x):
         ...     global result
         ...     try:
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index eb18ec08c9139ceff767414f3da4d82053a03237..b2f226a55ec1367042b92aff861472c06f90c98a 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -74,6 +74,17 @@ def worker(listen_sock):
                 raise
     signal.signal(SIGCHLD, handle_sigchld)
 
+    # Blocks until the socket is closed by draining the input stream
+    # until it raises an exception or returns EOF.
+    def waitSocketClose(sock):
+        try:
+            while True:
+                # Empty string is returned upon EOF (and only then).
+                if sock.recv(4096) == '':
+                    return
+        except:
+            pass
+
     # Handle clients
     while not should_exit():
         # Wait until a client arrives or we have to exit
@@ -105,7 +116,8 @@ def worker(listen_sock):
                     exit_code = exc.code
                 finally:
                     outfile.flush()
-                    sock.close()
+                    # The Scala side will close the socket upon task completion.
+                    waitSocketClose(sock)
                     os._exit(compute_real_exit_code(exit_code))
             else:
                 sock.close()