Skip to content
Snippets Groups Projects
Commit 7ec3595d authored by Josh Rosen's avatar Josh Rosen
Browse files

Fix bug (introduced by batching) in PySpark take()

parent fbadb1cd
No related branches found
No related tags found
No related merge requests found
...@@ -194,7 +194,7 @@ private[spark] object PythonRDD { ...@@ -194,7 +194,7 @@ private[spark] object PythonRDD {
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
} }
def writeArrayToPickleFile[T](items: Array[T], filename: String) { def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename)) val file = new DataOutputStream(new FileOutputStream(filename))
for (item <- items) { for (item <- items) {
writeAsPickle(item, file) writeAsPickle(item, file)
......
...@@ -19,8 +19,8 @@ class SparkContext(object): ...@@ -19,8 +19,8 @@ class SparkContext(object):
gateway = launch_gateway() gateway = launch_gateway()
jvm = gateway.jvm jvm = gateway.jvm
readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
def __init__(self, master, jobName, sparkHome=None, pyFiles=None, def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024): environment=None, batchSize=1024):
...@@ -94,7 +94,7 @@ class SparkContext(object): ...@@ -94,7 +94,7 @@ class SparkContext(object):
for x in c: for x in c:
write_with_length(dump_pickle(x), tempFile) write_with_length(dump_pickle(x), tempFile)
tempFile.close() tempFile.close()
jrdd = self.readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
return RDD(jrdd, self) return RDD(jrdd, self)
def textFile(self, name, minSplits=None): def textFile(self, name, minSplits=None):
......
...@@ -30,7 +30,7 @@ def launch_gateway(): ...@@ -30,7 +30,7 @@ def launch_gateway():
sys.stderr.write(line) sys.stderr.write(line)
EchoOutputThread(proc.stdout).start() EchoOutputThread(proc.stdout).start()
# Connect to the gateway # Connect to the gateway
gateway = JavaGateway(GatewayClient(port=port)) gateway = JavaGateway(GatewayClient(port=port), auto_convert=False)
# Import the classes used by PySpark # Import the classes used by PySpark
java_import(gateway.jvm, "spark.api.java.*") java_import(gateway.jvm, "spark.api.java.*")
java_import(gateway.jvm, "spark.api.python.*") java_import(gateway.jvm, "spark.api.python.*")
......
...@@ -152,8 +152,8 @@ class RDD(object): ...@@ -152,8 +152,8 @@ class RDD(object):
into a list. into a list.
>>> rdd = sc.parallelize([1, 2, 3, 4], 2) >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
>>> rdd.glom().first() >>> sorted(rdd.glom().collect())
[1, 2] [[1, 2], [3, 4]]
""" """
def func(iterator): yield list(iterator) def func(iterator): yield list(iterator)
return self.mapPartitions(func) return self.mapPartitions(func)
...@@ -211,10 +211,10 @@ class RDD(object): ...@@ -211,10 +211,10 @@ class RDD(object):
""" """
Return a list that contains all of the elements in this RDD. Return a list that contains all of the elements in this RDD.
""" """
picklesInJava = self._jrdd.rdd().collect() picklesInJava = self._jrdd.collect().iterator()
return list(self._collect_array_through_file(picklesInJava)) return list(self._collect_iterator_through_file(picklesInJava))
def _collect_array_through_file(self, array): def _collect_iterator_through_file(self, iterator):
# Transferring lots of data through Py4J can be slow because # Transferring lots of data through Py4J can be slow because
# socket.readline() is inefficient. Instead, we'll dump the data to a # socket.readline() is inefficient. Instead, we'll dump the data to a
# file and read it back. # file and read it back.
...@@ -224,7 +224,7 @@ class RDD(object): ...@@ -224,7 +224,7 @@ class RDD(object):
try: os.unlink(tempFile.name) try: os.unlink(tempFile.name)
except: pass except: pass
atexit.register(clean_up_file) atexit.register(clean_up_file)
self.ctx.writeArrayToPickleFile(array, tempFile.name) self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
# Read the data into Python and deserialize it: # Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile: with open(tempFile.name, 'rb') as tempFile:
for item in read_from_pickle_file(tempFile): for item in read_from_pickle_file(tempFile):
...@@ -325,11 +325,18 @@ class RDD(object): ...@@ -325,11 +325,18 @@ class RDD(object):
a lot of partitions are required. In that case, use L{collect} to get a lot of partitions are required. In that case, use L{collect} to get
the whole RDD instead. the whole RDD instead.
>>> sc.parallelize([2, 3, 4]).take(2) >>> sc.parallelize([2, 3, 4, 5, 6]).take(2)
[2, 3] [2, 3]
""" >>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
picklesInJava = self._jrdd.rdd().take(num) [2, 3, 4, 5, 6]
return list(self._collect_array_through_file(picklesInJava)) """
items = []
splits = self._jrdd.splits()
while len(items) < num and splits:
split = splits.pop(0)
iterator = self._jrdd.iterator(split)
items.extend(self._collect_iterator_through_file(iterator))
return items[:num]
def first(self): def first(self):
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment