diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 5326b45b50a8b01b81f47701aaa05bf80607a5b3..dfb51192c69bc1d0debc290462531bc65910d743 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -224,15 +224,6 @@ class Dataset[T] private[sql](
     }
   }
 
-  private[sql] def aggregatableColumns: Seq[Expression] = {
-    schema.fields
-      .filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType])
-      .map { n =>
-        queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver)
-          .get
-      }
-  }
-
   /**
    * Compose the string representing rows for output
    *
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 436e18fdb5ff5c331d9deecd2845334e906f3fd1..a75cfb3600225d8bb86ebf7a701266ac9a939869 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -17,12 +17,15 @@
 
 package org.apache.spark.sql.execution.stat
 
+import java.util.Locale
+
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
-import org.apache.spark.sql.catalyst.expressions.{Cast, CreateArray, Expression, GenericInternalRow, Literal}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, GenericInternalRow, GetArrayItem, Literal}
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.catalyst.util.{usePrettyExpression, QuantileSummaries}
+import org.apache.spark.sql.catalyst.util.QuantileSummaries
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -228,90 +231,68 @@ object StatFunctions extends Logging {
     val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max")
     val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics
 
-    val hasPercentiles = selectedStatistics.exists(_.endsWith("%"))
-    val (percentiles, percentileNames, remainingAggregates) = if (hasPercentiles) {
-      val (pStrings, rest) = selectedStatistics.partition(a => a.endsWith("%"))
-      val percentiles = pStrings.map { p =>
-        try {
-          p.stripSuffix("%").toDouble / 100.0
-        } catch {
-          case e: NumberFormatException =>
-            throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e)
-        }
+    val percentiles = selectedStatistics.filter(a => a.endsWith("%")).map { p =>
+      try {
+        p.stripSuffix("%").toDouble / 100.0
+      } catch {
+        case e: NumberFormatException =>
+          throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e)
       }
-      require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]")
-      (percentiles, pStrings, rest)
-    } else {
-      (Seq(), Seq(), selectedStatistics)
     }
+    require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]")
 
-
-    // The list of summary statistics to compute, in the form of expressions.
-    val availableStatistics = Map[String, Expression => Expression](
-      "count" -> ((child: Expression) => Count(child).toAggregateExpression()),
-      "mean" -> ((child: Expression) => Average(child).toAggregateExpression()),
-      "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()),
-      "min" -> ((child: Expression) => Min(child).toAggregateExpression()),
-      "max" -> ((child: Expression) => Max(child).toAggregateExpression()))
-
-    val statisticFns = remainingAggregates.map { agg =>
-      require(availableStatistics.contains(agg), s"$agg is not a recognised statistic")
-      agg -> availableStatistics(agg)
-    }
-
-    def percentileAgg(child: Expression): Expression =
-      new ApproximatePercentile(child, CreateArray(percentiles.map(Literal(_))))
-        .toAggregateExpression()
-
-    val outputCols = ds.aggregatableColumns.map(usePrettyExpression(_).sql).toList
-
-    val ret: Seq[Row] = if (outputCols.nonEmpty) {
-      var aggExprs = statisticFns.toList.flatMap { case (_, colToAgg) =>
-        outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c))
-      }
-      if (hasPercentiles) {
-        aggExprs = outputCols.map(c => Column(percentileAgg(Column(c).expr)).as(c)) ++ aggExprs
+    var percentileIndex = 0
+    val statisticFns = selectedStatistics.map { stats =>
+      if (stats.endsWith("%")) {
+        val index = percentileIndex
+        percentileIndex += 1
+        (child: Expression) =>
+          GetArrayItem(
+            new ApproximatePercentile(child, Literal.create(percentiles)).toAggregateExpression(),
+            Literal(index))
+      } else {
+        stats.toLowerCase(Locale.ROOT) match {
+          case "count" => (child: Expression) => Count(child).toAggregateExpression()
+          case "mean" => (child: Expression) => Average(child).toAggregateExpression()
+          case "stddev" => (child: Expression) => StddevSamp(child).toAggregateExpression()
+          case "min" => (child: Expression) => Min(child).toAggregateExpression()
+          case "max" => (child: Expression) => Max(child).toAggregateExpression()
+          case _ => throw new IllegalArgumentException(s"$stats is not a recognised statistic")
+        }
       }
+    }
 
-      val row = ds.groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
+    val selectedCols = ds.logicalPlan.output
+      .filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType])
 
-      // Pivot the data so each summary is one row
-      val grouped: Seq[Seq[Any]] = row.grouped(outputCols.size).toSeq
+    val aggExprs = statisticFns.flatMap { func =>
+      selectedCols.map(c => Column(Cast(func(c), StringType)).as(c.name))
+    }
 
-      val basicStats = if (hasPercentiles) grouped.tail else grouped
+    // If there is no selected columns, we don't need to run this aggregate, so make it a lazy val.
+    lazy val aggResult = ds.select(aggExprs: _*).queryExecution.toRdd.collect().head
 
-      val rows = basicStats.zip(statisticFns).map { case (aggregation, (statistic, _)) =>
-        Row(statistic :: aggregation.toList: _*)
-      }
+    // We will have one row for each selected statistic in the result.
+    val result = Array.fill[InternalRow](selectedStatistics.length) {
+      // each row has the statistic name, and statistic values of each selected column.
+      new GenericInternalRow(selectedCols.length + 1)
+    }
 
-      if (hasPercentiles) {
-        def nullSafeString(x: Any) = if (x == null) null else x.toString
-        val percentileRows = grouped.head
-          .map {
-            case a: Seq[Any] => a
-            case _ => Seq.fill(percentiles.length)(null: Any)
-          }
-          .transpose
-          .zip(percentileNames)
-          .map { case (values: Seq[Any], name) =>
-            Row(name :: values.map(nullSafeString).toList: _*)
-          }
-        (rows ++ percentileRows)
-          .sortWith((left, right) =>
-            selectedStatistics.indexOf(left(0)) < selectedStatistics.indexOf(right(0)))
-      } else {
-        rows
+    var rowIndex = 0
+    while (rowIndex < result.length) {
+      val statsName = selectedStatistics(rowIndex)
+      result(rowIndex).update(0, UTF8String.fromString(statsName))
+      for (colIndex <- selectedCols.indices) {
+        val statsValue = aggResult.getUTF8String(rowIndex * selectedCols.length + colIndex)
+        result(rowIndex).update(colIndex + 1, statsValue)
       }
-    } else {
-      // If there are no output columns, just output a single column that contains the stats.
-      selectedStatistics.map(Row(_))
+      rowIndex += 1
     }
 
     // All columns are string type
-    val schema = StructType(
-      StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes
-    // `toArray` forces materialization to make the seq serializable
-    Dataset.ofRows(ds.sparkSession, LocalRelation.fromExternalRows(schema, ret.toArray.toSeq))
-  }
+    val output = AttributeReference("summary", StringType)() +:
+      selectedCols.map(c => AttributeReference(c.name, StringType)())
 
+    Dataset.ofRows(ds.sparkSession, LocalRelation(output, result))
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 2c7051bf431c3968bae5a357be14c418816b0817..b2219b4eb8c172cbd9d1a21c273ce4fd4678de34 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -770,7 +770,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     val fooE = intercept[IllegalArgumentException] {
       person2.summary("foo")
     }
-    assert(fooE.getMessage === "requirement failed: foo is not a recognised statistic")
+    assert(fooE.getMessage === "foo is not a recognised statistic")
 
     val parseE = intercept[IllegalArgumentException] {
       person2.summary("foo%")