diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index dd95e406f2a8e662f21154cc801e5766a6da0b69..009ed6477584411bd4de7b6e179733c9be90e6b5 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -108,6 +108,14 @@ class SparkEnv (
       pythonWorkers.get(key).foreach(_.stopWorker(worker))
     }
   }
+
+  private[spark]
+  def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
+    synchronized {
+      val key = (pythonExec, envVars)
+      pythonWorkers.get(key).foreach(_.releaseWorker(worker))
+    }
+  }
 }
 
 object SparkEnv extends Logging {
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index ae8010300a500d08e76bb8b92418c2166b6c3d05..ca8eef5f99edfdb1fef0dbbb2172b29b43ad5c2e 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -23,6 +23,7 @@ import java.nio.charset.Charset
 import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
 
 import scala.collection.JavaConversions._
+import scala.collection.mutable
 import scala.language.existentials
 import scala.reflect.ClassTag
 import scala.util.{Try, Success, Failure}
@@ -52,6 +53,7 @@ private[spark] class PythonRDD(
   extends RDD[Array[Byte]](parent) {
 
   val bufferSize = conf.getInt("spark.buffer.size", 65536)
+  val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
 
   override def getPartitions = parent.partitions
 
@@ -63,19 +65,26 @@ private[spark] class PythonRDD(
     val localdir = env.blockManager.diskBlockManager.localDirs.map(
       f => f.getPath()).mkString(",")
     envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread
+    if (reuse_worker) {
+      envVars += ("SPARK_REUSE_WORKER" -> "1")
+    }
     val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
 
     // Start a thread to feed the process input from our parent's iterator
     val writerThread = new WriterThread(env, worker, split, context)
 
+    var complete_cleanly = false
     context.addTaskCompletionListener { context =>
       writerThread.shutdownOnTaskCompletion()
-
-      // Cleanup the worker socket. This will also cause the Python worker to exit.
-      try {
-        worker.close()
-      } catch {
-        case e: Exception => logWarning("Failed to close worker socket", e)
+      if (reuse_worker && complete_cleanly) {
+        env.releasePythonWorker(pythonExec, envVars.toMap, worker)
+      } else {
+        try {
+          worker.close()
+        } catch {
+          case e: Exception =>
+            logWarning("Failed to close worker socket", e)
+        }
       }
     }
 
@@ -133,6 +142,7 @@ private[spark] class PythonRDD(
                 stream.readFully(update)
                 accumulator += Collections.singletonList(update)
               }
+               complete_cleanly = true
               null
           }
         } catch {
@@ -195,11 +205,26 @@ private[spark] class PythonRDD(
           PythonRDD.writeUTF(include, dataOut)
         }
         // Broadcast variables
-        dataOut.writeInt(broadcastVars.length)
+        val oldBids = PythonRDD.getWorkerBroadcasts(worker)
+        val newBids = broadcastVars.map(_.id).toSet
+        // number of different broadcasts
+        val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size
+        dataOut.writeInt(cnt)
+        for (bid <- oldBids) {
+          if (!newBids.contains(bid)) {
+            // remove the broadcast from worker
+            dataOut.writeLong(- bid - 1)  // bid >= 0
+            oldBids.remove(bid)
+          }
+        }
         for (broadcast <- broadcastVars) {
-          dataOut.writeLong(broadcast.id)
-          dataOut.writeInt(broadcast.value.length)
-          dataOut.write(broadcast.value)
+          if (!oldBids.contains(broadcast.id)) {
+            // send new broadcast
+            dataOut.writeLong(broadcast.id)
+            dataOut.writeInt(broadcast.value.length)
+            dataOut.write(broadcast.value)
+            oldBids.add(broadcast.id)
+          }
         }
         dataOut.flush()
         // Serialized command:
@@ -207,17 +232,18 @@ private[spark] class PythonRDD(
         dataOut.write(command)
         // Data values
         PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
+        dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
         dataOut.flush()
       } catch {
         case e: Exception if context.isCompleted || context.isInterrupted =>
           logDebug("Exception thrown after task completion (likely due to cleanup)", e)
+          worker.shutdownOutput()
 
         case e: Exception =>
           // We must avoid throwing exceptions here, because the thread uncaught exception handler
           // will kill the whole executor (see org.apache.spark.executor.Executor).
           _exception = e
-      } finally {
-        Try(worker.shutdownOutput()) // kill Python worker process
+          worker.shutdownOutput()
       }
     }
   }
@@ -278,6 +304,14 @@ private object SpecialLengths {
 private[spark] object PythonRDD extends Logging {
   val UTF8 = Charset.forName("UTF-8")
 
+  // remember the broadcasts sent to each worker
+  private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
+  private def getWorkerBroadcasts(worker: Socket) = {
+    synchronized {
+      workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
+    }
+  }
+
   /**
    * Adapter for calling SparkContext#runJob from Python.
    *
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 4c4796f6c59bada6902846443cbdbc958fa6b934..71bdf0fe1b917f00a3fd88b6ea703b98c19d0777 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -40,7 +40,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
   var daemon: Process = null
   val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
   var daemonPort: Int = 0
-  var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
+  val daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
+  val idleWorkers = new mutable.Queue[Socket]()
+  var lastActivity = 0L
+  new MonitorThread().start()
 
   var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
 
@@ -51,6 +54,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
 
   def create(): Socket = {
     if (useDaemon) {
+      synchronized {
+        if (idleWorkers.size > 0) {
+          return idleWorkers.dequeue()
+        }
+      }
       createThroughDaemon()
     } else {
       createSimpleWorker()
@@ -199,9 +207,44 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
     }
   }
 
+  /**
+   * Monitor all the idle workers, kill them after timeout.
+   */
+  private class MonitorThread extends Thread(s"Idle Worker Monitor for $pythonExec") {
+
+    setDaemon(true)
+
+    override def run() {
+      while (true) {
+        synchronized {
+          if (lastActivity + IDLE_WORKER_TIMEOUT_MS < System.currentTimeMillis()) {
+            cleanupIdleWorkers()
+            lastActivity = System.currentTimeMillis()
+          }
+        }
+        Thread.sleep(10000)
+      }
+    }
+  }
+
+  private def cleanupIdleWorkers() {
+    while (idleWorkers.length > 0) {
+      val worker = idleWorkers.dequeue()
+      try {
+        // the worker will exit after closing the socket
+        worker.close()
+      } catch {
+        case e: Exception =>
+          logWarning("Failed to close worker socket", e)
+      }
+    }
+  }
+
   private def stopDaemon() {
     synchronized {
       if (useDaemon) {
+        cleanupIdleWorkers()
+
         // Request shutdown of existing daemon by sending SIGTERM
         if (daemon != null) {
           daemon.destroy()
@@ -220,23 +263,43 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
   }
 
   def stopWorker(worker: Socket) {
-    if (useDaemon) {
-      if (daemon != null) {
-        daemonWorkers.get(worker).foreach { pid =>
-          // tell daemon to kill worker by pid
-          val output = new DataOutputStream(daemon.getOutputStream)
-          output.writeInt(pid)
-          output.flush()
-          daemon.getOutputStream.flush()
+    synchronized {
+      if (useDaemon) {
+        if (daemon != null) {
+          daemonWorkers.get(worker).foreach { pid =>
+            // tell daemon to kill worker by pid
+            val output = new DataOutputStream(daemon.getOutputStream)
+            output.writeInt(pid)
+            output.flush()
+            daemon.getOutputStream.flush()
+          }
         }
+      } else {
+        simpleWorkers.get(worker).foreach(_.destroy())
       }
-    } else {
-      simpleWorkers.get(worker).foreach(_.destroy())
     }
     worker.close()
   }
+
+  def releaseWorker(worker: Socket) {
+    if (useDaemon) {
+      synchronized {
+        lastActivity = System.currentTimeMillis()
+        idleWorkers.enqueue(worker)
+      }
+    } else {
+      // Cleanup the worker socket. This will also cause the Python worker to exit.
+      try {
+        worker.close()
+      } catch {
+        case e: Exception =>
+          logWarning("Failed to close worker socket", e)
+      }
+    }
+  }
 }
 
 private object PythonWorkerFactory {
   val PROCESS_WAIT_TIMEOUT_MS = 10000
+  val IDLE_WORKER_TIMEOUT_MS = 60000  // kill idle workers after 1 minute
 }
diff --git a/docs/configuration.md b/docs/configuration.md
index 36178efb971039c6b03614a523f60ca9453281d7..af16489a4428145079ace61eb25452d642190760 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -206,6 +206,16 @@ Apart from these, the following properties are also available, and may be useful
     used during aggregation goes above this amount, it will spill the data into disks.
   </td>
 </tr>
+<tr>
+  <td><code>spark.python.worker.reuse</code></td>
+  <td>true</td>
+  <td>
+    Reuse Python worker or not. If yes, it will use a fixed number of Python workers,
+    does not need to fork() a Python process for every tasks. It will be very useful
+    if there is large broadcast, then the broadcast will not be needed to transfered
+    from JVM to Python worker for every task.
+  </td>
+</tr>
 <tr>
   <td><code>spark.executorEnv.[EnvironmentVariableName]</code></td>
   <td>(none)</td>
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 15445abf67147c810cad9a87f2057a37f0bea91b..64d6202acb27deded1385dd7609ed8b719a93dd6 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -23,6 +23,7 @@ import socket
 import sys
 import traceback
 import time
+import gc
 from errno import EINTR, ECHILD, EAGAIN
 from socket import AF_INET, SOCK_STREAM, SOMAXCONN
 from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
@@ -46,17 +47,6 @@ def worker(sock):
     signal.signal(SIGCHLD, SIG_DFL)
     signal.signal(SIGTERM, SIG_DFL)
 
-    # Blocks until the socket is closed by draining the input stream
-    # until it raises an exception or returns EOF.
-    def waitSocketClose(sock):
-        try:
-            while True:
-                # Empty string is returned upon EOF (and only then).
-                if sock.recv(4096) == '':
-                    return
-        except:
-            pass
-
     # Read the socket using fdopen instead of socket.makefile() because the latter
     # seems to be very slow; note that we need to dup() the file descriptor because
     # otherwise writes also cause a seek that makes us miss data on the read side.
@@ -64,17 +54,13 @@ def worker(sock):
     outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
     exit_code = 0
     try:
-        # Acknowledge that the fork was successful
-        write_int(os.getpid(), outfile)
-        outfile.flush()
         worker_main(infile, outfile)
     except SystemExit as exc:
-        exit_code = exc.code
+        exit_code = compute_real_exit_code(exc.code)
     finally:
         outfile.flush()
-        # The Scala side will close the socket upon task completion.
-        waitSocketClose(sock)
-        os._exit(compute_real_exit_code(exit_code))
+        if exit_code:
+            os._exit(exit_code)
 
 
 # Cleanup zombie children
@@ -111,6 +97,8 @@ def manager():
     signal.signal(SIGTERM, handle_sigterm)  # Gracefully exit on SIGTERM
     signal.signal(SIGHUP, SIG_IGN)  # Don't die on SIGHUP
 
+    reuse = os.environ.get("SPARK_REUSE_WORKER")
+
     # Initialization complete
     try:
         while True:
@@ -163,7 +151,19 @@ def manager():
                     # in child process
                     listen_sock.close()
                     try:
-                        worker(sock)
+                        # Acknowledge that the fork was successful
+                        outfile = sock.makefile("w")
+                        write_int(os.getpid(), outfile)
+                        outfile.flush()
+                        outfile.close()
+                        while True:
+                            worker(sock)
+                            if not reuse:
+                                # wait for closing
+                                while sock.recv(1024):
+                                    pass
+                                break
+                            gc.collect()
                     except:
                         traceback.print_exc()
                         os._exit(1)
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index bb60d3d0c8463bfb431109146d3a0b32f4af9cd5..68f6033616726d06f187350ba1d51487169d74a1 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -21,7 +21,7 @@ import numpy
 from numpy import ndarray, float64, int64, int32, array_equal, array
 from pyspark import SparkContext, RDD
 from pyspark.mllib.linalg import SparseVector
-from pyspark.serializers import Serializer
+from pyspark.serializers import FramedSerializer
 
 
 """
@@ -451,18 +451,16 @@ def _serialize_rating(r):
     return ba
 
 
-class RatingDeserializer(Serializer):
+class RatingDeserializer(FramedSerializer):
 
-    def loads(self, stream):
-        length = struct.unpack("!i", stream.read(4))[0]
-        ba = stream.read(length)
-        res = ndarray(shape=(3, ), buffer=ba, dtype=float64, offset=4)
+    def loads(self, string):
+        res = ndarray(shape=(3, ), buffer=string, dtype=float64, offset=4)
         return int(res[0]), int(res[1]), res[2]
 
     def load_stream(self, stream):
         while True:
             try:
-                yield self.loads(stream)
+                yield self._read_with_length(stream)
             except struct.error:
                 return
             except EOFError:
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index a5f9341e819a98f22f94e90d4ac9cc697b815355..ec3c6f055441d8e9850df188c59f8c95b0a97fb7 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -144,6 +144,8 @@ class FramedSerializer(Serializer):
 
     def _read_with_length(self, stream):
         length = read_int(stream)
+        if length == SpecialLengths.END_OF_DATA_SECTION:
+            raise EOFError
         obj = stream.read(length)
         if obj == "":
             raise EOFError
@@ -438,6 +440,8 @@ class UTF8Deserializer(Serializer):
 
     def loads(self, stream):
         length = read_int(stream)
+        if length == SpecialLengths.END_OF_DATA_SECTION:
+            raise EOFError
         s = stream.read(length)
         return s.decode("utf-8") if self.use_unicode else s
 
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index b687d695b01c4e704a167c06fd67764316c167e0..747cd1767de7bee0a228f08b4b9d69c1b3bec901 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1222,11 +1222,46 @@ class TestWorker(PySparkTestCase):
         except OSError:
             self.fail("daemon had been killed")
 
+        # run a normal job
+        rdd = self.sc.parallelize(range(100), 1)
+        self.assertEqual(100, rdd.map(str).count())
+
     def test_fd_leak(self):
         N = 1100  # fd limit is 1024 by default
         rdd = self.sc.parallelize(range(N), N)
         self.assertEquals(N, rdd.count())
 
+    def test_after_exception(self):
+        def raise_exception(_):
+            raise Exception()
+        rdd = self.sc.parallelize(range(100), 1)
+        self.assertRaises(Exception, lambda: rdd.foreach(raise_exception))
+        self.assertEqual(100, rdd.map(str).count())
+
+    def test_after_jvm_exception(self):
+        tempFile = tempfile.NamedTemporaryFile(delete=False)
+        tempFile.write("Hello World!")
+        tempFile.close()
+        data = self.sc.textFile(tempFile.name, 1)
+        filtered_data = data.filter(lambda x: True)
+        self.assertEqual(1, filtered_data.count())
+        os.unlink(tempFile.name)
+        self.assertRaises(Exception, lambda: filtered_data.count())
+
+        rdd = self.sc.parallelize(range(100), 1)
+        self.assertEqual(100, rdd.map(str).count())
+
+    def test_accumulator_when_reuse_worker(self):
+        from pyspark.accumulators import INT_ACCUMULATOR_PARAM
+        acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
+        self.sc.parallelize(range(100), 20).foreach(lambda x: acc1.add(x))
+        self.assertEqual(sum(range(100)), acc1.value)
+
+        acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
+        self.sc.parallelize(range(100), 20).foreach(lambda x: acc2.add(x))
+        self.assertEqual(sum(range(100)), acc2.value)
+        self.assertEqual(sum(range(100)), acc1.value)
+
 
 class TestSparkSubmit(unittest.TestCase):
 
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 6805063e067988feddc86dfe37731994472025e5..61b8a74d060e81f16866f3ee5ee770a9acdab14d 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -69,9 +69,14 @@ def main(infile, outfile):
         ser = CompressedSerializer(pickleSer)
         for _ in range(num_broadcast_variables):
             bid = read_long(infile)
-            value = ser._read_with_length(infile)
-            _broadcastRegistry[bid] = Broadcast(bid, value)
+            if bid >= 0:
+                value = ser._read_with_length(infile)
+                _broadcastRegistry[bid] = Broadcast(bid, value)
+            else:
+                bid = - bid - 1
+                _broadcastRegistry.remove(bid)
 
+        _accumulatorRegistry.clear()
         command = pickleSer._read_with_length(infile)
         (func, deserializer, serializer) = command
         init_time = time.time()