diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 765a4511b64bc9ce9cbd3c1fb28388606a09a9dc..b97c94dad834a966d194e2efafceb1af75132b42 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -422,6 +422,67 @@ class DataFrame(object): """ return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) + @since(1.3) + def repartition(self, numPartitions, *cols): + """ + Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The + resulting DataFrame is hash partitioned. + + ``numPartitions`` can be an int to specify the target number of partitions or a Column. + If it is a Column, it will be used as the first partitioning column. If not specified, + the default number of partitions is used. + + .. versionchanged:: 1.6 + Added optional arguments to specify the partitioning columns. Also made numPartitions + optional if partitioning columns are specified. + + >>> df.repartition(10).rdd.getNumPartitions() + 10 + >>> data = df.unionAll(df).repartition("age") + >>> data.show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 2|Alice| + | 5| Bob| + | 5| Bob| + +---+-----+ + >>> data = data.repartition(7, "age") + >>> data.show() + +---+-----+ + |age| name| + +---+-----+ + | 5| Bob| + | 5| Bob| + | 2|Alice| + | 2|Alice| + +---+-----+ + >>> data.rdd.getNumPartitions() + 7 + >>> data = data.repartition("name", "age") + >>> data.show() + +---+-----+ + |age| name| + +---+-----+ + | 5| Bob| + | 5| Bob| + | 2|Alice| + | 2|Alice| + +---+-----+ + """ + if isinstance(numPartitions, int): + if len(cols) == 0: + return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) + else: + return DataFrame( + self._jdf.repartition(numPartitions, self._jcols(*cols)), self.sql_ctx) + elif isinstance(numPartitions, (basestring, Column)): + cols = (numPartitions, ) + cols + return DataFrame(self._jdf.repartition(self._jcols(*cols)), self.sql_ctx) + else: + raise TypeError("numPartitions should be an int or Column") + @since(1.3) def distinct(self): """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. @@ -589,6 +650,26 @@ class DataFrame(object): jdf = self._jdf.join(other._jdf, on._jc, how) return DataFrame(jdf, self.sql_ctx) + @since(1.6) + def sortWithinPartitions(self, *cols, **kwargs): + """Returns a new :class:`DataFrame` with each partition sorted by the specified column(s). + + :param cols: list of :class:`Column` or column names to sort by. + :param ascending: boolean or list of boolean (default True). + Sort ascending vs. descending. Specify list for multiple sort orders. + If a list is specified, length of the list must equal length of the `cols`. + + >>> df.sortWithinPartitions("age", ascending=False).show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ + """ + jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs)) + return DataFrame(jdf, self.sql_ctx) + @ignore_unicode_prefix @since(1.3) def sort(self, *cols, **kwargs): @@ -613,22 +694,7 @@ class DataFrame(object): >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] """ - if not cols: - raise ValueError("should sort by at least one column") - if len(cols) == 1 and isinstance(cols[0], list): - cols = cols[0] - jcols = [_to_java_column(c) for c in cols] - ascending = kwargs.get('ascending', True) - if isinstance(ascending, (bool, int)): - if not ascending: - jcols = [jc.desc() for jc in jcols] - elif isinstance(ascending, list): - jcols = [jc if asc else jc.desc() - for asc, jc in zip(ascending, jcols)] - else: - raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending)) - - jdf = self._jdf.sort(self._jseq(jcols)) + jdf = self._jdf.sort(self._sort_cols(cols, kwargs)) return DataFrame(jdf, self.sql_ctx) orderBy = sort @@ -650,6 +716,25 @@ class DataFrame(object): cols = cols[0] return self._jseq(cols, _to_java_column) + def _sort_cols(self, cols, kwargs): + """ Return a JVM Seq of Columns that describes the sort order + """ + if not cols: + raise ValueError("should sort by at least one column") + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + jcols = [_to_java_column(c) for c in cols] + ascending = kwargs.get('ascending', True) + if isinstance(ascending, (bool, int)): + if not ascending: + jcols = [jc.desc() for jc in jcols] + elif isinstance(ascending, list): + jcols = [jc if asc else jc.desc() + for asc, jc in zip(ascending, jcols)] + else: + raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending)) + return self._jseq(jcols) + @since("1.3.1") def describe(self, *cols): """Computes statistics for numeric columns.