From f53a820721fe0525c275e2bb4415c20909c42dc3 Mon Sep 17 00:00:00 2001 From: zero323 <zero323@users.noreply.github.com> Date: Mon, 8 May 2017 10:58:27 +0800 Subject: [PATCH] [SPARK-16931][PYTHON][SQL] Add Python wrapper for bucketBy ## What changes were proposed in this pull request? Adds Python wrappers for `DataFrameWriter.bucketBy` and `DataFrameWriter.sortBy` ([SPARK-16931](https://issues.apache.org/jira/browse/SPARK-16931)) ## How was this patch tested? Unit tests covering new feature. __Note__: Based on work of GregBowyer (f49b9a23468f7af32cb53d2b654272757c151725) CC HyukjinKwon Author: zero323 <zero323@users.noreply.github.com> Author: Greg Bowyer <gbowyer@fastmail.co.uk> Closes #17077 from zero323/SPARK-16931. --- python/pyspark/sql/readwriter.py | 57 ++++++++++++++++++++++++++++++++ python/pyspark/sql/tests.py | 54 ++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 960fb882cf..90ce8f81eb 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -563,6 +563,63 @@ class DataFrameWriter(OptionUtils): self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols)) return self + @since(2.3) + def bucketBy(self, numBuckets, col, *cols): + """Buckets the output by the given columns.If specified, + the output is laid out on the file system similar to Hive's bucketing scheme. + + :param numBuckets: the number of buckets to save + :param col: a name of a column, or a list of names. + :param cols: additional names (optional). If `col` is a list it should be empty. + + .. note:: Applicable for file-based data sources in combination with + :py:meth:`DataFrameWriter.saveAsTable`. + + >>> (df.write.format('parquet') + ... .bucketBy(100, 'year', 'month') + ... .mode("overwrite") + ... .saveAsTable('bucketed_table')) + """ + if not isinstance(numBuckets, int): + raise TypeError("numBuckets should be an int, got {0}.".format(type(numBuckets))) + + if isinstance(col, (list, tuple)): + if cols: + raise ValueError("col is a {0} but cols are not empty".format(type(col))) + + col, cols = col[0], col[1:] + + if not all(isinstance(c, basestring) for c in cols) or not(isinstance(col, basestring)): + raise TypeError("all names should be `str`") + + self._jwrite = self._jwrite.bucketBy(numBuckets, col, _to_seq(self._spark._sc, cols)) + return self + + @since(2.3) + def sortBy(self, col, *cols): + """Sorts the output in each bucket by the given columns on the file system. + + :param col: a name of a column, or a list of names. + :param cols: additional names (optional). If `col` is a list it should be empty. + + >>> (df.write.format('parquet') + ... .bucketBy(100, 'year', 'month') + ... .sortBy('day') + ... .mode("overwrite") + ... .saveAsTable('sorted_bucketed_table')) + """ + if isinstance(col, (list, tuple)): + if cols: + raise ValueError("col is a {0} but cols are not empty".format(type(col))) + + col, cols = col[0], col[1:] + + if not all(isinstance(c, basestring) for c in cols) or not(isinstance(col, basestring)): + raise TypeError("all names should be `str`") + + self._jwrite = self._jwrite.sortBy(col, _to_seq(self._spark._sc, cols)) + return self + @since(1.4) def save(self, path=None, format=None, mode=None, partitionBy=None, **options): """Saves the contents of the :class:`DataFrame` to a data source. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7983bc536f..e3fe01eae2 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -211,6 +211,12 @@ class SQLTests(ReusedPySparkTestCase): sqlContext2 = SQLContext(self.sc) self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession) + def tearDown(self): + super(SQLTests, self).tearDown() + + # tear down test_bucketed_write state + self.spark.sql("DROP TABLE IF EXISTS pyspark_bucket") + def test_row_should_be_read_only(self): row = Row(a=1, b=2) self.assertEqual(1, row.a) @@ -2196,6 +2202,54 @@ class SQLTests(ReusedPySparkTestCase): df = self.spark.createDataFrame(data, schema=schema) df.collect() + def test_bucketed_write(self): + data = [ + (1, "foo", 3.0), (2, "foo", 5.0), + (3, "bar", -1.0), (4, "bar", 6.0), + ] + df = self.spark.createDataFrame(data, ["x", "y", "z"]) + + def count_bucketed_cols(names, table="pyspark_bucket"): + """Given a sequence of column names and a table name + query the catalog and return number o columns which are + used for bucketing + """ + cols = self.spark.catalog.listColumns(table) + num = len([c for c in cols if c.name in names and c.isBucket]) + return num + + # Test write with one bucketing column + df.write.bucketBy(3, "x").mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x"]), 1) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write two bucketing columns + df.write.bucketBy(3, "x", "y").mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x", "y"]), 2) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with bucket and sort + df.write.bucketBy(2, "x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x"]), 1) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with a list of columns + df.write.bucketBy(3, ["x", "y"]).mode("overwrite").saveAsTable("pyspark_bucket") + self.assertEqual(count_bucketed_cols(["x", "y"]), 2) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with bucket and sort with a list of columns + (df.write.bucketBy(2, "x") + .sortBy(["y", "z"]) + .mode("overwrite").saveAsTable("pyspark_bucket")) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + + # Test write with bucket and sort with multiple columns + (df.write.bucketBy(2, "x") + .sortBy("y", "z") + .mode("overwrite").saveAsTable("pyspark_bucket")) + self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) + class HiveSparkSubmitTests(SparkSubmitTests): -- GitLab