Skip to content
Snippets Groups Projects
Commit 7ec3595d authored by Josh Rosen's avatar Josh Rosen
Browse files

Fix bug (introduced by batching) in PySpark take()

parent fbadb1cd
No related branches found
No related tags found
No related merge requests found
......@@ -194,7 +194,7 @@ private[spark] object PythonRDD {
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
def writeArrayToPickleFile[T](items: Array[T], filename: String) {
def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename))
for (item <- items) {
writeAsPickle(item, file)
......
......@@ -19,8 +19,8 @@ class SparkContext(object):
gateway = launch_gateway()
jvm = gateway.jvm
readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
writeArrayToPickleFile = jvm.PythonRDD.writeArrayToPickleFile
_readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
_writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024):
......@@ -94,7 +94,7 @@ class SparkContext(object):
for x in c:
write_with_length(dump_pickle(x), tempFile)
tempFile.close()
jrdd = self.readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
return RDD(jrdd, self)
def textFile(self, name, minSplits=None):
......
......@@ -30,7 +30,7 @@ def launch_gateway():
sys.stderr.write(line)
EchoOutputThread(proc.stdout).start()
# Connect to the gateway
gateway = JavaGateway(GatewayClient(port=port))
gateway = JavaGateway(GatewayClient(port=port), auto_convert=False)
# Import the classes used by PySpark
java_import(gateway.jvm, "spark.api.java.*")
java_import(gateway.jvm, "spark.api.python.*")
......
......@@ -152,8 +152,8 @@ class RDD(object):
into a list.
>>> rdd = sc.parallelize([1, 2, 3, 4], 2)
>>> rdd.glom().first()
[1, 2]
>>> sorted(rdd.glom().collect())
[[1, 2], [3, 4]]
"""
def func(iterator): yield list(iterator)
return self.mapPartitions(func)
......@@ -211,10 +211,10 @@ class RDD(object):
"""
Return a list that contains all of the elements in this RDD.
"""
picklesInJava = self._jrdd.rdd().collect()
return list(self._collect_array_through_file(picklesInJava))
picklesInJava = self._jrdd.collect().iterator()
return list(self._collect_iterator_through_file(picklesInJava))
def _collect_array_through_file(self, array):
def _collect_iterator_through_file(self, iterator):
# Transferring lots of data through Py4J can be slow because
# socket.readline() is inefficient. Instead, we'll dump the data to a
# file and read it back.
......@@ -224,7 +224,7 @@ class RDD(object):
try: os.unlink(tempFile.name)
except: pass
atexit.register(clean_up_file)
self.ctx.writeArrayToPickleFile(array, tempFile.name)
self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
for item in read_from_pickle_file(tempFile):
......@@ -325,11 +325,18 @@ class RDD(object):
a lot of partitions are required. In that case, use L{collect} to get
the whole RDD instead.
>>> sc.parallelize([2, 3, 4]).take(2)
>>> sc.parallelize([2, 3, 4, 5, 6]).take(2)
[2, 3]
"""
picklesInJava = self._jrdd.rdd().take(num)
return list(self._collect_array_through_file(picklesInJava))
>>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
[2, 3, 4, 5, 6]
"""
items = []
splits = self._jrdd.splits()
while len(items) < num and splits:
split = splits.pop(0)
iterator = self._jrdd.iterator(split)
items.extend(self._collect_iterator_through_file(iterator))
return items[:num]
def first(self):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment