diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 81c420ce16541c3fb1baf7d13b4d7d2bd3260b1f..67752c0d150b99f365b2bffe03add3f413363c2e 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -486,7 +486,7 @@ class ExternalSorter(object): goes above the limit. """ global MemoryBytesSpilled, DiskBytesSpilled - batch, limit = 100, self.memory_limit + batch, limit = 100, self._next_limit() chunks, current_chunk = [], [] iterator = iter(iterator) while True: @@ -512,9 +512,6 @@ class ExternalSorter(object): f.close() chunks.append(load(open(path, 'rb'))) current_chunk = [] - gc.collect() - batch //= 2 - limit = self._next_limit() MemoryBytesSpilled += max(used_memory - get_used_memory(), 0) << 20 DiskBytesSpilled += os.path.getsize(path) os.unlink(path) # data will be deleted after close diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 11b402e6df6c15020f779a904d04aba6c916c75d..78265423682b062835f93436376f4fdf63d5164d 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -179,9 +179,12 @@ class SorterTests(unittest.TestCase): list(sorter.sorted(l, key=lambda x: -x, reverse=True))) def test_external_sort(self): + class CustomizedSorter(ExternalSorter): + def _next_limit(self): + return self.memory_limit l = list(range(1024)) random.shuffle(l) - sorter = ExternalSorter(1) + sorter = CustomizedSorter(1) self.assertEqual(sorted(l), list(sorter.sorted(l))) self.assertGreater(shuffle.DiskBytesSpilled, 0) last = shuffle.DiskBytesSpilled