diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a18eee3a0fab30ec02c3617c9df7190bf4a36480..47f9203acecaba55a16739c394552a5c27994de1 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2587,8 +2587,8 @@ setMethod("saveAsTable", #' summary #' -#' Computes statistics for numeric columns. -#' If no columns are given, this function computes statistics for all numerical columns. +#' Computes statistics for numeric and string columns. +#' If no columns are given, this function computes statistics for all numerical or string columns. #' #' @param x A SparkDataFrame to be computed. #' @param col A string of name diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index e2a1da0f1ee73c6ecd2f9844400b1fd9dec17e6f..fdd6020db9d0294c7e48f0a06570d37835a6f94d 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1824,11 +1824,11 @@ test_that("describe() and summarize() on a DataFrame", { expect_equal(collect(stats)[2, "age"], "24.5") expect_equal(collect(stats)[3, "age"], "7.7781745930520225") stats <- describe(df) - expect_equal(collect(stats)[4, "name"], NULL) + expect_equal(collect(stats)[4, "name"], "Andy") expect_equal(collect(stats)[5, "age"], "30") stats2 <- summary(df) - expect_equal(collect(stats2)[4, "name"], NULL) + expect_equal(collect(stats2)[4, "name"], "Andy") expect_equal(collect(stats2)[5, "age"], "30") # SPARK-16425: SparkR summary() fails on column of type logical diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index dd670a9b3db2ac86dff00ef4e3bf21ee93fe8584..ab41e88620b2cdaae42e5322dc256a956ddec8fa 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -751,15 +751,15 @@ class DataFrame(object): @since("1.3.1") def describe(self, *cols): - """Computes statistics for numeric columns. + """Computes statistics for numeric and string columns. This include count, mean, stddev, min, and max. If no columns are - given, this function computes statistics for all numerical columns. + given, this function computes statistics for all numerical or string columns. .. note:: This function is meant for exploratory data analysis, as we make no \ guarantee about the backward compatibility of the schema of the resulting DataFrame. - >>> df.describe().show() + >>> df.describe(['age']).show() +-------+------------------+ |summary| age| +-------+------------------+ @@ -769,7 +769,7 @@ class DataFrame(object): | min| 2| | max| 5| +-------+------------------+ - >>> df.describe(['age', 'name']).show() + >>> df.describe().show() +-------+------------------+-----+ |summary| age| name| +-------+------------------+-----+ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ededf7f4fe5b9428bb467f4237f16c409132b67f..ed4ccdb4c8d4f304bd035635b651a3da43849853 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -228,6 +228,15 @@ class Dataset[T] private[sql]( } } + private def aggregatableColumns: Seq[Expression] = { + schema.fields + .filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType]) + .map { n => + queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver) + .get + } + } + /** * Compose the string representing rows for output * @@ -1886,8 +1895,9 @@ class Dataset[T] private[sql]( } /** - * Computes statistics for numeric columns, including count, mean, stddev, min, and max. - * If no columns are given, this function computes statistics for all numerical columns. + * Computes statistics for numeric and string columns, including count, mean, stddev, min, and + * max. If no columns are given, this function computes statistics for all numerical or string + * columns. * * This function is meant for exploratory data analysis, as we make no guarantee about the * backward compatibility of the schema of the resulting Dataset. If you want to @@ -1920,7 +1930,7 @@ class Dataset[T] private[sql]( "max" -> ((child: Expression) => Max(child).toAggregateExpression())) val outputCols = - (if (cols.isEmpty) numericColumns.map(usePrettyExpression(_).sql) else cols).toList + (if (cols.isEmpty) aggregatableColumns.map(usePrettyExpression(_).sql) else cols).toList val ret: Seq[Row] = if (outputCols.nonEmpty) { val aggExprs = statistics.flatMap { case (_, colToAgg) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9d53be8e2bf86ef38d01329791448ec970717878..905da554f1cf130808d91eb31a0704c2144bbc21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -651,44 +651,44 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ("Amy", 24, 180)).toDF("name", "age", "height") val describeResult = Seq( - Row("count", "4", "4"), - Row("mean", "33.0", "178.0"), - Row("stddev", "19.148542155126762", "11.547005383792516"), - Row("min", "16", "164"), - Row("max", "60", "192")) + Row("count", "4", "4", "4"), + Row("mean", null, "33.0", "178.0"), + Row("stddev", null, "19.148542155126762", "11.547005383792516"), + Row("min", "Alice", "16", "164"), + Row("max", "David", "60", "192")) val emptyDescribeResult = Seq( - Row("count", "0", "0"), - Row("mean", null, null), - Row("stddev", null, null), - Row("min", null, null), - Row("max", null, null)) + Row("count", "0", "0", "0"), + Row("mean", null, null, null), + Row("stddev", null, null, null), + Row("min", null, null, null), + Row("max", null, null, null)) def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name) - val describeTwoCols = describeTestData.describe("age", "height") - assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height")) + val describeTwoCols = describeTestData.describe("name", "age", "height") + assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "name", "age", "height")) checkAnswer(describeTwoCols, describeResult) // All aggregate value should have been cast to string describeTwoCols.collect().foreach { row => - assert(row.get(1).isInstanceOf[String], "expected string but found " + row.get(1).getClass) assert(row.get(2).isInstanceOf[String], "expected string but found " + row.get(2).getClass) + assert(row.get(3).isInstanceOf[String], "expected string but found " + row.get(3).getClass) } val describeAllCols = describeTestData.describe() - assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height")) + assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height")) checkAnswer(describeAllCols, describeResult) val describeOneCol = describeTestData.describe("age") assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age")) - checkAnswer(describeOneCol, describeResult.map { case Row(s, d, _) => Row(s, d)} ) + checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d)} ) val describeNoCol = describeTestData.select("name").describe() - assert(getSchemaAsSeq(describeNoCol) === Seq("summary")) - checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _) => Row(s)} ) + assert(getSchemaAsSeq(describeNoCol) === Seq("summary", "name")) + checkAnswer(describeNoCol, describeResult.map { case Row(s, n, _, _) => Row(s, n)} ) val emptyDescription = describeTestData.limit(0).describe() - assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "age", "height")) + assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height")) checkAnswer(emptyDescription, emptyDescribeResult) }