diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index b1cec0f6472b0ea3ce1181b0d0fabe9498653ae7..8d4a53b4ca9b05d73d5b42f53d6c4e23696b5dce 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -19,26 +19,27 @@ package org.apache.spark.api.python
 
 import java.io._
 import java.net._
-import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections}
-
-import org.apache.spark.input.PortableDataStream
+import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap}
 
 import scala.collection.JavaConversions._
 import scala.collection.mutable
 import scala.language.existentials
 
 import com.google.common.base.Charsets.UTF_8
-
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.io.compress.CompressionCodec
-import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf}
+import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat}
 import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat}
+
 import org.apache.spark._
-import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
+import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
 import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.input.PortableDataStream
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.Utils
 
+import scala.util.control.NonFatal
+
 private[spark] class PythonRDD(
     @transient parent: RDD[_],
     command: Array[Byte],
@@ -341,21 +342,33 @@ private[spark] object PythonRDD extends Logging {
   /**
    * Adapter for calling SparkContext#runJob from Python.
    *
-   * This method will return an iterator of an array that contains all elements in the RDD
+   * This method will serve an iterator of an array that contains all elements in the RDD
    * (effectively a collect()), but allows you to run on a certain subset of partitions,
    * or to enable local execution.
+   *
+   * @return the port number of a local socket which serves the data collected from this job.
    */
   def runJob(
       sc: SparkContext,
       rdd: JavaRDD[Array[Byte]],
       partitions: JArrayList[Int],
-      allowLocal: Boolean): Iterator[Array[Byte]] = {
+      allowLocal: Boolean): Int = {
     type ByteArray = Array[Byte]
     type UnrolledPartition = Array[ByteArray]
     val allPartitions: Array[UnrolledPartition] =
       sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
     val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
-    flattenedPartition.iterator
+    serveIterator(flattenedPartition.iterator,
+      s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}")
+  }
+
+  /**
+   * A helper function to collect an RDD as an iterator, then serve it via socket.
+   *
+   * @return the port number of a local socket which serves the data collected from this job.
+   */
+  def collectAndServe[T](rdd: RDD[T]): Int = {
+    serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
   }
 
   def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
@@ -575,15 +588,44 @@ private[spark] object PythonRDD extends Logging {
     dataOut.write(bytes)
   }
 
-  def writeToFile[T](items: java.util.Iterator[T], filename: String) {
-    import scala.collection.JavaConverters._
-    writeToFile(items.asScala, filename)
-  }
+  /**
+   * Create a socket server and a background thread to serve the data in `items`,
+   *
+   * The socket server can only accept one connection, or close if no connection
+   * in 3 seconds.
+   *
+   * Once a connection comes in, it tries to serialize all the data in `items`
+   * and send them into this connection.
+   *
+   * The thread will terminate after all the data are sent or any exceptions happen.
+   */
+  private def serveIterator[T](items: Iterator[T], threadName: String): Int = {
+    val serverSocket = new ServerSocket(0, 1)
+    serverSocket.setReuseAddress(true)
+    // Close the socket if no connection in 3 seconds
+    serverSocket.setSoTimeout(3000)
+
+    new Thread(threadName) {
+      setDaemon(true)
+      override def run() {
+        try {
+          val sock = serverSocket.accept()
+          val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
+          try {
+            writeIteratorToStream(items, out)
+          } finally {
+            out.close()
+          }
+        } catch {
+          case NonFatal(e) =>
+            logError(s"Error while sending iterator", e)
+        } finally {
+          serverSocket.close()
+        }
+      }
+    }.start()
 
-  def writeToFile[T](items: Iterator[T], filename: String) {
-    val file = new DataOutputStream(new FileOutputStream(filename))
-    writeIteratorToStream(items, file)
-    file.close()
+    serverSocket.getLocalPort
   }
 
   private def getMergedConf(confAsMap: java.util.HashMap[String, String],
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 6011caf9f1c5a35a020b41f9874b892663fad703..78dccc40470e3bdd447145a0d318477aa4922139 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -21,6 +21,8 @@ import sys
 from threading import Lock
 from tempfile import NamedTemporaryFile
 
+from py4j.java_collections import ListConverter
+
 from pyspark import accumulators
 from pyspark.accumulators import Accumulator
 from pyspark.broadcast import Broadcast
@@ -30,13 +32,11 @@ from pyspark.java_gateway import launch_gateway
 from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
     PairDeserializer, AutoBatchedSerializer, NoOpSerializer
 from pyspark.storagelevel import StorageLevel
-from pyspark.rdd import RDD
+from pyspark.rdd import RDD, _load_from_socket
 from pyspark.traceback_utils import CallSite, first_spark_call
 from pyspark.status import StatusTracker
 from pyspark.profiler import ProfilerCollector, BasicProfiler
 
-from py4j.java_collections import ListConverter
-
 
 __all__ = ['SparkContext']
 
@@ -59,7 +59,6 @@ class SparkContext(object):
 
     _gateway = None
     _jvm = None
-    _writeToFile = None
     _next_accum_id = 0
     _active_spark_context = None
     _lock = Lock()
@@ -221,7 +220,6 @@ class SparkContext(object):
             if not SparkContext._gateway:
                 SparkContext._gateway = gateway or launch_gateway()
                 SparkContext._jvm = SparkContext._gateway.jvm
-                SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile
 
             if instance:
                 if (SparkContext._active_spark_context and
@@ -840,8 +838,9 @@ class SparkContext(object):
         # by runJob() in order to avoid having to pass a Python lambda into
         # SparkContext#runJob.
         mappedRDD = rdd.mapPartitions(partitionFunc)
-        it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
-        return list(mappedRDD._collect_iterator_through_file(it))
+        port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions,
+                                          allowLocal)
+        return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))
 
     def show_profiles(self):
         """ Print the profile stats to stdout """
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index cb12fed98c53d172529d6c2b83bcdd7507e31620..bf17f513c0bc3e01ef084bba64d45f2b159333e0 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -19,7 +19,6 @@ import copy
 from collections import defaultdict
 from itertools import chain, ifilter, imap
 import operator
-import os
 import sys
 import shlex
 from subprocess import Popen, PIPE
@@ -29,6 +28,7 @@ import warnings
 import heapq
 import bisect
 import random
+import socket
 from math import sqrt, log, isinf, isnan, pow, ceil
 
 from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
@@ -111,6 +111,17 @@ def _parse_memory(s):
     return int(float(s[:-1]) * units[s[-1].lower()])
 
 
+def _load_from_socket(port, serializer):
+    sock = socket.socket()
+    try:
+        sock.connect(("localhost", port))
+        rf = sock.makefile("rb", 65536)
+        for item in serializer.load_stream(rf):
+            yield item
+    finally:
+        sock.close()
+
+
 class Partitioner(object):
     def __init__(self, numPartitions, partitionFunc):
         self.numPartitions = numPartitions
@@ -698,21 +709,8 @@ class RDD(object):
         Return a list that contains all of the elements in this RDD.
         """
         with SCCallSiteSync(self.context) as css:
-            bytesInJava = self._jrdd.collect().iterator()
-        return list(self._collect_iterator_through_file(bytesInJava))
-
-    def _collect_iterator_through_file(self, iterator):
-        # Transferring lots of data through Py4J can be slow because
-        # socket.readline() is inefficient.  Instead, we'll dump the data to a
-        # file and read it back.
-        tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
-        tempFile.close()
-        self.ctx._writeToFile(iterator, tempFile.name)
-        # Read the data into Python and deserialize it:
-        with open(tempFile.name, 'rb') as tempFile:
-            for item in self._jrdd_deserializer.load_stream(tempFile):
-                yield item
-        os.unlink(tempFile.name)
+            port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
+        return list(_load_from_socket(port, self._jrdd_deserializer))
 
     def reduce(self, f):
         """
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 5c3b7377c33b53168df1441753da3e095c4a6303..e8ce4547455a5ce491f69da6c5447d1565bfae55 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -19,13 +19,11 @@ import sys
 import itertools
 import warnings
 import random
-import os
-from tempfile import NamedTemporaryFile
 
 from py4j.java_collections import ListConverter, MapConverter
 
 from pyspark.context import SparkContext
-from pyspark.rdd import RDD
+from pyspark.rdd import RDD, _load_from_socket
 from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
 from pyspark.storagelevel import StorageLevel
 from pyspark.traceback_utils import SCCallSiteSync
@@ -310,14 +308,8 @@ class DataFrame(object):
         [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
         """
         with SCCallSiteSync(self._sc) as css:
-            bytesInJava = self._jdf.javaToPython().collect().iterator()
-        tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
-        tempFile.close()
-        self._sc._writeToFile(bytesInJava, tempFile.name)
-        # Read the data into Python and deserialize it:
-        with open(tempFile.name, 'rb') as tempFile:
-            rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
-        os.unlink(tempFile.name)
+            port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd())
+        rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
         cls = _create_cls(self.schema)
         return [cls(r) for r in rs]