diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 2a1326947f4f5b6fe623a37afe983b16d3d5119f..c4f2f08cb4445ffcae6cccc14ab6f05636a035f2 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -61,7 +61,7 @@ import itertools if sys.version < '3': import cPickle as pickle protocol = 2 - from itertools import izip as zip + from itertools import izip as zip, imap as map else: import pickle protocol = 3 @@ -96,7 +96,12 @@ class Serializer(object): raise NotImplementedError def _load_stream_without_unbatching(self, stream): - return self.load_stream(stream) + """ + Return an iterator of deserialized batches (lists) of objects from the input stream. + if the serializer does not operate on batches the default implementation returns an + iterator of single element lists. + """ + return map(lambda x: [x], self.load_stream(stream)) # Note: our notion of "equality" is that output generated by # equal serializers can be deserialized using the same serializer. @@ -278,50 +283,57 @@ class AutoBatchedSerializer(BatchedSerializer): return "AutoBatchedSerializer(%s)" % self.serializer -class CartesianDeserializer(FramedSerializer): +class CartesianDeserializer(Serializer): """ Deserializes the JavaRDD cartesian() of two PythonRDDs. + Due to pyspark batching we cannot simply use the result of the Java RDD cartesian, + we additionally need to do the cartesian within each pair of batches. """ def __init__(self, key_ser, val_ser): - FramedSerializer.__init__(self) self.key_ser = key_ser self.val_ser = val_ser - def prepare_keys_values(self, stream): - key_stream = self.key_ser._load_stream_without_unbatching(stream) - val_stream = self.val_ser._load_stream_without_unbatching(stream) - key_is_batched = isinstance(self.key_ser, BatchedSerializer) - val_is_batched = isinstance(self.val_ser, BatchedSerializer) - for (keys, vals) in zip(key_stream, val_stream): - keys = keys if key_is_batched else [keys] - vals = vals if val_is_batched else [vals] - yield (keys, vals) + def _load_stream_without_unbatching(self, stream): + key_batch_stream = self.key_ser._load_stream_without_unbatching(stream) + val_batch_stream = self.val_ser._load_stream_without_unbatching(stream) + for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream): + # for correctness with repeated cartesian/zip this must be returned as one batch + yield product(key_batch, val_batch) def load_stream(self, stream): - for (keys, vals) in self.prepare_keys_values(stream): - for pair in product(keys, vals): - yield pair + return chain.from_iterable(self._load_stream_without_unbatching(stream)) def __repr__(self): return "CartesianDeserializer(%s, %s)" % \ (str(self.key_ser), str(self.val_ser)) -class PairDeserializer(CartesianDeserializer): +class PairDeserializer(Serializer): """ Deserializes the JavaRDD zip() of two PythonRDDs. + Due to pyspark batching we cannot simply use the result of the Java RDD zip, + we additionally need to do the zip within each pair of batches. """ + def __init__(self, key_ser, val_ser): + self.key_ser = key_ser + self.val_ser = val_ser + + def _load_stream_without_unbatching(self, stream): + key_batch_stream = self.key_ser._load_stream_without_unbatching(stream) + val_batch_stream = self.val_ser._load_stream_without_unbatching(stream) + for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream): + if len(key_batch) != len(val_batch): + raise ValueError("Can not deserialize PairRDD with different number of items" + " in batches: (%d, %d)" % (len(key_batch), len(val_batch))) + # for correctness with repeated cartesian/zip this must be returned as one batch + yield zip(key_batch, val_batch) + def load_stream(self, stream): - for (keys, vals) in self.prepare_keys_values(stream): - if len(keys) != len(vals): - raise ValueError("Can not deserialize RDD with different number of items" - " in pair: (%d, %d)" % (len(keys), len(vals))) - for pair in zip(keys, vals): - yield pair + return chain.from_iterable(self._load_stream_without_unbatching(stream)) def __repr__(self): return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser)) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index ab4bef8329cd00c2e9ba113eac580edc42430f2e..89fce8ab25bafa5605efd0805d2e5ddd50de6454 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -548,6 +548,24 @@ class RDDTests(ReusedPySparkTestCase): self.assertEqual(u"Hello World!", x.strip()) self.assertEqual(u"Hello World!", y.strip()) + def test_cartesian_chaining(self): + # Tests for SPARK-16589 + rdd = self.sc.parallelize(range(10), 2) + self.assertSetEqual( + set(rdd.cartesian(rdd).cartesian(rdd).collect()), + set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)]) + ) + + self.assertSetEqual( + set(rdd.cartesian(rdd.cartesian(rdd)).collect()), + set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)]) + ) + + self.assertSetEqual( + set(rdd.cartesian(rdd.zip(rdd)).collect()), + set([(x, (y, y)) for x in range(10) for y in range(10)]) + ) + def test_deleting_input_files(self): # Regression test for SPARK-1025 tempFile = tempfile.NamedTemporaryFile(delete=False)