From b3872e00d155939e40366debda635fc3fb12cc73 Mon Sep 17 00:00:00 2001
From: Vladimir Vladimirov <vladimir.vladimirov@magnetic.com>
Date: Fri, 6 Feb 2015 13:55:02 -0800
Subject: [PATCH] SPARK-5633 pyspark saveAsTextFile support for compression
 codec

See https://issues.apache.org/jira/browse/SPARK-5633 for details

Author: Vladimir Vladimirov <vladimir.vladimirov@magnetic.com>

Closes #4403 from smartkiwi/master and squashes the following commits:

94c014e [Vladimir Vladimirov] SPARK-5633 pyspark saveAsTextFile support for compression codec
---
 python/pyspark/rdd.py | 22 ++++++++++++++++++++--
 1 file changed, 20 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 6e029bf7f1..bd4f16e058 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -1366,10 +1366,14 @@ class RDD(object):
             ser = BatchedSerializer(PickleSerializer(), batchSize)
         self._reserialize(ser)._jrdd.saveAsObjectFile(path)
 
-    def saveAsTextFile(self, path):
+    def saveAsTextFile(self, path, compressionCodecClass=None):
         """
         Save this RDD as a text file, using string representations of elements.
 
+        @param path: path to text file
+        @param compressionCodecClass: (None by default) string i.e.
+            "org.apache.hadoop.io.compress.GzipCodec"
+
         >>> tempFile = NamedTemporaryFile(delete=True)
         >>> tempFile.close()
         >>> sc.parallelize(range(10)).saveAsTextFile(tempFile.name)
@@ -1385,6 +1389,16 @@ class RDD(object):
         >>> sc.parallelize(['', 'foo', '', 'bar', '']).saveAsTextFile(tempFile2.name)
         >>> ''.join(sorted(input(glob(tempFile2.name + "/part-0000*"))))
         '\\n\\n\\nbar\\nfoo\\n'
+
+        Using compressionCodecClass
+
+        >>> tempFile3 = NamedTemporaryFile(delete=True)
+        >>> tempFile3.close()
+        >>> codec = "org.apache.hadoop.io.compress.GzipCodec"
+        >>> sc.parallelize(['foo', 'bar']).saveAsTextFile(tempFile3.name, codec)
+        >>> from fileinput import input, hook_compressed
+        >>> ''.join(sorted(input(glob(tempFile3.name + "/part*.gz"), openhook=hook_compressed)))
+        'bar\\nfoo\\n'
         """
         def func(split, iterator):
             for x in iterator:
@@ -1395,7 +1409,11 @@ class RDD(object):
                 yield x
         keyed = self.mapPartitionsWithIndex(func)
         keyed._bypass_serializer = True
-        keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path)
+        if compressionCodecClass:
+            compressionCodec = self.ctx._jvm.java.lang.Class.forName(compressionCodecClass)
+            keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path, compressionCodec)
+        else:
+            keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path)
 
     # Pair functions
 
-- 
GitLab