From d7b69946cb21cd2781c9ad3e691e54b28efbbf3d Mon Sep 17 00:00:00 2001 From: Davies Liu <davies@databricks.com> Date: Fri, 15 May 2015 20:09:15 -0700 Subject: [PATCH] [SPARK-7543] [SQL] [PySpark] split dataframe.py into multiple files dataframe.py is splited into column.py, group.py and dataframe.py: ``` 360 column.py 1223 dataframe.py 183 group.py ``` Author: Davies Liu <davies@databricks.com> Closes #6201 from davies/split_df and squashes the following commits: fc8f5ab [Davies Liu] split dataframe.py into multiple files --- python/pyspark/sql/__init__.py | 5 +- python/pyspark/sql/column.py | 360 +++++++++++++++++++++++++ python/pyspark/sql/dataframe.py | 449 +------------------------------- python/pyspark/sql/functions.py | 2 +- python/pyspark/sql/group.py | 183 +++++++++++++ python/run-tests | 2 + 6 files changed, 552 insertions(+), 449 deletions(-) create mode 100644 python/pyspark/sql/column.py create mode 100644 python/pyspark/sql/group.py diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 7192c89b3d..19805e291e 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -55,8 +55,9 @@ del modname, sys from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext -from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions -from pyspark.sql.dataframe import DataFrameStatFunctions +from pyspark.sql.column import Column +from pyspark.sql.dataframe import DataFrame, SchemaRDD, DataFrameNaFunctions, DataFrameStatFunctions +from pyspark.sql.group import GroupedData __all__ = [ 'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py new file mode 100644 index 0000000000..fc7ad674da --- /dev/null +++ b/python/pyspark/sql/column.py @@ -0,0 +1,360 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +if sys.version >= '3': + basestring = str + long = int + +from pyspark.context import SparkContext +from pyspark.rdd import ignore_unicode_prefix +from pyspark.sql.types import * + +__all__ = ["DataFrame", "Column", "SchemaRDD", "DataFrameNaFunctions", + "DataFrameStatFunctions"] + + +def _create_column_from_literal(literal): + sc = SparkContext._active_spark_context + return sc._jvm.functions.lit(literal) + + +def _create_column_from_name(name): + sc = SparkContext._active_spark_context + return sc._jvm.functions.col(name) + + +def _to_java_column(col): + if isinstance(col, Column): + jcol = col._jc + else: + jcol = _create_column_from_name(col) + return jcol + + +def _to_seq(sc, cols, converter=None): + """ + Convert a list of Column (or names) into a JVM Seq of Column. + + An optional `converter` could be used to convert items in `cols` + into JVM Column objects. + """ + if converter: + cols = [converter(c) for c in cols] + return sc._jvm.PythonUtils.toSeq(cols) + + +def _unary_op(name, doc="unary operator"): + """ Create a method for given unary operator """ + def _(self): + jc = getattr(self._jc, name)() + return Column(jc) + _.__doc__ = doc + return _ + + +def _func_op(name, doc=''): + def _(self): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.functions, name)(self._jc) + return Column(jc) + _.__doc__ = doc + return _ + + +def _bin_op(name, doc="binary operator"): + """ Create a method for given binary operator + """ + def _(self, other): + jc = other._jc if isinstance(other, Column) else other + njc = getattr(self._jc, name)(jc) + return Column(njc) + _.__doc__ = doc + return _ + + +def _reverse_op(name, doc="binary operator"): + """ Create a method for binary operator (this object is on right side) + """ + def _(self, other): + jother = _create_column_from_literal(other) + jc = getattr(jother, name)(self._jc) + return Column(jc) + _.__doc__ = doc + return _ + + +class Column(object): + + """ + A column in a DataFrame. + + :class:`Column` instances can be created by:: + + # 1. Select a column out of a DataFrame + + df.colName + df["colName"] + + # 2. Create from an expression + df.colName + 1 + 1 / df.colName + """ + + def __init__(self, jc): + self._jc = jc + + # arithmetic operators + __neg__ = _func_op("negate") + __add__ = _bin_op("plus") + __sub__ = _bin_op("minus") + __mul__ = _bin_op("multiply") + __div__ = _bin_op("divide") + __truediv__ = _bin_op("divide") + __mod__ = _bin_op("mod") + __radd__ = _bin_op("plus") + __rsub__ = _reverse_op("minus") + __rmul__ = _bin_op("multiply") + __rdiv__ = _reverse_op("divide") + __rtruediv__ = _reverse_op("divide") + __rmod__ = _reverse_op("mod") + + # logistic operators + __eq__ = _bin_op("equalTo") + __ne__ = _bin_op("notEqual") + __lt__ = _bin_op("lt") + __le__ = _bin_op("leq") + __ge__ = _bin_op("geq") + __gt__ = _bin_op("gt") + + # `and`, `or`, `not` cannot be overloaded in Python, + # so use bitwise operators as boolean operators + __and__ = _bin_op('and') + __or__ = _bin_op('or') + __invert__ = _func_op('not') + __rand__ = _bin_op("and") + __ror__ = _bin_op("or") + + # container operators + __contains__ = _bin_op("contains") + __getitem__ = _bin_op("apply") + + # bitwise operators + bitwiseOR = _bin_op("bitwiseOR") + bitwiseAND = _bin_op("bitwiseAND") + bitwiseXOR = _bin_op("bitwiseXOR") + + def getItem(self, key): + """An expression that gets an item at position `ordinal` out of a list, + or gets an item by key out of a dict. + + >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"]) + >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() + +----+------+ + |l[0]|d[key]| + +----+------+ + | 1| value| + +----+------+ + >>> df.select(df.l[0], df.d["key"]).show() + +----+------+ + |l[0]|d[key]| + +----+------+ + | 1| value| + +----+------+ + """ + return self[key] + + def getField(self, name): + """An expression that gets a field by name in a StructField. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() + >>> df.select(df.r.getField("b")).show() + +----+ + |r[b]| + +----+ + | b| + +----+ + >>> df.select(df.r.a).show() + +----+ + |r[a]| + +----+ + | 1| + +----+ + """ + return self[name] + + def __getattr__(self, item): + if item.startswith("__"): + raise AttributeError(item) + return self.getField(item) + + # string methods + rlike = _bin_op("rlike") + like = _bin_op("like") + startswith = _bin_op("startsWith") + endswith = _bin_op("endsWith") + + @ignore_unicode_prefix + def substr(self, startPos, length): + """ + Return a :class:`Column` which is a substring of the column + + :param startPos: start position (int or Column) + :param length: length of the substring (int or Column) + + >>> df.select(df.name.substr(1, 3).alias("col")).collect() + [Row(col=u'Ali'), Row(col=u'Bob')] + """ + if type(startPos) != type(length): + raise TypeError("Can not mix the type") + if isinstance(startPos, (int, long)): + jc = self._jc.substr(startPos, length) + elif isinstance(startPos, Column): + jc = self._jc.substr(startPos._jc, length._jc) + else: + raise TypeError("Unexpected type: %s" % type(startPos)) + return Column(jc) + + __getslice__ = substr + + @ignore_unicode_prefix + def inSet(self, *cols): + """ A boolean expression that is evaluated to true if the value of this + expression is contained by the evaluated values of the arguments. + + >>> df[df.name.inSet("Bob", "Mike")].collect() + [Row(age=5, name=u'Bob')] + >>> df[df.age.inSet([1, 2, 3])].collect() + [Row(age=2, name=u'Alice')] + """ + if len(cols) == 1 and isinstance(cols[0], (list, set)): + cols = cols[0] + cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] + sc = SparkContext._active_spark_context + jc = getattr(self._jc, "in")(_to_seq(sc, cols)) + return Column(jc) + + # order + asc = _unary_op("asc", "Returns a sort expression based on the" + " ascending order of the given column name.") + desc = _unary_op("desc", "Returns a sort expression based on the" + " descending order of the given column name.") + + isNull = _unary_op("isNull", "True if the current expression is null.") + isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") + + def alias(self, *alias): + """Returns this column aliased with a new name or names (in the case of expressions that + return more than one column, such as explode). + + >>> df.select(df.age.alias("age2")).collect() + [Row(age2=2), Row(age2=5)] + """ + + if len(alias) == 1: + return Column(getattr(self._jc, "as")(alias[0])) + else: + sc = SparkContext._active_spark_context + return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias)))) + + @ignore_unicode_prefix + def cast(self, dataType): + """ Convert the column into type `dataType` + + >>> df.select(df.age.cast("string").alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + >>> df.select(df.age.cast(StringType()).alias('ages')).collect() + [Row(ages=u'2'), Row(ages=u'5')] + """ + if isinstance(dataType, basestring): + jc = self._jc.cast(dataType) + elif isinstance(dataType, DataType): + sc = SparkContext._active_spark_context + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + jdt = ssql_ctx.parseDataType(dataType.json()) + jc = self._jc.cast(jdt) + else: + raise TypeError("unexpected type: %s" % type(dataType)) + return Column(jc) + + @ignore_unicode_prefix + def between(self, lowerBound, upperBound): + """ A boolean expression that is evaluated to true if the value of this + expression is between the given columns. + """ + return (self >= lowerBound) & (self <= upperBound) + + @ignore_unicode_prefix + def when(self, condition, value): + """Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + See :func:`pyspark.sql.functions.when` for example usage. + + :param condition: a boolean :class:`Column` expression. + :param value: a literal value, or a :class:`Column` expression. + + """ + sc = SparkContext._active_spark_context + if not isinstance(condition, Column): + raise TypeError("condition should be a Column") + v = value._jc if isinstance(value, Column) else value + jc = sc._jvm.functions.when(condition._jc, v) + return Column(jc) + + @ignore_unicode_prefix + def otherwise(self, value): + """Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + See :func:`pyspark.sql.functions.when` for example usage. + + :param value: a literal value, or a :class:`Column` expression. + """ + v = value._jc if isinstance(value, Column) else value + jc = self._jc.otherwise(value) + return Column(jc) + + def __repr__(self): + return 'Column<%s>' % self._jc.toString().encode('utf8') + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + import pyspark.sql.column + globs = pyspark.sql.column.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ + .toDF(StructType([StructField('age', IntegerType()), + StructField('name', StringType())])) + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.column, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2ed95ac8e2..96d927b9ba 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -25,17 +25,15 @@ if sys.version >= '3': else: from itertools import imap as map -from pyspark.context import SparkContext from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import * from pyspark.sql.types import _create_cls, _parse_datatype_json_string +from pyspark.sql.column import Column, _to_seq, _to_java_column - -__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions", - "DataFrameStatFunctions"] +__all__ = ["DataFrame", "SchemaRDD", "DataFrameNaFunctions", "DataFrameStatFunctions"] class DataFrame(object): @@ -757,6 +755,7 @@ class DataFrame(object): [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] """ jdf = self._jdf.groupBy(self._jcols(*cols)) + from pyspark.sql.group import GroupedData return GroupedData(jdf, self.sql_ctx) def agg(self, *exprs): @@ -1141,169 +1140,6 @@ class SchemaRDD(DataFrame): """ -def dfapi(f): - def _api(self): - name = f.__name__ - jdf = getattr(self._jdf, name)() - return DataFrame(jdf, self.sql_ctx) - _api.__name__ = f.__name__ - _api.__doc__ = f.__doc__ - return _api - - -def df_varargs_api(f): - def _api(self, *args): - name = f.__name__ - jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args)) - return DataFrame(jdf, self.sql_ctx) - _api.__name__ = f.__name__ - _api.__doc__ = f.__doc__ - return _api - - -class GroupedData(object): - """ - A set of methods for aggregations on a :class:`DataFrame`, - created by :func:`DataFrame.groupBy`. - """ - - def __init__(self, jdf, sql_ctx): - self._jdf = jdf - self.sql_ctx = sql_ctx - - @ignore_unicode_prefix - def agg(self, *exprs): - """Compute aggregates and returns the result as a :class:`DataFrame`. - - The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. - - If ``exprs`` is a single :class:`dict` mapping from string to string, then the key - is the column to perform aggregation on, and the value is the aggregate function. - - Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. - - :param exprs: a dict mapping from column name (string) to aggregate functions (string), - or a list of :class:`Column`. - - >>> gdf = df.groupBy(df.name) - >>> gdf.agg({"*": "count"}).collect() - [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)] - - >>> from pyspark.sql import functions as F - >>> gdf.agg(F.min(df.age)).collect() - [Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)] - """ - assert exprs, "exprs should not be empty" - if len(exprs) == 1 and isinstance(exprs[0], dict): - jdf = self._jdf.agg(exprs[0]) - else: - # Columns - assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" - jdf = self._jdf.agg(exprs[0]._jc, - _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) - return DataFrame(jdf, self.sql_ctx) - - @dfapi - def count(self): - """Counts the number of records for each group. - - >>> df.groupBy(df.age).count().collect() - [Row(age=2, count=1), Row(age=5, count=1)] - """ - - @df_varargs_api - def mean(self, *cols): - """Computes average values for each numeric columns for each group. - - :func:`mean` is an alias for :func:`avg`. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df.groupBy().mean('age').collect() - [Row(AVG(age)=3.5)] - >>> df3.groupBy().mean('age', 'height').collect() - [Row(AVG(age)=3.5, AVG(height)=82.5)] - """ - - @df_varargs_api - def avg(self, *cols): - """Computes average values for each numeric columns for each group. - - :func:`mean` is an alias for :func:`avg`. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df.groupBy().avg('age').collect() - [Row(AVG(age)=3.5)] - >>> df3.groupBy().avg('age', 'height').collect() - [Row(AVG(age)=3.5, AVG(height)=82.5)] - """ - - @df_varargs_api - def max(self, *cols): - """Computes the max value for each numeric columns for each group. - - >>> df.groupBy().max('age').collect() - [Row(MAX(age)=5)] - >>> df3.groupBy().max('age', 'height').collect() - [Row(MAX(age)=5, MAX(height)=85)] - """ - - @df_varargs_api - def min(self, *cols): - """Computes the min value for each numeric column for each group. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df.groupBy().min('age').collect() - [Row(MIN(age)=2)] - >>> df3.groupBy().min('age', 'height').collect() - [Row(MIN(age)=2, MIN(height)=80)] - """ - - @df_varargs_api - def sum(self, *cols): - """Compute the sum for each numeric columns for each group. - - :param cols: list of column names (string). Non-numeric columns are ignored. - - >>> df.groupBy().sum('age').collect() - [Row(SUM(age)=7)] - >>> df3.groupBy().sum('age', 'height').collect() - [Row(SUM(age)=7, SUM(height)=165)] - """ - - -def _create_column_from_literal(literal): - sc = SparkContext._active_spark_context - return sc._jvm.functions.lit(literal) - - -def _create_column_from_name(name): - sc = SparkContext._active_spark_context - return sc._jvm.functions.col(name) - - -def _to_java_column(col): - if isinstance(col, Column): - jcol = col._jc - else: - jcol = _create_column_from_name(col) - return jcol - - -def _to_seq(sc, cols, converter=None): - """ - Convert a list of Column (or names) into a JVM Seq of Column. - - An optional `converter` could be used to convert items in `cols` - into JVM Column objects. - """ - if converter: - cols = [converter(c) for c in cols] - return sc._jvm.PythonUtils.toSeq(cols) - - def _to_scala_map(sc, jm): """ Convert a dict into a JVM Map. @@ -1311,282 +1147,6 @@ def _to_scala_map(sc, jm): return sc._jvm.PythonUtils.toScalaMap(jm) -def _unary_op(name, doc="unary operator"): - """ Create a method for given unary operator """ - def _(self): - jc = getattr(self._jc, name)() - return Column(jc) - _.__doc__ = doc - return _ - - -def _func_op(name, doc=''): - def _(self): - sc = SparkContext._active_spark_context - jc = getattr(sc._jvm.functions, name)(self._jc) - return Column(jc) - _.__doc__ = doc - return _ - - -def _bin_op(name, doc="binary operator"): - """ Create a method for given binary operator - """ - def _(self, other): - jc = other._jc if isinstance(other, Column) else other - njc = getattr(self._jc, name)(jc) - return Column(njc) - _.__doc__ = doc - return _ - - -def _reverse_op(name, doc="binary operator"): - """ Create a method for binary operator (this object is on right side) - """ - def _(self, other): - jother = _create_column_from_literal(other) - jc = getattr(jother, name)(self._jc) - return Column(jc) - _.__doc__ = doc - return _ - - -class Column(object): - - """ - A column in a DataFrame. - - :class:`Column` instances can be created by:: - - # 1. Select a column out of a DataFrame - - df.colName - df["colName"] - - # 2. Create from an expression - df.colName + 1 - 1 / df.colName - """ - - def __init__(self, jc): - self._jc = jc - - # arithmetic operators - __neg__ = _func_op("negate") - __add__ = _bin_op("plus") - __sub__ = _bin_op("minus") - __mul__ = _bin_op("multiply") - __div__ = _bin_op("divide") - __truediv__ = _bin_op("divide") - __mod__ = _bin_op("mod") - __radd__ = _bin_op("plus") - __rsub__ = _reverse_op("minus") - __rmul__ = _bin_op("multiply") - __rdiv__ = _reverse_op("divide") - __rtruediv__ = _reverse_op("divide") - __rmod__ = _reverse_op("mod") - - # logistic operators - __eq__ = _bin_op("equalTo") - __ne__ = _bin_op("notEqual") - __lt__ = _bin_op("lt") - __le__ = _bin_op("leq") - __ge__ = _bin_op("geq") - __gt__ = _bin_op("gt") - - # `and`, `or`, `not` cannot be overloaded in Python, - # so use bitwise operators as boolean operators - __and__ = _bin_op('and') - __or__ = _bin_op('or') - __invert__ = _func_op('not') - __rand__ = _bin_op("and") - __ror__ = _bin_op("or") - - # container operators - __contains__ = _bin_op("contains") - __getitem__ = _bin_op("apply") - - # bitwise operators - bitwiseOR = _bin_op("bitwiseOR") - bitwiseAND = _bin_op("bitwiseAND") - bitwiseXOR = _bin_op("bitwiseXOR") - - def getItem(self, key): - """An expression that gets an item at position `ordinal` out of a list, - or gets an item by key out of a dict. - - >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"]) - >>> df.select(df.l.getItem(0), df.d.getItem("key")).show() - +----+------+ - |l[0]|d[key]| - +----+------+ - | 1| value| - +----+------+ - >>> df.select(df.l[0], df.d["key"]).show() - +----+------+ - |l[0]|d[key]| - +----+------+ - | 1| value| - +----+------+ - """ - return self[key] - - def getField(self, name): - """An expression that gets a field by name in a StructField. - - >>> from pyspark.sql import Row - >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF() - >>> df.select(df.r.getField("b")).show() - +----+ - |r[b]| - +----+ - | b| - +----+ - >>> df.select(df.r.a).show() - +----+ - |r[a]| - +----+ - | 1| - +----+ - """ - return self[name] - - def __getattr__(self, item): - if item.startswith("__"): - raise AttributeError(item) - return self.getField(item) - - # string methods - rlike = _bin_op("rlike") - like = _bin_op("like") - startswith = _bin_op("startsWith") - endswith = _bin_op("endsWith") - - @ignore_unicode_prefix - def substr(self, startPos, length): - """ - Return a :class:`Column` which is a substring of the column - - :param startPos: start position (int or Column) - :param length: length of the substring (int or Column) - - >>> df.select(df.name.substr(1, 3).alias("col")).collect() - [Row(col=u'Ali'), Row(col=u'Bob')] - """ - if type(startPos) != type(length): - raise TypeError("Can not mix the type") - if isinstance(startPos, (int, long)): - jc = self._jc.substr(startPos, length) - elif isinstance(startPos, Column): - jc = self._jc.substr(startPos._jc, length._jc) - else: - raise TypeError("Unexpected type: %s" % type(startPos)) - return Column(jc) - - __getslice__ = substr - - @ignore_unicode_prefix - def inSet(self, *cols): - """ A boolean expression that is evaluated to true if the value of this - expression is contained by the evaluated values of the arguments. - - >>> df[df.name.inSet("Bob", "Mike")].collect() - [Row(age=5, name=u'Bob')] - >>> df[df.age.inSet([1, 2, 3])].collect() - [Row(age=2, name=u'Alice')] - """ - if len(cols) == 1 and isinstance(cols[0], (list, set)): - cols = cols[0] - cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols] - sc = SparkContext._active_spark_context - jc = getattr(self._jc, "in")(_to_seq(sc, cols)) - return Column(jc) - - # order - asc = _unary_op("asc", "Returns a sort expression based on the" - " ascending order of the given column name.") - desc = _unary_op("desc", "Returns a sort expression based on the" - " descending order of the given column name.") - - isNull = _unary_op("isNull", "True if the current expression is null.") - isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") - - def alias(self, *alias): - """Returns this column aliased with a new name or names (in the case of expressions that - return more than one column, such as explode). - - >>> df.select(df.age.alias("age2")).collect() - [Row(age2=2), Row(age2=5)] - """ - - if len(alias) == 1: - return Column(getattr(self._jc, "as")(alias[0])) - else: - sc = SparkContext._active_spark_context - return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias)))) - - @ignore_unicode_prefix - def cast(self, dataType): - """ Convert the column into type `dataType` - - >>> df.select(df.age.cast("string").alias('ages')).collect() - [Row(ages=u'2'), Row(ages=u'5')] - >>> df.select(df.age.cast(StringType()).alias('ages')).collect() - [Row(ages=u'2'), Row(ages=u'5')] - """ - if isinstance(dataType, basestring): - jc = self._jc.cast(dataType) - elif isinstance(dataType, DataType): - sc = SparkContext._active_spark_context - ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) - jdt = ssql_ctx.parseDataType(dataType.json()) - jc = self._jc.cast(jdt) - else: - raise TypeError("unexpected type: %s" % type(dataType)) - return Column(jc) - - @ignore_unicode_prefix - def between(self, lowerBound, upperBound): - """ A boolean expression that is evaluated to true if the value of this - expression is between the given columns. - """ - return (self >= lowerBound) & (self <= upperBound) - - @ignore_unicode_prefix - def when(self, condition, value): - """Evaluates a list of conditions and returns one of multiple possible result expressions. - If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. - - See :func:`pyspark.sql.functions.when` for example usage. - - :param condition: a boolean :class:`Column` expression. - :param value: a literal value, or a :class:`Column` expression. - - """ - sc = SparkContext._active_spark_context - if not isinstance(condition, Column): - raise TypeError("condition should be a Column") - v = value._jc if isinstance(value, Column) else value - jc = sc._jvm.functions.when(condition._jc, v) - return Column(jc) - - @ignore_unicode_prefix - def otherwise(self, value): - """Evaluates a list of conditions and returns one of multiple possible result expressions. - If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. - - See :func:`pyspark.sql.functions.when` for example usage. - - :param value: a literal value, or a :class:`Column` expression. - """ - v = value._jc if isinstance(value, Column) else value - jc = self._jc.otherwise(value) - return Column(jc) - - def __repr__(self): - return 'Column<%s>' % self._jc.toString().encode('utf8') - - class DataFrameNaFunctions(object): """Functionality for working with missing data in :class:`DataFrame`. """ @@ -1646,9 +1206,6 @@ def _test(): .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() - globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), - Row(name='Bob', age=5, height=85)]).toDF() - globs['df4'] = sc.parallelize([Row(name='Alice', age=10, height=80), Row(name='Bob', age=5, height=None), Row(name='Tom', age=None, height=None), diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6cd6974b0e..8d0e766ecd 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -27,7 +27,7 @@ from pyspark import SparkContext from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType -from pyspark.sql.dataframe import Column, _to_java_column, _to_seq +from pyspark.sql.column import Column, _to_java_column, _to_seq __all__ = [ diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py new file mode 100644 index 0000000000..9f7c743c05 --- /dev/null +++ b/python/pyspark/sql/group.py @@ -0,0 +1,183 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.rdd import ignore_unicode_prefix +from pyspark.sql.column import Column, _to_seq +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.types import * + +__all__ = ["GroupedData"] + + +def dfapi(f): + def _api(self): + name = f.__name__ + jdf = getattr(self._jdf, name)() + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +def df_varargs_api(f): + def _api(self, *args): + name = f.__name__ + jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args)) + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +class GroupedData(object): + """ + A set of methods for aggregations on a :class:`DataFrame`, + created by :func:`DataFrame.groupBy`. + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + + @ignore_unicode_prefix + def agg(self, *exprs): + """Compute aggregates and returns the result as a :class:`DataFrame`. + + The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`. + + If ``exprs`` is a single :class:`dict` mapping from string to string, then the key + is the column to perform aggregation on, and the value is the aggregate function. + + Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions. + + :param exprs: a dict mapping from column name (string) to aggregate functions (string), + or a list of :class:`Column`. + + >>> gdf = df.groupBy(df.name) + >>> gdf.agg({"*": "count"}).collect() + [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)] + + >>> from pyspark.sql import functions as F + >>> gdf.agg(F.min(df.age)).collect() + [Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)] + """ + assert exprs, "exprs should not be empty" + if len(exprs) == 1 and isinstance(exprs[0], dict): + jdf = self._jdf.agg(exprs[0]) + else: + # Columns + assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" + jdf = self._jdf.agg(exprs[0]._jc, + _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]])) + return DataFrame(jdf, self.sql_ctx) + + @dfapi + def count(self): + """Counts the number of records for each group. + + >>> df.groupBy(df.age).count().collect() + [Row(age=2, count=1), Row(age=5, count=1)] + """ + + @df_varargs_api + def mean(self, *cols): + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().mean('age').collect() + [Row(AVG(age)=3.5)] + >>> df3.groupBy().mean('age', 'height').collect() + [Row(AVG(age)=3.5, AVG(height)=82.5)] + """ + + @df_varargs_api + def avg(self, *cols): + """Computes average values for each numeric columns for each group. + + :func:`mean` is an alias for :func:`avg`. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().avg('age').collect() + [Row(AVG(age)=3.5)] + >>> df3.groupBy().avg('age', 'height').collect() + [Row(AVG(age)=3.5, AVG(height)=82.5)] + """ + + @df_varargs_api + def max(self, *cols): + """Computes the max value for each numeric columns for each group. + + >>> df.groupBy().max('age').collect() + [Row(MAX(age)=5)] + >>> df3.groupBy().max('age', 'height').collect() + [Row(MAX(age)=5, MAX(height)=85)] + """ + + @df_varargs_api + def min(self, *cols): + """Computes the min value for each numeric column for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().min('age').collect() + [Row(MIN(age)=2)] + >>> df3.groupBy().min('age', 'height').collect() + [Row(MIN(age)=2, MIN(height)=80)] + """ + + @df_varargs_api + def sum(self, *cols): + """Compute the sum for each numeric columns for each group. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().sum('age').collect() + [Row(SUM(age)=7)] + >>> df3.groupBy().sum('age', 'height').collect() + [Row(SUM(age)=7, SUM(height)=165)] + """ + + +def _test(): + import doctest + from pyspark.context import SparkContext + from pyspark.sql import Row, SQLContext + import pyspark.sql.group + globs = pyspark.sql.group.__dict__.copy() + sc = SparkContext('local[4]', 'PythonTest') + globs['sc'] = sc + globs['sqlContext'] = SQLContext(sc) + globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ + .toDF(StructType([StructField('age', IntegerType()), + StructField('name', StringType())])) + globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), + Row(name='Bob', age=5, height=85)]).toDF() + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.group, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/run-tests b/python/run-tests index f2757a3967..ffde2fb24b 100755 --- a/python/run-tests +++ b/python/run-tests @@ -72,7 +72,9 @@ function run_sql_tests() { echo "Run sql tests ..." run_test "pyspark/sql/_types.py" run_test "pyspark/sql/context.py" + run_test "pyspark/sql/column.py" run_test "pyspark/sql/dataframe.py" + run_test "pyspark/sql/group.py" run_test "pyspark/sql/functions.py" run_test "pyspark/sql/tests.py" } -- GitLab