Skip to content
Snippets Groups Projects
Commit abd58175 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-4398][PySpark] specialize sc.parallelize(xrange)

`sc.parallelize(range(1 << 20), 1).count()` may take 15 seconds to finish and the rdd object stores the entire list, making task size very large. This PR adds a specialized version for xrange.

JoshRosen davies

Author: Xiangrui Meng <meng@databricks.com>

Closes #3264 from mengxr/SPARK-4398 and squashes the following commits:

8953c41 [Xiangrui Meng] follow davies' suggestion
cbd58e3 [Xiangrui Meng] specialize sc.parallelize(xrange)
parent 77e845ca
No related branches found
No related tags found
No related merge requests found
...@@ -289,12 +289,29 @@ class SparkContext(object): ...@@ -289,12 +289,29 @@ class SparkContext(object):
def parallelize(self, c, numSlices=None): def parallelize(self, c, numSlices=None):
""" """
Distribute a local Python collection to form an RDD. Distribute a local Python collection to form an RDD. Using xrange
is recommended if the input represents a range for performance.
>>> sc.parallelize(range(5), 5).glom().collect() >>> sc.parallelize([0, 2, 3, 4, 6], 5).glom().collect()
[[0], [1], [2], [3], [4]] [[0], [2], [3], [4], [6]]
>>> sc.parallelize(xrange(0, 6, 2), 5).glom().collect()
[[], [0], [], [2], [4]]
""" """
numSlices = numSlices or self.defaultParallelism numSlices = int(numSlices) if numSlices is not None else self.defaultParallelism
if isinstance(c, xrange):
size = len(c)
if size == 0:
return self.parallelize([], numSlices)
step = c[1] - c[0] if size > 1 else 1
start0 = c[0]
def getStart(split):
return start0 + (split * size / numSlices) * step
def f(split, iterator):
return xrange(getStart(split), getStart(split + 1), step)
return self.parallelize([], numSlices).mapPartitionsWithIndex(f)
# Calling the Java parallelize() method with an ArrayList is too slow, # Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized # because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile(). # objects are written to a file and loaded through textFile().
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment