Skip to content
Snippets Groups Projects
Commit bff6a463 authored by Josh Rosen's avatar Josh Rosen
Browse files

Add pipe(), saveAsTextFile(), sc.union() to Python API.

parent 200d248d
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,7 @@ import spark._
import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import broadcast.Broadcast
import scala.collection
import java.nio.charset.Charset
trait PythonRDDBase {
def compute[T](split: Split, envVars: Map[String, String],
......@@ -238,9 +239,12 @@ private object Pickle {
val MARK : Byte = '('
val APPENDS : Byte = 'e'
}
class ExtractValue extends spark.api.java.function.Function[(Array[Byte],
Array[Byte]), Array[Byte]] {
private class ExtractValue extends spark.api.java.function.Function[(Array[Byte],
Array[Byte]), Array[Byte]] {
override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2
}
private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
}
......@@ -7,6 +7,8 @@ from pyspark.java_gateway import launch_gateway
from pyspark.serializers import dump_pickle, write_with_length
from pyspark.rdd import RDD
from py4j.java_collections import ListConverter
class SparkContext(object):
......@@ -39,12 +41,6 @@ class SparkContext(object):
self._jsc = None
def parallelize(self, c, numSlices=None):
"""
>>> sc = SparkContext("local", "test")
>>> rdd = sc.parallelize([(1, 2), (3, 4)])
>>> rdd.collect()
[(1, 2), (3, 4)]
"""
numSlices = numSlices or self.defaultParallelism
# Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized
......@@ -62,6 +58,12 @@ class SparkContext(object):
jrdd = self._jsc.textFile(name, minSplits)
return RDD(jrdd, self)
def union(self, rdds):
first = rdds[0]._jrdd
rest = [x._jrdd for x in rdds[1:]]
rest = ListConverter().convert(rest, self.gateway._gateway_client)
return RDD(self._jsc.union(first, rest), self)
def broadcast(self, value):
jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast,
......
from base64 import standard_b64encode as b64enc
from collections import Counter
from itertools import chain, ifilter, imap
import shlex
from subprocess import Popen, PIPE
from threading import Thread
from pyspark import cloudpickle
from pyspark.serializers import dump_pickle, load_pickle
......@@ -118,7 +121,20 @@ class RDD(object):
"""
return self.map(lambda x: (f(x), x)).groupByKey(numSplits)
# TODO: pipe
def pipe(self, command, env={}):
"""
>>> sc.parallelize([1, 2, 3]).pipe('cat').collect()
['1', '2', '3']
"""
def func(iterator):
pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE)
def pipe_objs(out):
for obj in iterator:
out.write(str(obj).rstrip('\n') + '\n')
out.close()
Thread(target=pipe_objs, args=[pipe.stdin]).start()
return (x.rstrip('\n') for x in pipe.stdout)
return self.mapPartitions(func)
def foreach(self, f):
"""
......@@ -206,7 +222,12 @@ class RDD(object):
"""
return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first())))
# TODO: saveAsTextFile
def saveAsTextFile(self, path):
def func(iterator):
return (str(x).encode("utf-8") for x in iterator)
keyed = PipelinedRDD(self, func)
keyed._bypass_serializer = True
keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path)
# Pair functions
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment