diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 57754776faaa200dafb2a34531fad65d8bae2e79..bd2ff00c0f1bee8bce284cfd26f39436d0912eec 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -469,8 +469,7 @@ class RDD(object): def _reserialize(self, serializer=None): serializer = serializer or self.ctx.serializer if self._jrdd_deserializer != serializer: - if not isinstance(self, PipelinedRDD): - self = self.map(lambda x: x, preservesPartitioning=True) + self = self.map(lambda x: x, preservesPartitioning=True) self._jrdd_deserializer = serializer return self @@ -1798,23 +1797,21 @@ class RDD(object): def get_batch_size(ser): if isinstance(ser, BatchedSerializer): return ser.batchSize - return 1 + return 1 # not batched def batch_as(rdd, batchSize): - ser = rdd._jrdd_deserializer - if isinstance(ser, BatchedSerializer): - ser = ser.serializer - return rdd._reserialize(BatchedSerializer(ser, batchSize)) + return rdd._reserialize(BatchedSerializer(PickleSerializer(), batchSize)) my_batch = get_batch_size(self._jrdd_deserializer) other_batch = get_batch_size(other._jrdd_deserializer) - # use the smallest batchSize for both of them - batchSize = min(my_batch, other_batch) - if batchSize <= 0: - # auto batched or unlimited - batchSize = 100 - other = batch_as(other, batchSize) - self = batch_as(self, batchSize) + if my_batch != other_batch: + # use the smallest batchSize for both of them + batchSize = min(my_batch, other_batch) + if batchSize <= 0: + # auto batched or unlimited + batchSize = 100 + other = batch_as(other, batchSize) + self = batch_as(self, batchSize) if self.getNumPartitions() != other.getNumPartitions(): raise ValueError("Can only zip with RDD which has the same number of partitions") diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 33aa55f7f142951e54d8a996012fcdfff22f0865..bd08c9a6d20d6c073d50181cde50440406410f00 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -463,6 +463,9 @@ class CompressedSerializer(FramedSerializer): def loads(self, obj): return self.serializer.loads(zlib.decompress(obj)) + def __eq__(self, other): + return isinstance(other, CompressedSerializer) and self.serializer == other.serializer + class UTF8Deserializer(Serializer): @@ -489,6 +492,9 @@ class UTF8Deserializer(Serializer): except EOFError: return + def __eq__(self, other): + return isinstance(other, UTF8Deserializer) and self.use_unicode == other.use_unicode + def read_long(stream): length = stream.read(8) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 32645778c2b8f47f7d973ccab90280428d96bfce..bca52a7ce6d589bc0abd5888766001a12d7bf94b 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -533,6 +533,15 @@ class RDDTests(ReusedPySparkTestCase): a = a._reserialize(BatchedSerializer(PickleSerializer(), 2)) b = b._reserialize(MarshalSerializer()) self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) + # regression test for SPARK-4841 + path = os.path.join(SPARK_HOME, "python/test_support/hello.txt") + t = self.sc.textFile(path) + cnt = t.count() + self.assertEqual(cnt, t.zip(t).count()) + rdd = t.map(str) + self.assertEqual(cnt, t.zip(rdd).count()) + # regression test for bug in _reserializer() + self.assertEqual(cnt, t.zip(rdd).count()) def test_zip_with_different_number_of_items(self): a = self.sc.parallelize(range(5), 2)