diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 0587a592145be0771dd2c64321fb0a802105b95a..93550e1fc32ab611a09cad742818ddca63fd17bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -344,7 +344,8 @@ abstract class UnaryNode extends LogicalPlan { sizeInBytes = 1 } - child.stats(conf).copy(sizeInBytes = sizeInBytes) + // Don't propagate rowCount and attributeStats, since they are not estimated here. + Statistics(sizeInBytes = sizeInBytes, isBroadcastable = child.stats(conf).isBroadcastable) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 3bd314315d27930d3ae15c2d84ab65de12ba18da..432097d6218d07f29b16f45e9f167fc071be69c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTypes} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, ProjectEstimation} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, EstimationUtils, ProjectEstimation} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -541,7 +541,10 @@ case class Aggregate( override def computeStats(conf: CatalystConf): Statistics = { def simpleEstimation: Statistics = { if (groupingExpressions.isEmpty) { - super.computeStats(conf).copy(sizeInBytes = 1) + Statistics( + sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1), + rowCount = Some(1), + isBroadcastable = child.stats(conf).isBroadcastable) } else { super.computeStats(conf) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index 21e94fc941a5c2fe90af5704347f2d184781d9fd..ce74554c17010961733b94c2a8d5f50eb1b5e9bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -53,7 +53,7 @@ object AggregateEstimation { val outputAttrStats = getOutputMap(childStats.attributeStats, agg.output) Some(Statistics( - sizeInBytes = getOutputSize(agg.output, outputAttrStats, outputRows), + sizeInBytes = getOutputSize(agg.output, outputRows, outputAttrStats), rowCount = Some(outputRows), attributeStats = outputAttrStats, isBroadcastable = childStats.isBroadcastable)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index cf4452d0fdfba694cb07d616735dfac58226c5b1..e8b794212c10d6c4fcd7d6fcb7f9d6d8cfaf3355 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -37,8 +37,8 @@ object EstimationUtils { def getOutputSize( attributes: Seq[Attribute], - attrStats: AttributeMap[ColumnStat], - outputRowCount: BigInt): BigInt = { + outputRowCount: BigInt, + attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = { // We assign a generic overhead for a Row object, the actual overhead is different for different // Row format. val sizePerRow = 8 + attributes.map { attr => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala index 50b869ab3ae89d93bba1cf20a0b3899e6c390964..e9084ad8b859c9a439fa3dd9232f741feaad6798 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala @@ -36,7 +36,7 @@ object ProjectEstimation { val outputAttrStats = getOutputMap(AttributeMap(inputAttrStats.toSeq ++ aliasStats), project.output) Some(childStats.copy( - sizeInBytes = getOutputSize(project.output, outputAttrStats, childStats.rowCount.get), + sizeInBytes = getOutputSize(project.output, childStats.rowCount.get, outputAttrStats), attributeStats = outputAttrStats)) } else { None diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index 41a4bc359e4ce32623932230c6dbbeb25f44a0d1..c0b9515ca7cd0d65d27dd068ae9e3f201e458c68 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -90,6 +90,28 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { expectedOutputRowCount = 0) } + test("non-cbo estimation") { + val attributes = Seq("key12").map(nameToAttr) + val child = StatsTestPlan( + outputList = attributes, + rowCount = 4, + // rowCount * (overhead + column size) + size = Some(4 * (8 + 4)), + attributeStats = AttributeMap(Seq("key12").map(nameToColInfo))) + + val noGroupAgg = Aggregate(groupingExpressions = Nil, + aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child) + assert(noGroupAgg.stats(conf.copy(cboEnabled = false)) == + // overhead + count result size + Statistics(sizeInBytes = 8 + 8, rowCount = Some(1))) + + val hasGroupAgg = Aggregate(groupingExpressions = attributes, + aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child) + assert(hasGroupAgg.stats(conf.copy(cboEnabled = false)) == + // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize + Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4))) + } + private def checkAggStats( tableColumns: Seq[String], tableRowCount: BigInt, @@ -107,7 +129,7 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { val expectedAttrStats = AttributeMap(groupByColumns.map(nameToColInfo)) val expectedStats = Statistics( - sizeInBytes = getOutputSize(testAgg.output, expectedAttrStats, expectedOutputRowCount), + sizeInBytes = getOutputSize(testAgg.output, expectedOutputRowCount, expectedAttrStats), rowCount = Some(expectedOutputRowCount), attributeStats = expectedAttrStats) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index e6adb6700a6566700a78c0a695fe4b9a9afce9f4..a5fac4ba6f03c3eec94d7256ec37ddd3e8cea66f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -45,11 +45,12 @@ class StatsEstimationTestBase extends SparkFunSuite { protected case class StatsTestPlan( outputList: Seq[Attribute], rowCount: BigInt, - attributeStats: AttributeMap[ColumnStat]) extends LeafNode { + attributeStats: AttributeMap[ColumnStat], + size: Option[BigInt] = None) extends LeafNode { override def output: Seq[Attribute] = outputList override def computeStats(conf: CatalystConf): Statistics = Statistics( - // sizeInBytes in stats of StatsTestPlan is useless in cbo estimation, we just use a fake value - sizeInBytes = Int.MaxValue, + // If sizeInBytes is useless in testing, we just use a fake value + sizeInBytes = size.getOrElse(Int.MaxValue), rowCount = Some(rowCount), attributeStats = attributeStats) }