From bff6a46359131a8f9bc38b93149b22baa7c711cd Mon Sep 17 00:00:00 2001
From: Josh Rosen <joshrosen@eecs.berkeley.edu>
Date: Sat, 25 Aug 2012 18:00:25 -0700
Subject: [PATCH] Add pipe(), saveAsTextFile(), sc.union() to Python API.

---
 .../scala/spark/api/python/PythonRDD.scala    |  8 ++++--
 pyspark/pyspark/context.py                    | 14 ++++++-----
 pyspark/pyspark/rdd.py                        | 25 +++++++++++++++++--
 3 files changed, 37 insertions(+), 10 deletions(-)

diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index b9091fd436..4d3bdb3963 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -9,6 +9,7 @@ import spark._
 import api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
 import broadcast.Broadcast
 import scala.collection
+import java.nio.charset.Charset
 
 trait PythonRDDBase {
   def compute[T](split: Split, envVars: Map[String, String],
@@ -238,9 +239,12 @@ private object Pickle {
   val MARK : Byte = '('
   val APPENDS : Byte = 'e'
 }
-class ExtractValue extends spark.api.java.function.Function[(Array[Byte],
-  Array[Byte]), Array[Byte]] {
 
+private class ExtractValue extends spark.api.java.function.Function[(Array[Byte],
+  Array[Byte]), Array[Byte]] {
   override def call(pair: (Array[Byte], Array[Byte])) : Array[Byte] = pair._2
+}
 
+private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
+  override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
 }
diff --git a/pyspark/pyspark/context.py b/pyspark/pyspark/context.py
index b8490019e3..04932c93f2 100644
--- a/pyspark/pyspark/context.py
+++ b/pyspark/pyspark/context.py
@@ -7,6 +7,8 @@ from pyspark.java_gateway import launch_gateway
 from pyspark.serializers import dump_pickle, write_with_length
 from pyspark.rdd import RDD
 
+from py4j.java_collections import ListConverter
+
 
 class SparkContext(object):
 
@@ -39,12 +41,6 @@ class SparkContext(object):
         self._jsc = None
 
     def parallelize(self, c, numSlices=None):
-        """
-        >>> sc = SparkContext("local", "test")
-        >>> rdd = sc.parallelize([(1, 2), (3, 4)])
-        >>> rdd.collect()
-        [(1, 2), (3, 4)]
-        """
         numSlices = numSlices or self.defaultParallelism
         # Calling the Java parallelize() method with an ArrayList is too slow,
         # because it sends O(n) Py4J commands.  As an alternative, serialized
@@ -62,6 +58,12 @@ class SparkContext(object):
         jrdd = self._jsc.textFile(name, minSplits)
         return RDD(jrdd, self)
 
+    def union(self, rdds):
+        first = rdds[0]._jrdd
+        rest = [x._jrdd for x in rdds[1:]]
+        rest = ListConverter().convert(rest, self.gateway._gateway_client)
+        return RDD(self._jsc.union(first, rest), self)
+
     def broadcast(self, value):
         jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
         return Broadcast(jbroadcast.uuid().toString(), value, jbroadcast,
diff --git a/pyspark/pyspark/rdd.py b/pyspark/pyspark/rdd.py
index 21e822ba9f..8477f6dd02 100644
--- a/pyspark/pyspark/rdd.py
+++ b/pyspark/pyspark/rdd.py
@@ -1,6 +1,9 @@
 from base64 import standard_b64encode as b64enc
 from collections import Counter
 from itertools import chain, ifilter, imap
+import shlex
+from subprocess import Popen, PIPE
+from threading import Thread
 
 from pyspark import cloudpickle
 from pyspark.serializers import dump_pickle, load_pickle
@@ -118,7 +121,20 @@ class RDD(object):
         """
         return self.map(lambda x: (f(x), x)).groupByKey(numSplits)
 
-    # TODO: pipe
+    def pipe(self, command, env={}):
+        """
+        >>> sc.parallelize([1, 2, 3]).pipe('cat').collect()
+        ['1', '2', '3']
+        """
+        def func(iterator):
+            pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE)
+            def pipe_objs(out):
+                for obj in iterator:
+                    out.write(str(obj).rstrip('\n') + '\n')
+                out.close()
+            Thread(target=pipe_objs, args=[pipe.stdin]).start()
+            return (x.rstrip('\n') for x in pipe.stdout)
+        return self.mapPartitions(func)
 
     def foreach(self, f):
         """
@@ -206,7 +222,12 @@ class RDD(object):
         """
         return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first())))
 
-    # TODO: saveAsTextFile
+    def saveAsTextFile(self, path):
+        def func(iterator):
+            return (str(x).encode("utf-8") for x in iterator)
+        keyed = PipelinedRDD(self, func)
+        keyed._bypass_serializer = True
+        keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path)
 
     # Pair functions
 
-- 
GitLab