diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py
index b90596ecc2bee41cba9e0852c3dd92eb651028b1..6172d69dcff97fba5546494732f0199f2c12e021 100644
--- a/pyspark/pyspark/context.py
+++ b/pyspark/pyspark/context.py
@@ -4,7 +4,7 @@ from tempfile import NamedTemporaryFile
 
 from pyspark.broadcast import Broadcast
 from pyspark.java_gateway import launch_gateway
-from pyspark.serializers import dump_pickle, write_with_length
+from pyspark.serializers import dump_pickle, write_with_length, batched
 from pyspark.rdd import RDD
 
 from py4j.java_collections import ListConverter
@@ -91,6 +91,8 @@ class SparkContext(object):
         # objects are written to a file and loaded through textFile().
         tempFile = NamedTemporaryFile(delete=False)
         atexit.register(lambda: os.unlink(tempFile.name))
+        if self.batchSize != 1:
+            c = batched(c, self.batchSize)
         for x in c:
             write_with_length(dump_pickle(x), tempFile)
         tempFile.close()
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index 20f84b2dd08e3394db6bfffbc10dd591630993bb..203f7377d2c049182afdc5d114392ab355e9a22c 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -2,7 +2,7 @@ import atexit
 from base64 import standard_b64encode as b64enc
 import copy
 from collections import defaultdict
-from itertools import chain, ifilter, imap
+from itertools import chain, ifilter, imap, product
 import operator
 import os
 import shlex
@@ -123,12 +123,6 @@ class RDD(object):
         >>> rdd = sc.parallelize([1, 1, 2, 3])
         >>> rdd.union(rdd).collect()
         [1, 1, 2, 3, 1, 1, 2, 3]
-
-        Union of batched and unbatched RDDs (internal test):
-
-        >>> batchedRDD = sc.parallelize([Batch([1, 2, 3, 4, 5])])
-        >>> rdd.union(batchedRDD).collect()
-        [1, 1, 2, 3, 1, 2, 3, 4, 5]
         """
         return RDD(self._jrdd.union(other._jrdd), self.ctx)
 
@@ -168,7 +162,18 @@ class RDD(object):
         >>> sorted(rdd.cartesian(rdd).collect())
         [(1, 1), (1, 2), (2, 1), (2, 2)]
         """
-        return RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
+        # Due to batching, we can't use the Java cartesian method.
+        java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
+        def unpack_batches(pair):
+            (x, y) = pair
+            if type(x) == Batch or type(y) == Batch:
+                xs = x.items if type(x) == Batch else [x]
+                ys = y.items if type(y) == Batch else [y]
+                for pair in product(xs, ys):
+                    yield pair
+            else:
+                yield pair
+        return java_cartesian.flatMap(unpack_batches)
 
     def groupBy(self, f, numSplits=None):
         """
@@ -293,8 +298,6 @@ class RDD(object):
 
         >>> sc.parallelize([2, 3, 4]).count()
         3
-        >>> sc.parallelize([Batch([2, 3, 4])]).count()
-        3
         """
         return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum()
 
@@ -667,12 +670,8 @@ class PipelinedRDD(RDD):
         if not self._bypass_serializer and self.ctx.batchSize != 1:
             oldfunc = self.func
             batchSize = self.ctx.batchSize
-            if batchSize == -1:  # unlimited batch size
-                def batched_func(iterator):
-                    yield Batch(list(oldfunc(iterator)))
-            else:
-                def batched_func(iterator):
-                    return batched(oldfunc(iterator), batchSize)
+            def batched_func(iterator):
+                return batched(oldfunc(iterator), batchSize)
             func = batched_func
         cmds = [func, self._bypass_serializer]
         pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
diff --git a/pyspark/pyspark/serializers.py b/pyspark/pyspark/serializers.py
index 8b08f7ef0f982173774d1d59abdfb8da28ebb9b5..9a5151ea00341459f83b85938b82a8555a91945b 100644
--- a/pyspark/pyspark/serializers.py
+++ b/pyspark/pyspark/serializers.py
@@ -16,17 +16,20 @@ class Batch(object):
 
 
 def batched(iterator, batchSize):
-    items = []
-    count = 0
-    for item in iterator:
-        items.append(item)
-        count += 1
-        if count == batchSize:
+    if batchSize == -1: # unlimited batch size
+        yield Batch(list(iterator))
+    else:
+        items = []
+        count = 0
+        for item in iterator:
+            items.append(item)
+            count += 1
+            if count == batchSize:
+                yield Batch(items)
+                items = []
+                count = 0
+        if items:
             yield Batch(items)
-            items = []
-            count = 0
-    if items:
-        yield Batch(items)
 
 
 def dump_pickle(obj):