Skip to content
Snippets Groups Projects
Commit 039ed9fe authored by wangzhenhua's avatar wangzhenhua Committed by gatorsmile
Browse files

[SPARK-19271][SQL] Change non-cbo estimation of aggregate

## What changes were proposed in this pull request?

Change non-cbo estimation behavior of aggregate:
- If groupExpression is empty, we can know row count (=1) and the corresponding size;
- otherwise, estimation falls back to UnaryNode's computeStats method, which should not propagate rowCount and attributeStats in Statistics because they are not estimated in that method.

## How was this patch tested?

Added test case

Author: wangzhenhua <wangzhenhua@huawei.com>

Closes #16631 from wzhfy/aggNoCbo.
parent 0bf605c2
No related branches found
No related tags found
No related merge requests found
...@@ -344,7 +344,8 @@ abstract class UnaryNode extends LogicalPlan { ...@@ -344,7 +344,8 @@ abstract class UnaryNode extends LogicalPlan {
sizeInBytes = 1 sizeInBytes = 1
} }
child.stats(conf).copy(sizeInBytes = sizeInBytes) // Don't propagate rowCount and attributeStats, since they are not estimated here.
Statistics(sizeInBytes = sizeInBytes, isBroadcastable = child.stats(conf).isBroadcastable)
} }
} }
......
...@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTypes} ...@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTypes}
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, ProjectEstimation} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, EstimationUtils, ProjectEstimation}
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
...@@ -541,7 +541,10 @@ case class Aggregate( ...@@ -541,7 +541,10 @@ case class Aggregate(
override def computeStats(conf: CatalystConf): Statistics = { override def computeStats(conf: CatalystConf): Statistics = {
def simpleEstimation: Statistics = { def simpleEstimation: Statistics = {
if (groupingExpressions.isEmpty) { if (groupingExpressions.isEmpty) {
super.computeStats(conf).copy(sizeInBytes = 1) Statistics(
sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1),
rowCount = Some(1),
isBroadcastable = child.stats(conf).isBroadcastable)
} else { } else {
super.computeStats(conf) super.computeStats(conf)
} }
......
...@@ -53,7 +53,7 @@ object AggregateEstimation { ...@@ -53,7 +53,7 @@ object AggregateEstimation {
val outputAttrStats = getOutputMap(childStats.attributeStats, agg.output) val outputAttrStats = getOutputMap(childStats.attributeStats, agg.output)
Some(Statistics( Some(Statistics(
sizeInBytes = getOutputSize(agg.output, outputAttrStats, outputRows), sizeInBytes = getOutputSize(agg.output, outputRows, outputAttrStats),
rowCount = Some(outputRows), rowCount = Some(outputRows),
attributeStats = outputAttrStats, attributeStats = outputAttrStats,
isBroadcastable = childStats.isBroadcastable)) isBroadcastable = childStats.isBroadcastable))
......
...@@ -37,8 +37,8 @@ object EstimationUtils { ...@@ -37,8 +37,8 @@ object EstimationUtils {
def getOutputSize( def getOutputSize(
attributes: Seq[Attribute], attributes: Seq[Attribute],
attrStats: AttributeMap[ColumnStat], outputRowCount: BigInt,
outputRowCount: BigInt): BigInt = { attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = {
// We assign a generic overhead for a Row object, the actual overhead is different for different // We assign a generic overhead for a Row object, the actual overhead is different for different
// Row format. // Row format.
val sizePerRow = 8 + attributes.map { attr => val sizePerRow = 8 + attributes.map { attr =>
......
...@@ -36,7 +36,7 @@ object ProjectEstimation { ...@@ -36,7 +36,7 @@ object ProjectEstimation {
val outputAttrStats = val outputAttrStats =
getOutputMap(AttributeMap(inputAttrStats.toSeq ++ aliasStats), project.output) getOutputMap(AttributeMap(inputAttrStats.toSeq ++ aliasStats), project.output)
Some(childStats.copy( Some(childStats.copy(
sizeInBytes = getOutputSize(project.output, outputAttrStats, childStats.rowCount.get), sizeInBytes = getOutputSize(project.output, childStats.rowCount.get, outputAttrStats),
attributeStats = outputAttrStats)) attributeStats = outputAttrStats))
} else { } else {
None None
......
...@@ -90,6 +90,28 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { ...@@ -90,6 +90,28 @@ class AggregateEstimationSuite extends StatsEstimationTestBase {
expectedOutputRowCount = 0) expectedOutputRowCount = 0)
} }
test("non-cbo estimation") {
val attributes = Seq("key12").map(nameToAttr)
val child = StatsTestPlan(
outputList = attributes,
rowCount = 4,
// rowCount * (overhead + column size)
size = Some(4 * (8 + 4)),
attributeStats = AttributeMap(Seq("key12").map(nameToColInfo)))
val noGroupAgg = Aggregate(groupingExpressions = Nil,
aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child)
assert(noGroupAgg.stats(conf.copy(cboEnabled = false)) ==
// overhead + count result size
Statistics(sizeInBytes = 8 + 8, rowCount = Some(1)))
val hasGroupAgg = Aggregate(groupingExpressions = attributes,
aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child)
assert(hasGroupAgg.stats(conf.copy(cboEnabled = false)) ==
// From UnaryNode.computeStats, childSize * outputRowSize / childRowSize
Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4)))
}
private def checkAggStats( private def checkAggStats(
tableColumns: Seq[String], tableColumns: Seq[String],
tableRowCount: BigInt, tableRowCount: BigInt,
...@@ -107,7 +129,7 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { ...@@ -107,7 +129,7 @@ class AggregateEstimationSuite extends StatsEstimationTestBase {
val expectedAttrStats = AttributeMap(groupByColumns.map(nameToColInfo)) val expectedAttrStats = AttributeMap(groupByColumns.map(nameToColInfo))
val expectedStats = Statistics( val expectedStats = Statistics(
sizeInBytes = getOutputSize(testAgg.output, expectedAttrStats, expectedOutputRowCount), sizeInBytes = getOutputSize(testAgg.output, expectedOutputRowCount, expectedAttrStats),
rowCount = Some(expectedOutputRowCount), rowCount = Some(expectedOutputRowCount),
attributeStats = expectedAttrStats) attributeStats = expectedAttrStats)
......
...@@ -45,11 +45,12 @@ class StatsEstimationTestBase extends SparkFunSuite { ...@@ -45,11 +45,12 @@ class StatsEstimationTestBase extends SparkFunSuite {
protected case class StatsTestPlan( protected case class StatsTestPlan(
outputList: Seq[Attribute], outputList: Seq[Attribute],
rowCount: BigInt, rowCount: BigInt,
attributeStats: AttributeMap[ColumnStat]) extends LeafNode { attributeStats: AttributeMap[ColumnStat],
size: Option[BigInt] = None) extends LeafNode {
override def output: Seq[Attribute] = outputList override def output: Seq[Attribute] = outputList
override def computeStats(conf: CatalystConf): Statistics = Statistics( override def computeStats(conf: CatalystConf): Statistics = Statistics(
// sizeInBytes in stats of StatsTestPlan is useless in cbo estimation, we just use a fake value // If sizeInBytes is useless in testing, we just use a fake value
sizeInBytes = Int.MaxValue, sizeInBytes = size.getOrElse(Int.MaxValue),
rowCount = Some(rowCount), rowCount = Some(rowCount),
attributeStats = attributeStats) attributeStats = attributeStats)
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment