Skip to content
Snippets Groups Projects
Commit 15c2bd01 authored by Zhenhua Wang's avatar Zhenhua Wang Committed by Reynold Xin
Browse files

[SPARK-19020][SQL] Cardinality estimation of aggregate operator

## What changes were proposed in this pull request?

Support cardinality estimation of aggregate operator

## How was this patch tested?

Add test cases

Author: Zhenhua Wang <wzh_zju@163.com>
Author: wangzhenhua <wangzhenhua@huawei.com>

Closes #16431 from wzhfy/aggEstimation.
parent 3ccabdfb
No related branches found
No related tags found
No related merge requests found
......@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ProjectEstimation
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, ProjectEstimation}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
......@@ -495,7 +495,7 @@ case class Aggregate(
child.constraints.union(getAliasedConstraints(nonAgg))
}
override lazy val statistics: Statistics = {
override lazy val statistics: Statistics = AggregateEstimation.estimate(this).getOrElse {
if (groupingExpressions.isEmpty) {
super.statistics.copy(sizeInBytes = 1)
} else {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics}
object AggregateEstimation {
import EstimationUtils._
/**
* Estimate the number of output rows based on column stats of group-by columns, and propagate
* column stats for aggregate expressions.
*/
def estimate(agg: Aggregate): Option[Statistics] = {
val childStats = agg.child.statistics
// Check if we have column stats for all group-by columns.
val colStatsExist = agg.groupingExpressions.forall { e =>
e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute])
}
if (rowCountsExist(agg.child) && colStatsExist) {
// Multiply distinct counts of group-by columns. This is an upper bound, which assumes
// the data contains all combinations of distinct values of group-by columns.
var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))(
(res, expr) => res * childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount)
// Here we set another upper bound for the number of output rows: it must not be larger than
// child's number of rows.
outputRows = outputRows.min(childStats.rowCount.get)
val outputAttrStats = getOutputMap(childStats.attributeStats, agg.output)
Some(Statistics(
sizeInBytes = outputRows * getRowSize(agg.output, outputAttrStats),
rowCount = Some(outputRows),
attributeStats = outputAttrStats,
isBroadcastable = childStats.isBroadcastable))
} else {
None
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.statsEstimation
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
class AggEstimationSuite extends StatsEstimationTestBase {
/** Columns for testing */
private val columnInfo: Map[Attribute, ColumnStat] =
Map(
attr("key11") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
avgLen = 4, maxLen = 4),
attr("key12") -> ColumnStat(distinctCount = 1, min = Some(10), max = Some(10), nullCount = 0,
avgLen = 4, maxLen = 4),
attr("key21") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
avgLen = 4, maxLen = 4),
attr("key22") -> ColumnStat(distinctCount = 4, min = Some(10), max = Some(40), nullCount = 0,
avgLen = 4, maxLen = 4),
attr("key31") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
avgLen = 4, maxLen = 4),
attr("key32") -> ColumnStat(distinctCount = 2, min = Some(10), max = Some(20), nullCount = 0,
avgLen = 4, maxLen = 4))
private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1)
private val nameToColInfo: Map[String, (Attribute, ColumnStat)] =
columnInfo.map(kv => kv._1.name -> kv)
test("empty group-by column") {
val colNames = Seq("key11", "key12")
// Suppose table1 has 2 records: (1, 10), (2, 10)
val table1 = StatsTestPlan(
outputList = colNames.map(nameToAttr),
stats = Statistics(
sizeInBytes = 2 * (4 + 4),
rowCount = Some(2),
attributeStats = AttributeMap(colNames.map(nameToColInfo))))
checkAggStats(
child = table1,
colNames = Nil,
expectedRowCount = 1)
}
test("there's a primary key in group-by columns") {
val colNames = Seq("key11", "key12")
// Suppose table1 has 2 records: (1, 10), (2, 10)
val table1 = StatsTestPlan(
outputList = colNames.map(nameToAttr),
stats = Statistics(
sizeInBytes = 2 * (4 + 4),
rowCount = Some(2),
attributeStats = AttributeMap(colNames.map(nameToColInfo))))
checkAggStats(
child = table1,
colNames = colNames,
// Column key11 a primary key, so row count = ndv of key11 = child's row count
expectedRowCount = table1.stats.rowCount.get)
}
test("the product of ndv's of group-by columns is too large") {
val colNames = Seq("key21", "key22")
// Suppose table2 has 4 records: (1, 10), (1, 20), (2, 30), (2, 40)
val table2 = StatsTestPlan(
outputList = colNames.map(nameToAttr),
stats = Statistics(
sizeInBytes = 4 * (4 + 4),
rowCount = Some(4),
attributeStats = AttributeMap(colNames.map(nameToColInfo))))
checkAggStats(
child = table2,
colNames = colNames,
// Use child's row count as an upper bound
expectedRowCount = table2.stats.rowCount.get)
}
test("data contains all combinations of distinct values of group-by columns.") {
val colNames = Seq("key31", "key32")
// Suppose table3 has 6 records: (1, 10), (1, 10), (1, 20), (2, 20), (2, 10), (2, 10)
val table3 = StatsTestPlan(
outputList = colNames.map(nameToAttr),
stats = Statistics(
sizeInBytes = 6 * (4 + 4),
rowCount = Some(6),
attributeStats = AttributeMap(colNames.map(nameToColInfo))))
checkAggStats(
child = table3,
colNames = colNames,
// Row count = product of ndv
expectedRowCount = nameToColInfo("key31")._2.distinctCount * nameToColInfo("key32")._2
.distinctCount)
}
private def checkAggStats(
child: LogicalPlan,
colNames: Seq[String],
expectedRowCount: BigInt): Unit = {
val columns = colNames.map(nameToAttr)
val testAgg = Aggregate(
groupingExpressions = columns,
aggregateExpressions = columns :+ Alias(Count(Literal(1)), "cnt")(),
child = child)
val expectedAttrStats = AttributeMap(colNames.map(nameToColInfo))
val expectedStats = Statistics(
sizeInBytes = expectedRowCount * getRowSize(testAgg.output, expectedAttrStats),
rowCount = Some(expectedRowCount),
attributeStats = expectedAttrStats)
assert(testAgg.statistics == expectedStats)
}
}
......@@ -18,12 +18,15 @@
package org.apache.spark.sql.catalyst.statsEstimation
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.types.IntegerType
class StatsEstimationTestBase extends SparkFunSuite {
def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)()
/** Convert (column name, column stat) pairs to an AttributeMap based on plan output. */
def toAttributeMap(colStats: Seq[(String, ColumnStat)], plan: LogicalPlan)
: AttributeMap[ColumnStat] = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment