From 039ed9fe8a2fdcd99e0561af64cda8fe3406bc12 Mon Sep 17 00:00:00 2001
From: wangzhenhua <wangzhenhua@huawei.com>
Date: Thu, 19 Jan 2017 22:18:47 -0800
Subject: [PATCH] [SPARK-19271][SQL] Change non-cbo estimation of aggregate

## What changes were proposed in this pull request?

Change non-cbo estimation behavior of aggregate:
- If groupExpression is empty, we can know row count (=1) and the corresponding size;
- otherwise, estimation falls back to UnaryNode's computeStats method, which should not propagate rowCount and attributeStats in Statistics because they are not estimated in that method.

## How was this patch tested?

Added test case

Author: wangzhenhua <wangzhenhua@huawei.com>

Closes #16631 from wzhfy/aggNoCbo.
---
 .../catalyst/plans/logical/LogicalPlan.scala  |  3 ++-
 .../plans/logical/basicLogicalOperators.scala |  7 ++++--
 .../statsEstimation/AggregateEstimation.scala |  2 +-
 .../statsEstimation/EstimationUtils.scala     |  4 ++--
 .../statsEstimation/ProjectEstimation.scala   |  2 +-
 .../AggregateEstimationSuite.scala            | 24 ++++++++++++++++++-
 .../StatsEstimationTestBase.scala             |  7 +++---
 7 files changed, 38 insertions(+), 11 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 0587a59214..93550e1fc3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -344,7 +344,8 @@ abstract class UnaryNode extends LogicalPlan {
       sizeInBytes = 1
     }
 
-    child.stats(conf).copy(sizeInBytes = sizeInBytes)
+    // Don't propagate rowCount and attributeStats, since they are not estimated here.
+    Statistics(sizeInBytes = sizeInBytes, isBroadcastable = child.stats(conf).isBroadcastable)
   }
 }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 3bd314315d..432097d621 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTypes}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, ProjectEstimation}
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, EstimationUtils, ProjectEstimation}
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
@@ -541,7 +541,10 @@ case class Aggregate(
   override def computeStats(conf: CatalystConf): Statistics = {
     def simpleEstimation: Statistics = {
       if (groupingExpressions.isEmpty) {
-        super.computeStats(conf).copy(sizeInBytes = 1)
+        Statistics(
+          sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1),
+          rowCount = Some(1),
+          isBroadcastable = child.stats(conf).isBroadcastable)
       } else {
         super.computeStats(conf)
       }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
index 21e94fc941..ce74554c17 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
@@ -53,7 +53,7 @@ object AggregateEstimation {
 
       val outputAttrStats = getOutputMap(childStats.attributeStats, agg.output)
       Some(Statistics(
-        sizeInBytes = getOutputSize(agg.output, outputAttrStats, outputRows),
+        sizeInBytes = getOutputSize(agg.output, outputRows, outputAttrStats),
         rowCount = Some(outputRows),
         attributeStats = outputAttrStats,
         isBroadcastable = childStats.isBroadcastable))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
index cf4452d0fd..e8b794212c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
@@ -37,8 +37,8 @@ object EstimationUtils {
 
   def getOutputSize(
       attributes: Seq[Attribute],
-      attrStats: AttributeMap[ColumnStat],
-      outputRowCount: BigInt): BigInt = {
+      outputRowCount: BigInt,
+      attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = {
     // We assign a generic overhead for a Row object, the actual overhead is different for different
     // Row format.
     val sizePerRow = 8 + attributes.map { attr =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
index 50b869ab3a..e9084ad8b8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
@@ -36,7 +36,7 @@ object ProjectEstimation {
       val outputAttrStats =
         getOutputMap(AttributeMap(inputAttrStats.toSeq ++ aliasStats), project.output)
       Some(childStats.copy(
-        sizeInBytes = getOutputSize(project.output, outputAttrStats, childStats.rowCount.get),
+        sizeInBytes = getOutputSize(project.output, childStats.rowCount.get, outputAttrStats),
         attributeStats = outputAttrStats))
     } else {
       None
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
index 41a4bc359e..c0b9515ca7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
@@ -90,6 +90,28 @@ class AggregateEstimationSuite extends StatsEstimationTestBase {
       expectedOutputRowCount = 0)
   }
 
+  test("non-cbo estimation") {
+    val attributes = Seq("key12").map(nameToAttr)
+    val child = StatsTestPlan(
+      outputList = attributes,
+      rowCount = 4,
+      // rowCount * (overhead + column size)
+      size = Some(4 * (8 + 4)),
+      attributeStats = AttributeMap(Seq("key12").map(nameToColInfo)))
+
+    val noGroupAgg = Aggregate(groupingExpressions = Nil,
+      aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child)
+    assert(noGroupAgg.stats(conf.copy(cboEnabled = false)) ==
+      // overhead + count result size
+      Statistics(sizeInBytes = 8 + 8, rowCount = Some(1)))
+
+    val hasGroupAgg = Aggregate(groupingExpressions = attributes,
+      aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child)
+    assert(hasGroupAgg.stats(conf.copy(cboEnabled = false)) ==
+      // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize
+      Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4)))
+  }
+
   private def checkAggStats(
       tableColumns: Seq[String],
       tableRowCount: BigInt,
@@ -107,7 +129,7 @@ class AggregateEstimationSuite extends StatsEstimationTestBase {
 
     val expectedAttrStats = AttributeMap(groupByColumns.map(nameToColInfo))
     val expectedStats = Statistics(
-      sizeInBytes = getOutputSize(testAgg.output, expectedAttrStats, expectedOutputRowCount),
+      sizeInBytes = getOutputSize(testAgg.output, expectedOutputRowCount, expectedAttrStats),
       rowCount = Some(expectedOutputRowCount),
       attributeStats = expectedAttrStats)
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
index e6adb6700a..a5fac4ba6f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
@@ -45,11 +45,12 @@ class StatsEstimationTestBase extends SparkFunSuite {
 protected case class StatsTestPlan(
     outputList: Seq[Attribute],
     rowCount: BigInt,
-    attributeStats: AttributeMap[ColumnStat]) extends LeafNode {
+    attributeStats: AttributeMap[ColumnStat],
+    size: Option[BigInt] = None) extends LeafNode {
   override def output: Seq[Attribute] = outputList
   override def computeStats(conf: CatalystConf): Statistics = Statistics(
-    // sizeInBytes in stats of StatsTestPlan is useless in cbo estimation, we just use a fake value
-    sizeInBytes = Int.MaxValue,
+    // If sizeInBytes is useless in testing, we just use a fake value
+    sizeInBytes = size.getOrElse(Int.MaxValue),
     rowCount = Some(rowCount),
     attributeStats = attributeStats)
 }
-- 
GitLab