diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 0d5913ec607ec59258df5c1db18d883cd7347469..eb0b0db0cc2efb9331b12c0e6e09778d115a7883 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -75,7 +75,7 @@ private[spark] class PythonRDD[T: ClassManifest](
           // Partition index
           dataOut.writeInt(split.index)
           // sparkFilesDir
-          PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
+          dataOut.writeUTF(SparkFiles.getRootDirectory)
           // Broadcast variables
           dataOut.writeInt(broadcastVars.length)
           for (broadcast <- broadcastVars) {
@@ -85,9 +85,7 @@ private[spark] class PythonRDD[T: ClassManifest](
           }
           // Python includes (*.zip and *.egg files)
           dataOut.writeInt(pythonIncludes.length)
-          for (f <- pythonIncludes) {
-            PythonRDD.writeAsPickle(f, dataOut)
-          }
+          pythonIncludes.foreach(dataOut.writeUTF)
           dataOut.flush()
           // Serialized user code
           for (elem <- command) {
@@ -96,7 +94,7 @@ private[spark] class PythonRDD[T: ClassManifest](
           printOut.flush()
           // Data values
           for (elem <- parent.iterator(split, context)) {
-            PythonRDD.writeAsPickle(elem, dataOut)
+            PythonRDD.writeToStream(elem, dataOut)
           }
           dataOut.flush()
           printOut.flush()
@@ -205,60 +203,7 @@ private object SpecialLengths {
 
 private[spark] object PythonRDD {
 
-  /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */
-  def stripPickle(arr: Array[Byte]) : Array[Byte] = {
-    arr.slice(2, arr.length - 1)
-  }
-
-  /**
-   * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
-   * The data format is a 32-bit integer representing the pickled object's length (in bytes),
-   * followed by the pickled data.
-   *
-   * Pickle module:
-   *
-   *    http://docs.python.org/2/library/pickle.html
-   *
-   * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules:
-   *
-   *    http://hg.python.org/cpython/file/2.6/Lib/pickle.py
-   *    http://hg.python.org/cpython/file/2.6/Lib/pickletools.py
-   *
-   * @param elem the object to write
-   * @param dOut a data output stream
-   */
-  def writeAsPickle(elem: Any, dOut: DataOutputStream) {
-    if (elem.isInstanceOf[Array[Byte]]) {
-      val arr = elem.asInstanceOf[Array[Byte]]
-      dOut.writeInt(arr.length)
-      dOut.write(arr)
-    } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
-      val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
-      val length = t._1.length + t._2.length - 3 - 3 + 4  // stripPickle() removes 3 bytes
-      dOut.writeInt(length)
-      dOut.writeByte(Pickle.PROTO)
-      dOut.writeByte(Pickle.TWO)
-      dOut.write(PythonRDD.stripPickle(t._1))
-      dOut.write(PythonRDD.stripPickle(t._2))
-      dOut.writeByte(Pickle.TUPLE2)
-      dOut.writeByte(Pickle.STOP)
-    } else if (elem.isInstanceOf[String]) {
-      // For uniformity, strings are wrapped into Pickles.
-      val s = elem.asInstanceOf[String].getBytes("UTF-8")
-      val length = 2 + 1 + 4 + s.length + 1
-      dOut.writeInt(length)
-      dOut.writeByte(Pickle.PROTO)
-      dOut.writeByte(Pickle.TWO)
-      dOut.write(Pickle.BINUNICODE)
-      dOut.writeInt(Integer.reverseBytes(s.length))
-      dOut.write(s)
-      dOut.writeByte(Pickle.STOP)
-    } else {
-      throw new SparkException("Unexpected RDD type")
-    }
-  }
-
-  def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
+  def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
   JavaRDD[Array[Byte]] = {
     val file = new DataInputStream(new FileInputStream(filename))
     val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
@@ -276,15 +221,46 @@ private[spark] object PythonRDD {
     JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
   }
 
-  def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
+  def writeStringAsPickle(elem: String, dOut: DataOutputStream) {
+    val s = elem.getBytes("UTF-8")
+    val length = 2 + 1 + 4 + s.length + 1
+    dOut.writeInt(length)
+    dOut.writeByte(Pickle.PROTO)
+    dOut.writeByte(Pickle.TWO)
+    dOut.write(Pickle.BINUNICODE)
+    dOut.writeInt(Integer.reverseBytes(s.length))
+    dOut.write(s)
+    dOut.writeByte(Pickle.STOP)
+  }
+
+  def writeToStream(elem: Any, dataOut: DataOutputStream) {
+    elem match {
+      case bytes: Array[Byte] =>
+        dataOut.writeInt(bytes.length)
+        dataOut.write(bytes)
+      case pair: (Array[Byte], Array[Byte]) =>
+        dataOut.writeInt(pair._1.length)
+        dataOut.write(pair._1)
+        dataOut.writeInt(pair._2.length)
+        dataOut.write(pair._2)
+      case str: String =>
+        // Until we've implemented full custom serializer support, we need to return
+        // strings as Pickles to properly support union() and cartesian():
+        writeStringAsPickle(str, dataOut)
+      case other =>
+        throw new SparkException("Unexpected element type " + other.getClass)
+    }
+  }
+
+  def writeToFile[T](items: java.util.Iterator[T], filename: String) {
     import scala.collection.JavaConverters._
-    writeIteratorToPickleFile(items.asScala, filename)
+    writeToFile(items.asScala, filename)
   }
 
-  def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) {
+  def writeToFile[T](items: Iterator[T], filename: String) {
     val file = new DataOutputStream(new FileOutputStream(filename))
     for (item <- items) {
-      writeAsPickle(item, file)
+      writeToStream(item, file)
     }
     file.close()
   }
@@ -300,10 +276,6 @@ private object Pickle {
   val TWO: Byte = 0x02.toByte
   val BINUNICODE: Byte = 'X'
   val STOP: Byte = '.'
-  val TUPLE2: Byte = 0x86.toByte
-  val EMPTY_LIST: Byte = ']'
-  val MARK: Byte = '('
-  val APPENDS: Byte = 'e'
 }
 
 private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index a7ca8bc888c6759aff5784d26ad7df015d2fe2f4..0fec1a6bf67d4d2afeeb06cd605c64f3055f722b 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -42,7 +42,7 @@ class SparkContext(object):
 
     _gateway = None
     _jvm = None
-    _writeIteratorToPickleFile = None
+    _writeToFile = None
     _takePartition = None
     _next_accum_id = 0
     _active_spark_context = None
@@ -125,8 +125,8 @@ class SparkContext(object):
             if not SparkContext._gateway:
                 SparkContext._gateway = launch_gateway()
                 SparkContext._jvm = SparkContext._gateway.jvm
-                SparkContext._writeIteratorToPickleFile = \
-                    SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
+                SparkContext._writeToFile = \
+                    SparkContext._jvm.PythonRDD.writeToFile
                 SparkContext._takePartition = \
                     SparkContext._jvm.PythonRDD.takePartition
 
@@ -190,8 +190,8 @@ class SparkContext(object):
         for x in c:
             write_with_length(dump_pickle(x), tempFile)
         tempFile.close()
-        readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
-        jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
+        readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
+        jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
         return RDD(jrdd, self)
 
     def textFile(self, name, minSplits=None):
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 7019fb8beefc8f5c8c2060899e2c28e8ae10643e..d3c4d13a1e0a28d5a66abd97442f67fba1686374 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -54,6 +54,7 @@ class RDD(object):
         self.is_checkpointed = False
         self.ctx = ctx
         self._partitionFunc = None
+        self._stage_input_is_pairs = False
 
     @property
     def context(self):
@@ -344,6 +345,7 @@ class RDD(object):
                     yield pair
             else:
                 yield pair
+        java_cartesian._stage_input_is_pairs = True
         return java_cartesian.flatMap(unpack_batches)
 
     def groupBy(self, f, numPartitions=None):
@@ -391,8 +393,8 @@ class RDD(object):
         """
         Return a list that contains all of the elements in this RDD.
         """
-        picklesInJava = self._jrdd.collect().iterator()
-        return list(self._collect_iterator_through_file(picklesInJava))
+        bytesInJava = self._jrdd.collect().iterator()
+        return list(self._collect_iterator_through_file(bytesInJava))
 
     def _collect_iterator_through_file(self, iterator):
         # Transferring lots of data through Py4J can be slow because
@@ -400,7 +402,7 @@ class RDD(object):
         # file and read it back.
         tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
         tempFile.close()
-        self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
+        self.ctx._writeToFile(iterator, tempFile.name)
         # Read the data into Python and deserialize it:
         with open(tempFile.name, 'rb') as tempFile:
             for item in read_from_pickle_file(tempFile):
@@ -941,6 +943,7 @@ class PipelinedRDD(RDD):
             self.func = func
             self.preservesPartitioning = preservesPartitioning
             self._prev_jrdd = prev._jrdd
+        self._stage_input_is_pairs = prev._stage_input_is_pairs
         self.is_cached = False
         self.is_checkpointed = False
         self.ctx = prev.ctx
@@ -959,7 +962,7 @@ class PipelinedRDD(RDD):
             def batched_func(split, iterator):
                 return batched(oldfunc(split, iterator), batchSize)
             func = batched_func
-        cmds = [func, self._bypass_serializer]
+        cmds = [func, self._bypass_serializer, self._stage_input_is_pairs]
         pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
         broadcast_vars = ListConverter().convert(
             [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index fbc280fd3705a4c1b8cb49503d077e659806f530..fd02e1ee8fbd7a8a83e2249f08da965f97703362 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -93,6 +93,14 @@ def write_with_length(obj, stream):
     stream.write(obj)
 
 
+def read_mutf8(stream):
+    """
+    Read a string written with Java's DataOutputStream.writeUTF() method.
+    """
+    length = struct.unpack('>H', stream.read(2))[0]
+    return stream.read(length).decode('utf8')
+
+
 def read_with_length(stream):
     length = read_int(stream)
     obj = stream.read(length)
@@ -112,3 +120,13 @@ def read_from_pickle_file(stream):
                 yield obj
     except EOFError:
         return
+
+
+def read_pairs_from_pickle_file(stream):
+    try:
+        while True:
+            a = load_pickle(read_with_length(stream))
+            b = load_pickle(read_with_length(stream))
+            yield (a, b)
+    except EOFError:
+        return
\ No newline at end of file
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 7696df9d1c4c51dd2c2bd211cafd3bad89dd5250..4e64557fc49e414657f14cdfe645807f7e5bffbd 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -31,8 +31,8 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.cloudpickle import CloudPickler
 from pyspark.files import SparkFiles
 from pyspark.serializers import write_with_length, read_with_length, write_int, \
-    read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file \
-    SpecialLengths
+    read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file, \
+    SpecialLengths, read_mutf8, read_pairs_from_pickle_file
 
 
 def load_obj(infile):
@@ -53,7 +53,7 @@ def main(infile, outfile):
         return
 
     # fetch name of workdir
-    spark_files_dir = load_pickle(read_with_length(infile))
+    spark_files_dir = read_mutf8(infile)
     SparkFiles._root_directory = spark_files_dir
     SparkFiles._is_running_on_worker = True
 
@@ -68,17 +68,21 @@ def main(infile, outfile):
     sys.path.append(spark_files_dir) # *.py files that were added will be copied here
     num_python_includes =  read_int(infile)
     for _ in range(num_python_includes):
-        sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile))))
+        sys.path.append(os.path.join(spark_files_dir, read_mutf8(infile)))
 
     # now load function
     func = load_obj(infile)
     bypassSerializer = load_obj(infile)
+    stageInputIsPairs = load_obj(infile)
     if bypassSerializer:
         dumps = lambda x: x
     else:
         dumps = dump_pickle
     init_time = time.time()
-    iterator = read_from_pickle_file(infile)
+    if stageInputIsPairs:
+        iterator = read_pairs_from_pickle_file(infile)
+    else:
+        iterator = read_from_pickle_file(infile)
     try:
         for obj in func(split_index, iterator):
             write_with_length(dumps(obj), outfile)