From 142df4834bc33dc7b84b626c6ee3508ab1abe015 Mon Sep 17 00:00:00 2001
From: Dongjoon Hyun <dongjoon@apache.org>
Date: Fri, 8 Jul 2016 14:36:50 -0700
Subject: [PATCH] [SPARK-16429][SQL] Include `StringType` columns in
 `describe()`

## What changes were proposed in this pull request?

Currently, Spark `describe` supports `StringType`. However, `describe()` returns a dataset for only all numeric columns. This PR aims to include `StringType` columns in `describe()`, `describe` without argument.

**Background**
```scala
scala> spark.read.json("examples/src/main/resources/people.json").describe("age", "name").show()
+-------+------------------+-------+
|summary|               age|   name|
+-------+------------------+-------+
|  count|                 2|      3|
|   mean|              24.5|   null|
| stddev|7.7781745930520225|   null|
|    min|                19|   Andy|
|    max|                30|Michael|
+-------+------------------+-------+
```

**Before**
```scala
scala> spark.read.json("examples/src/main/resources/people.json").describe().show()
+-------+------------------+
|summary|               age|
+-------+------------------+
|  count|                 2|
|   mean|              24.5|
| stddev|7.7781745930520225|
|    min|                19|
|    max|                30|
+-------+------------------+
```

**After**
```scala
scala> spark.read.json("examples/src/main/resources/people.json").describe().show()
+-------+------------------+-------+
|summary|               age|   name|
+-------+------------------+-------+
|  count|                 2|      3|
|   mean|              24.5|   null|
| stddev|7.7781745930520225|   null|
|    min|                19|   Andy|
|    max|                30|Michael|
+-------+------------------+-------+
```

## How was this patch tested?

Pass the Jenkins with a update testcase.

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #14095 from dongjoon-hyun/SPARK-16429.
---
 R/pkg/R/DataFrame.R                           |  4 +--
 R/pkg/inst/tests/testthat/test_sparkSQL.R     |  4 +--
 python/pyspark/sql/dataframe.py               |  8 ++---
 .../scala/org/apache/spark/sql/Dataset.scala  | 16 +++++++--
 .../org/apache/spark/sql/DataFrameSuite.scala | 36 +++++++++----------
 5 files changed, 39 insertions(+), 29 deletions(-)

diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index a18eee3a0f..47f9203ace 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -2587,8 +2587,8 @@ setMethod("saveAsTable",
 
 #' summary
 #'
-#' Computes statistics for numeric columns.
-#' If no columns are given, this function computes statistics for all numerical columns.
+#' Computes statistics for numeric and string columns.
+#' If no columns are given, this function computes statistics for all numerical or string columns.
 #'
 #' @param x A SparkDataFrame to be computed.
 #' @param col A string of name
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index e2a1da0f1e..fdd6020db9 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1824,11 +1824,11 @@ test_that("describe() and summarize() on a DataFrame", {
   expect_equal(collect(stats)[2, "age"], "24.5")
   expect_equal(collect(stats)[3, "age"], "7.7781745930520225")
   stats <- describe(df)
-  expect_equal(collect(stats)[4, "name"], NULL)
+  expect_equal(collect(stats)[4, "name"], "Andy")
   expect_equal(collect(stats)[5, "age"], "30")
 
   stats2 <- summary(df)
-  expect_equal(collect(stats2)[4, "name"], NULL)
+  expect_equal(collect(stats2)[4, "name"], "Andy")
   expect_equal(collect(stats2)[5, "age"], "30")
 
   # SPARK-16425: SparkR summary() fails on column of type logical
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index dd670a9b3d..ab41e88620 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -751,15 +751,15 @@ class DataFrame(object):
 
     @since("1.3.1")
     def describe(self, *cols):
-        """Computes statistics for numeric columns.
+        """Computes statistics for numeric and string columns.
 
         This include count, mean, stddev, min, and max. If no columns are
-        given, this function computes statistics for all numerical columns.
+        given, this function computes statistics for all numerical or string columns.
 
         .. note:: This function is meant for exploratory data analysis, as we make no \
         guarantee about the backward compatibility of the schema of the resulting DataFrame.
 
-        >>> df.describe().show()
+        >>> df.describe(['age']).show()
         +-------+------------------+
         |summary|               age|
         +-------+------------------+
@@ -769,7 +769,7 @@ class DataFrame(object):
         |    min|                 2|
         |    max|                 5|
         +-------+------------------+
-        >>> df.describe(['age', 'name']).show()
+        >>> df.describe().show()
         +-------+------------------+-----+
         |summary|               age| name|
         +-------+------------------+-----+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index ededf7f4fe..ed4ccdb4c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -228,6 +228,15 @@ class Dataset[T] private[sql](
     }
   }
 
+  private def aggregatableColumns: Seq[Expression] = {
+    schema.fields
+      .filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType])
+      .map { n =>
+        queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver)
+          .get
+      }
+  }
+
   /**
    * Compose the string representing rows for output
    *
@@ -1886,8 +1895,9 @@ class Dataset[T] private[sql](
   }
 
   /**
-   * Computes statistics for numeric columns, including count, mean, stddev, min, and max.
-   * If no columns are given, this function computes statistics for all numerical columns.
+   * Computes statistics for numeric and string columns, including count, mean, stddev, min, and
+   * max. If no columns are given, this function computes statistics for all numerical or string
+   * columns.
    *
    * This function is meant for exploratory data analysis, as we make no guarantee about the
    * backward compatibility of the schema of the resulting Dataset. If you want to
@@ -1920,7 +1930,7 @@ class Dataset[T] private[sql](
       "max" -> ((child: Expression) => Max(child).toAggregateExpression()))
 
     val outputCols =
-      (if (cols.isEmpty) numericColumns.map(usePrettyExpression(_).sql) else cols).toList
+      (if (cols.isEmpty) aggregatableColumns.map(usePrettyExpression(_).sql) else cols).toList
 
     val ret: Seq[Row] = if (outputCols.nonEmpty) {
       val aggExprs = statistics.flatMap { case (_, colToAgg) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 9d53be8e2b..905da554f1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -651,44 +651,44 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
       ("Amy", 24, 180)).toDF("name", "age", "height")
 
     val describeResult = Seq(
-      Row("count", "4", "4"),
-      Row("mean", "33.0", "178.0"),
-      Row("stddev", "19.148542155126762", "11.547005383792516"),
-      Row("min", "16", "164"),
-      Row("max", "60", "192"))
+      Row("count", "4", "4", "4"),
+      Row("mean", null, "33.0", "178.0"),
+      Row("stddev", null, "19.148542155126762", "11.547005383792516"),
+      Row("min", "Alice", "16", "164"),
+      Row("max", "David", "60", "192"))
 
     val emptyDescribeResult = Seq(
-      Row("count", "0", "0"),
-      Row("mean", null, null),
-      Row("stddev", null, null),
-      Row("min", null, null),
-      Row("max", null, null))
+      Row("count", "0", "0", "0"),
+      Row("mean", null, null, null),
+      Row("stddev", null, null, null),
+      Row("min", null, null, null),
+      Row("max", null, null, null))
 
     def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name)
 
-    val describeTwoCols = describeTestData.describe("age", "height")
-    assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height"))
+    val describeTwoCols = describeTestData.describe("name", "age", "height")
+    assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "name", "age", "height"))
     checkAnswer(describeTwoCols, describeResult)
     // All aggregate value should have been cast to string
     describeTwoCols.collect().foreach { row =>
-      assert(row.get(1).isInstanceOf[String], "expected string but found " + row.get(1).getClass)
       assert(row.get(2).isInstanceOf[String], "expected string but found " + row.get(2).getClass)
+      assert(row.get(3).isInstanceOf[String], "expected string but found " + row.get(3).getClass)
     }
 
     val describeAllCols = describeTestData.describe()
-    assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "age", "height"))
+    assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height"))
     checkAnswer(describeAllCols, describeResult)
 
     val describeOneCol = describeTestData.describe("age")
     assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age"))
-    checkAnswer(describeOneCol, describeResult.map { case Row(s, d, _) => Row(s, d)} )
+    checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d)} )
 
     val describeNoCol = describeTestData.select("name").describe()
-    assert(getSchemaAsSeq(describeNoCol) === Seq("summary"))
-    checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _) => Row(s)} )
+    assert(getSchemaAsSeq(describeNoCol) === Seq("summary", "name"))
+    checkAnswer(describeNoCol, describeResult.map { case Row(s, n, _, _) => Row(s, n)} )
 
     val emptyDescription = describeTestData.limit(0).describe()
-    assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "age", "height"))
+    assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height"))
     checkAnswer(emptyDescription, emptyDescribeResult)
   }
 
-- 
GitLab