diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 5326b45b50a8b01b81f47701aaa05bf80607a5b3..dfb51192c69bc1d0debc290462531bc65910d743 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -224,15 +224,6 @@ class Dataset[T] private[sql]( } } - private[sql] def aggregatableColumns: Seq[Expression] = { - schema.fields - .filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType]) - .map { n => - queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver) - .get - } - } - /** * Compose the string representing rows for output * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 436e18fdb5ff5c331d9deecd2845334e906f3fd1..a75cfb3600225d8bb86ebf7a701266ac9a939869 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.execution.stat +import java.util.Locale + import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.expressions.{Cast, CreateArray, Expression, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, GenericInternalRow, GetArrayItem, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.catalyst.util.{usePrettyExpression, QuantileSummaries} +import org.apache.spark.sql.catalyst.util.QuantileSummaries import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -228,90 +231,68 @@ object StatFunctions extends Logging { val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max") val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics - val hasPercentiles = selectedStatistics.exists(_.endsWith("%")) - val (percentiles, percentileNames, remainingAggregates) = if (hasPercentiles) { - val (pStrings, rest) = selectedStatistics.partition(a => a.endsWith("%")) - val percentiles = pStrings.map { p => - try { - p.stripSuffix("%").toDouble / 100.0 - } catch { - case e: NumberFormatException => - throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e) - } + val percentiles = selectedStatistics.filter(a => a.endsWith("%")).map { p => + try { + p.stripSuffix("%").toDouble / 100.0 + } catch { + case e: NumberFormatException => + throw new IllegalArgumentException(s"Unable to parse $p as a percentile", e) } - require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") - (percentiles, pStrings, rest) - } else { - (Seq(), Seq(), selectedStatistics) } + require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]") - - // The list of summary statistics to compute, in the form of expressions. - val availableStatistics = Map[String, Expression => Expression]( - "count" -> ((child: Expression) => Count(child).toAggregateExpression()), - "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), - "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), - "min" -> ((child: Expression) => Min(child).toAggregateExpression()), - "max" -> ((child: Expression) => Max(child).toAggregateExpression())) - - val statisticFns = remainingAggregates.map { agg => - require(availableStatistics.contains(agg), s"$agg is not a recognised statistic") - agg -> availableStatistics(agg) - } - - def percentileAgg(child: Expression): Expression = - new ApproximatePercentile(child, CreateArray(percentiles.map(Literal(_)))) - .toAggregateExpression() - - val outputCols = ds.aggregatableColumns.map(usePrettyExpression(_).sql).toList - - val ret: Seq[Row] = if (outputCols.nonEmpty) { - var aggExprs = statisticFns.toList.flatMap { case (_, colToAgg) => - outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) - } - if (hasPercentiles) { - aggExprs = outputCols.map(c => Column(percentileAgg(Column(c).expr)).as(c)) ++ aggExprs + var percentileIndex = 0 + val statisticFns = selectedStatistics.map { stats => + if (stats.endsWith("%")) { + val index = percentileIndex + percentileIndex += 1 + (child: Expression) => + GetArrayItem( + new ApproximatePercentile(child, Literal.create(percentiles)).toAggregateExpression(), + Literal(index)) + } else { + stats.toLowerCase(Locale.ROOT) match { + case "count" => (child: Expression) => Count(child).toAggregateExpression() + case "mean" => (child: Expression) => Average(child).toAggregateExpression() + case "stddev" => (child: Expression) => StddevSamp(child).toAggregateExpression() + case "min" => (child: Expression) => Min(child).toAggregateExpression() + case "max" => (child: Expression) => Max(child).toAggregateExpression() + case _ => throw new IllegalArgumentException(s"$stats is not a recognised statistic") + } } + } - val row = ds.groupBy().agg(aggExprs.head, aggExprs.tail: _*).head().toSeq + val selectedCols = ds.logicalPlan.output + .filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType]) - // Pivot the data so each summary is one row - val grouped: Seq[Seq[Any]] = row.grouped(outputCols.size).toSeq + val aggExprs = statisticFns.flatMap { func => + selectedCols.map(c => Column(Cast(func(c), StringType)).as(c.name)) + } - val basicStats = if (hasPercentiles) grouped.tail else grouped + // If there is no selected columns, we don't need to run this aggregate, so make it a lazy val. + lazy val aggResult = ds.select(aggExprs: _*).queryExecution.toRdd.collect().head - val rows = basicStats.zip(statisticFns).map { case (aggregation, (statistic, _)) => - Row(statistic :: aggregation.toList: _*) - } + // We will have one row for each selected statistic in the result. + val result = Array.fill[InternalRow](selectedStatistics.length) { + // each row has the statistic name, and statistic values of each selected column. + new GenericInternalRow(selectedCols.length + 1) + } - if (hasPercentiles) { - def nullSafeString(x: Any) = if (x == null) null else x.toString - val percentileRows = grouped.head - .map { - case a: Seq[Any] => a - case _ => Seq.fill(percentiles.length)(null: Any) - } - .transpose - .zip(percentileNames) - .map { case (values: Seq[Any], name) => - Row(name :: values.map(nullSafeString).toList: _*) - } - (rows ++ percentileRows) - .sortWith((left, right) => - selectedStatistics.indexOf(left(0)) < selectedStatistics.indexOf(right(0))) - } else { - rows + var rowIndex = 0 + while (rowIndex < result.length) { + val statsName = selectedStatistics(rowIndex) + result(rowIndex).update(0, UTF8String.fromString(statsName)) + for (colIndex <- selectedCols.indices) { + val statsValue = aggResult.getUTF8String(rowIndex * selectedCols.length + colIndex) + result(rowIndex).update(colIndex + 1, statsValue) } - } else { - // If there are no output columns, just output a single column that contains the stats. - selectedStatistics.map(Row(_)) + rowIndex += 1 } // All columns are string type - val schema = StructType( - StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes - // `toArray` forces materialization to make the seq serializable - Dataset.ofRows(ds.sparkSession, LocalRelation.fromExternalRows(schema, ret.toArray.toSeq)) - } + val output = AttributeReference("summary", StringType)() +: + selectedCols.map(c => AttributeReference(c.name, StringType)()) + Dataset.ofRows(ds.sparkSession, LocalRelation(output, result)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 2c7051bf431c3968bae5a357be14c418816b0817..b2219b4eb8c172cbd9d1a21c273ce4fd4678de34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -770,7 +770,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val fooE = intercept[IllegalArgumentException] { person2.summary("foo") } - assert(fooE.getMessage === "requirement failed: foo is not a recognised statistic") + assert(fooE.getMessage === "foo is not a recognised statistic") val parseE = intercept[IllegalArgumentException] { person2.summary("foo%")