From a6155135690433988aa0cbf22f260f52a235e9f5 Mon Sep 17 00:00:00 2001 From: wangzhenhua <wangzhenhua@huawei.com> Date: Tue, 10 Jan 2017 22:34:44 -0800 Subject: [PATCH] [SPARK-19149][SQL] Unify two sets of statistics in LogicalPlan ## What changes were proposed in this pull request? Currently we have two sets of statistics in LogicalPlan: a simple stats and a stats estimated by cbo, but the computing logic and naming are quite confusing, we need to unify these two sets of stats. ## How was this patch tested? Just modify existing tests. Author: wangzhenhua <wangzhenhua@huawei.com> Author: Zhenhua Wang <wzh_zju@163.com> Closes #16529 from wzhfy/unifyStats. --- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../plans/logical/LocalRelation.scala | 6 +- .../catalyst/plans/logical/LogicalPlan.scala | 48 +++++------ .../plans/logical/basicLogicalOperators.scala | 82 ++++++++++--------- .../statsEstimation/AggregateEstimation.scala | 7 +- .../statsEstimation/EstimationUtils.scala | 5 +- .../statsEstimation/ProjectEstimation.scala | 7 +- .../analysis/DecimalPrecisionSuite.scala | 1 - .../SubstituteUnresolvedOrdinalsSuite.scala | 1 - .../optimizer/AggregateOptimizeSuite.scala | 2 +- .../optimizer/EliminateSortsSuite.scala | 2 +- .../optimizer/JoinOptimizationSuite.scala | 2 +- .../optimizer/LimitPushdownSuite.scala | 8 +- .../RewriteDistinctAggregatesSuite.scala | 2 +- .../spark/sql/catalyst/plans/PlanTest.scala | 4 + .../statsEstimation/AggEstimationSuite.scala | 2 +- .../ProjectEstimationSuite.scala | 2 +- .../statsEstimation/StatsConfSuite.scala} | 29 +++---- .../StatsEstimationTestBase.scala | 6 +- .../spark/sql/execution/ExistingRDD.scala | 6 +- .../spark/sql/execution/SparkStrategies.scala | 12 +-- .../execution/columnar/InMemoryRelation.scala | 4 +- .../datasources/LogicalRelation.scala | 3 +- .../sql/execution/streaming/memory.scala | 4 +- .../apache/spark/sql/CachedTableSuite.scala | 2 +- .../org/apache/spark/sql/DatasetSuite.scala | 2 +- .../org/apache/spark/sql/JoinSuite.scala | 2 +- .../spark/sql/StatisticsCollectionSuite.scala | 22 ++--- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../datasources/HadoopFsRelationSuite.scala | 2 +- .../execution/streaming/MemorySinkSuite.scala | 8 +- .../apache/spark/sql/test/SQLTestData.scala | 3 + .../spark/sql/test/SharedSQLContext.scala | 1 + .../spark/sql/hive/HiveMetastoreCatalog.scala | 3 +- .../spark/sql/hive/MetastoreRelation.scala | 3 +- .../spark/sql/hive/StatisticsSuite.scala | 10 +-- 36 files changed, 161 insertions(+), 146 deletions(-) rename sql/{core/src/test/scala/org/apache/spark/sql/statsEstimation/StatsEstimationSuite.scala => catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala} (73%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d1f90e6a1a..cef17b8d25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -253,7 +253,7 @@ case class LimitPushDown(conf: CatalystConf) extends Rule[LogicalPlan] { case FullOuter => (left.maxRows, right.maxRows) match { case (None, None) => - if (left.planStats(conf).sizeInBytes >= right.planStats(conf).sizeInBytes) { + if (left.stats(conf).sizeInBytes >= right.stats(conf).sizeInBytes) { join.copy(left = maybePushLimit(exp, left)) } else { join.copy(right = maybePushLimit(exp, right)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 91633f5124..1faabcfcb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.types.{StructField, StructType} @@ -74,9 +74,9 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) } } - override lazy val statistics = + override def computeStats(conf: CatalystConf): Statistics = Statistics(sizeInBytes = - (output.map(n => BigInt(n.dataType.defaultSize))).sum * data.length) + output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) def toSQL(inlineTableName: String): String = { require(data.nonEmpty) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 4f634cb29d..9e5ba9ca8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -81,6 +81,21 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { } } + /** A cache for the estimated statistics, such that it will only be computed once. */ + private val statsCache = new ThreadLocal[Option[Statistics]] { + override protected def initialValue: Option[Statistics] = None + } + + def stats(conf: CatalystConf): Statistics = statsCache.get.getOrElse { + statsCache.set(Some(computeStats(conf))) + statsCache.get.get + } + + def invalidateStatsCache(): Unit = { + statsCache.set(None) + children.foreach(_.invalidateStatsCache()) + } + /** * Computes [[Statistics]] for this plan. The default implementation assumes the output * cardinality is the product of all child plan's cardinality, i.e. applies in the case @@ -88,36 +103,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { * * [[LeafNode]]s must override this. */ - def statistics: Statistics = { + protected def computeStats(conf: CatalystConf): Statistics = { if (children.isEmpty) { throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") } - Statistics(sizeInBytes = children.map(_.statistics.sizeInBytes).product) + Statistics(sizeInBytes = children.map(_.stats(conf).sizeInBytes).product) } - /** - * Returns the default statistics or statistics estimated by cbo based on configuration. - */ - final def planStats(conf: CatalystConf): Statistics = { - if (conf.cboEnabled) { - if (estimatedStats.isEmpty) { - estimatedStats = Some(cboStatistics(conf)) - } - estimatedStats.get - } else { - statistics - } - } - - /** - * Returns statistics estimated by cbo. If the plan doesn't override this, it returns the - * default statistics. - */ - protected def cboStatistics(conf: CatalystConf): Statistics = statistics - - /** A cache for the estimated statistics, such that it will only be computed once. */ - private var estimatedStats: Option[Statistics] = None - /** * Returns the maximum number of rows that this plan may compute. * @@ -334,20 +326,20 @@ abstract class UnaryNode extends LogicalPlan { override protected def validConstraints: Set[Expression] = child.constraints - override def statistics: Statistics = { + override def computeStats(conf: CatalystConf): Statistics = { // There should be some overhead in Row object, the size should not be zero when there is // no columns, this help to prevent divide-by-zero error. val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8 val outputRowSize = output.map(_.dataType.defaultSize).sum + 8 // Assume there will be the same number of rows as child has. - var sizeInBytes = (child.statistics.sizeInBytes * outputRowSize) / childRowSize + var sizeInBytes = (child.stats(conf).sizeInBytes * outputRowSize) / childRowSize if (sizeInBytes == 0) { // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero // (product of children). sizeInBytes = 1 } - child.statistics.copy(sizeInBytes = sizeInBytes) + child.stats(conf).copy(sizeInBytes = sizeInBytes) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index b97c81ce01..9bdae5eac3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.{CatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -55,8 +55,13 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend override def validConstraints: Set[Expression] = child.constraints.union(getAliasedConstraints(projectList)) - override lazy val statistics: Statistics = - ProjectEstimation.estimate(this).getOrElse(super.statistics) + override def computeStats(conf: CatalystConf): Statistics = { + if (conf.cboEnabled) { + ProjectEstimation.estimate(conf, this).getOrElse(super.computeStats(conf)) + } else { + super.computeStats(conf) + } + } } /** @@ -162,11 +167,11 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation } } - override lazy val statistics: Statistics = { - val leftSize = left.statistics.sizeInBytes - val rightSize = right.statistics.sizeInBytes + override def computeStats(conf: CatalystConf): Statistics = { + val leftSize = left.stats(conf).sizeInBytes + val rightSize = right.stats(conf).sizeInBytes val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize - val isBroadcastable = left.statistics.isBroadcastable || right.statistics.isBroadcastable + val isBroadcastable = left.stats(conf).isBroadcastable || right.stats(conf).isBroadcastable Statistics(sizeInBytes = sizeInBytes, isBroadcastable = isBroadcastable) } @@ -179,8 +184,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le override protected def validConstraints: Set[Expression] = leftConstraints - override lazy val statistics: Statistics = { - left.statistics.copy() + override def computeStats(conf: CatalystConf): Statistics = { + left.stats(conf).copy() } } @@ -218,8 +223,8 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { children.length > 1 && childrenResolved && allChildrenCompatible } - override lazy val statistics: Statistics = { - val sizeInBytes = children.map(_.statistics.sizeInBytes).sum + override def computeStats(conf: CatalystConf): Statistics = { + val sizeInBytes = children.map(_.stats(conf).sizeInBytes).sum Statistics(sizeInBytes = sizeInBytes) } @@ -327,14 +332,14 @@ case class Join( case _ => resolvedExceptNatural } - override lazy val statistics: Statistics = joinType match { + override def computeStats(conf: CatalystConf): Statistics = joinType match { case LeftAnti | LeftSemi => // LeftSemi and LeftAnti won't ever be bigger than left - left.statistics.copy() + left.stats(conf).copy() case _ => // make sure we don't propagate isBroadcastable in other joins, because // they could explode the size. - super.statistics.copy(isBroadcastable = false) + super.computeStats(conf).copy(isBroadcastable = false) } } @@ -345,7 +350,8 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output // set isBroadcastable to true so the child will be broadcasted - override lazy val statistics: Statistics = super.statistics.copy(isBroadcastable = true) + override def computeStats(conf: CatalystConf): Statistics = + super.computeStats(conf).copy(isBroadcastable = true) } /** @@ -462,7 +468,7 @@ case class Range( override def newInstance(): Range = copy(output = output.map(_.newInstance())) - override lazy val statistics: Statistics = { + override def computeStats(conf: CatalystConf): Statistics = { val sizeInBytes = LongType.defaultSize * numElements Statistics( sizeInBytes = sizeInBytes ) } @@ -495,11 +501,19 @@ case class Aggregate( child.constraints.union(getAliasedConstraints(nonAgg)) } - override lazy val statistics: Statistics = AggregateEstimation.estimate(this).getOrElse { - if (groupingExpressions.isEmpty) { - super.statistics.copy(sizeInBytes = 1) + override def computeStats(conf: CatalystConf): Statistics = { + def simpleEstimation: Statistics = { + if (groupingExpressions.isEmpty) { + super.computeStats(conf).copy(sizeInBytes = 1) + } else { + super.computeStats(conf) + } + } + + if (conf.cboEnabled) { + AggregateEstimation.estimate(conf, this).getOrElse(simpleEstimation) } else { - super.statistics + simpleEstimation } } } @@ -600,8 +614,8 @@ case class Expand( override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) - override lazy val statistics: Statistics = { - val sizeInBytes = super.statistics.sizeInBytes * projections.length + override def computeStats(conf: CatalystConf): Statistics = { + val sizeInBytes = super.computeStats(conf).sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } @@ -671,7 +685,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN case _ => None } } - override lazy val statistics: Statistics = { + override def computeStats(conf: CatalystConf): Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] val sizeInBytes = if (limit == 0) { // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero @@ -680,7 +694,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN } else { (limit: Long) * output.map(a => a.dataType.defaultSize).sum } - child.statistics.copy(sizeInBytes = sizeInBytes) + child.stats(conf).copy(sizeInBytes = sizeInBytes) } } @@ -692,7 +706,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case _ => None } } - override lazy val statistics: Statistics = { + override def computeStats(conf: CatalystConf): Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] val sizeInBytes = if (limit == 0) { // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero @@ -701,7 +715,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo } else { (limit: Long) * output.map(a => a.dataType.defaultSize).sum } - child.statistics.copy(sizeInBytes = sizeInBytes) + child.stats(conf).copy(sizeInBytes = sizeInBytes) } } @@ -735,14 +749,14 @@ case class Sample( override def output: Seq[Attribute] = child.output - override lazy val statistics: Statistics = { + override def computeStats(conf: CatalystConf): Statistics = { val ratio = upperBound - lowerBound // BigInt can't multiply with Double - var sizeInBytes = child.statistics.sizeInBytes * (ratio * 100).toInt / 100 + var sizeInBytes = child.stats(conf).sizeInBytes * (ratio * 100).toInt / 100 if (sizeInBytes == 0) { sizeInBytes = 1 } - child.statistics.copy(sizeInBytes = sizeInBytes) + child.stats(conf).copy(sizeInBytes = sizeInBytes) } override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil @@ -796,13 +810,5 @@ case class RepartitionByExpression( case object OneRowRelation extends LeafNode { override def maxRows: Option[Long] = Some(1) override def output: Seq[Attribute] = Nil - - /** - * Computes [[Statistics]] for this plan. The default implementation assumes the output - * cardinality is the product of all child plan's cardinality, i.e. applies in the case - * of cartesian joins. - * - * [[LeafNode]]s must override this. - */ - override lazy val statistics: Statistics = Statistics(sizeInBytes = 1) + override def computeStats(conf: CatalystConf): Statistics = Statistics(sizeInBytes = 1) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index 33ebc380d2..af673430c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation +import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics} @@ -28,13 +29,13 @@ object AggregateEstimation { * Estimate the number of output rows based on column stats of group-by columns, and propagate * column stats for aggregate expressions. */ - def estimate(agg: Aggregate): Option[Statistics] = { - val childStats = agg.child.statistics + def estimate(conf: CatalystConf, agg: Aggregate): Option[Statistics] = { + val childStats = agg.child.stats(conf) // Check if we have column stats for all group-by columns. val colStatsExist = agg.groupingExpressions.forall { e => e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute]) } - if (rowCountsExist(agg.child) && colStatsExist) { + if (rowCountsExist(conf, agg.child) && colStatsExist) { // Multiply distinct counts of group-by columns. This is an upper bound, which assumes // the data contains all combinations of distinct values of group-by columns. var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index f099e32267..c7eb6f0d7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation +import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} import org.apache.spark.sql.types.StringType @@ -25,8 +26,8 @@ import org.apache.spark.sql.types.StringType object EstimationUtils { /** Check if each plan has rowCount in its statistics. */ - def rowCountsExist(plans: LogicalPlan*): Boolean = - plans.forall(_.statistics.rowCount.isDefined) + def rowCountsExist(conf: CatalystConf, plans: LogicalPlan*): Boolean = + plans.forall(_.stats(conf).rowCount.isDefined) /** Get column stats for output attributes. */ def getOutputMap(inputMap: AttributeMap[ColumnStat], output: Seq[Attribute]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala index 6d63b09fd4..69c546b01b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala @@ -17,15 +17,16 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation +import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{Project, Statistics} object ProjectEstimation { import EstimationUtils._ - def estimate(project: Project): Option[Statistics] = { - if (rowCountsExist(project.child)) { - val childStats = project.child.statistics + def estimate(conf: CatalystConf, project: Project): Option[Statistics] = { + if (rowCountsExist(conf, project.child)) { + val childStats = project.child.stats(conf) val inputAttrStats = childStats.attributeStats // Match alias with its child's column stat val aliasStats = project.expressions.collect { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 66d9b4c8e3..6995faebfa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.types._ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { - private val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true) private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) private val analyzer = new Analyzer(catalog, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala index 3c429ebce1..88f68ebadc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.SimpleCatalystConf class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { - private lazy val conf = SimpleCatalystConf(caseSensitiveAnalysis = true) private lazy val a = testRelation2.output(0) private lazy val b = testRelation2.output(1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index aecf59aee6..b45bd977cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor class AggregateOptimizeSuite extends PlanTest { - val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) + override val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index 7402918c1b..c5f9cc1852 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ class EliminateSortsSuite extends PlanTest { - val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true, orderByOrdinal = false) + override val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true, orderByOrdinal = false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 087718b3ec..65dd6225ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -143,7 +143,7 @@ class JoinOptimizationSuite extends PlanTest { comparePlans(optimized, expected) val broadcastChildren = optimized.collect { - case Join(_, r, _, _) if r.statistics.sizeInBytes == 1 => r + case Join(_, r, _, _) if r.stats(conf).sizeInBytes == 1 => r } assert(broadcastChildren.size == 1) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index 9ec99835c6..0f3ba6c895 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -33,7 +33,7 @@ class LimitPushdownSuite extends PlanTest { Batch("Subqueries", Once, EliminateSubqueryAliases) :: Batch("Limit pushdown", FixedPoint(100), - LimitPushDown(SimpleCatalystConf(caseSensitiveAnalysis = true)), + LimitPushDown(conf), CombineLimits, ConstantFolding, BooleanSimplification) :: Nil @@ -111,7 +111,7 @@ class LimitPushdownSuite extends PlanTest { } test("full outer join where neither side is limited and both sides have same statistics") { - assert(x.statistics.sizeInBytes === y.statistics.sizeInBytes) + assert(x.stats(conf).sizeInBytes === y.stats(conf).sizeInBytes) val originalQuery = x.join(y, FullOuter).limit(1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Limit(1, LocalLimit(1, x).join(y, FullOuter)).analyze @@ -120,7 +120,7 @@ class LimitPushdownSuite extends PlanTest { test("full outer join where neither side is limited and left side has larger statistics") { val xBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('x) - assert(xBig.statistics.sizeInBytes > y.statistics.sizeInBytes) + assert(xBig.stats(conf).sizeInBytes > y.stats(conf).sizeInBytes) val originalQuery = xBig.join(y, FullOuter).limit(1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Limit(1, LocalLimit(1, xBig).join(y, FullOuter)).analyze @@ -129,7 +129,7 @@ class LimitPushdownSuite extends PlanTest { test("full outer join where neither side is limited and right side has larger statistics") { val yBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('y) - assert(x.statistics.sizeInBytes < yBig.statistics.sizeInBytes) + assert(x.stats(conf).sizeInBytes < yBig.stats(conf).sizeInBytes) val originalQuery = x.join(yBig, FullOuter).limit(1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Limit(1, x.join(LocalLimit(1, yBig), FullOuter)).analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index 5c1faaecdb..350a1c26fd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRela import org.apache.spark.sql.types.{IntegerType, StringType} class RewriteDistinctAggregatesSuite extends PlanTest { - val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) + override val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 64e268703b..3b7e5e938a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical._ @@ -27,6 +28,9 @@ import org.apache.spark.sql.catalyst.util._ * Provides helper methods for comparing plans. */ abstract class PlanTest extends SparkFunSuite with PredicateHelper { + + protected val conf = SimpleCatalystConf(caseSensitiveAnalysis = true) + /** * Since attribute references are given globally unique ids during analysis, * we must normalize them to check if two different queries are identical. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala index 42ce2f8c5e..ff79122e39 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala @@ -130,6 +130,6 @@ class AggEstimationSuite extends StatsEstimationTestBase { rowCount = Some(expectedRowCount), attributeStats = expectedAttrStats) - assert(testAgg.statistics == expectedStats) + assert(testAgg.stats(conf) == expectedStats) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala index 4a1bed84f8..a613f0f5d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -46,6 +46,6 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 2 * getRowSize(project.output, expectedAttrStats), rowCount = Some(2), attributeStats = expectedAttrStats) - assert(project.statistics == expectedStats) + assert(project.stats(conf) == expectedStats) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/statsEstimation/StatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala similarity index 73% rename from sql/core/src/test/scala/org/apache/spark/sql/statsEstimation/StatsEstimationSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala index 78f2ce1d57..212d57a9bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/statsEstimation/StatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala @@ -15,17 +15,16 @@ * limitations under the License. */ -package org.apache.spark.sql.statsEstimation +package org.apache.spark.sql.catalyst.statsEstimation import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} -import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.IntegerType -class StatsEstimationSuite extends SharedSQLContext { - test("statistics for a plan based on the cbo switch") { +class StatsConfSuite extends StatsEstimationTestBase { + test("estimate statistics when the conf changes") { val expectedDefaultStats = Statistics( sizeInBytes = 40, @@ -42,26 +41,24 @@ class StatsEstimationSuite extends SharedSQLContext { isBroadcastable = false) val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) - withSQLConf("spark.sql.cbo.enabled" -> "true") { - // Use the statistics estimated by cbo - assert(plan.planStats(spark.sessionState.conf) == expectedCboStats) - } - withSQLConf("spark.sql.cbo.enabled" -> "false") { - // Use the default statistics - assert(plan.planStats(spark.sessionState.conf) == expectedDefaultStats) - } + // Return the statistics estimated by cbo + assert(plan.stats(conf.copy(cboEnabled = true)) == expectedCboStats) + // Invalidate statistics + plan.invalidateStatsCache() + // Return the simple statistics + assert(plan.stats(conf.copy(cboEnabled = false)) == expectedDefaultStats) } } /** - * This class is used for unit-testing the cbo switch, it mimics a logical plan which has both - * default statistics and cbo estimated statistics. + * This class is used for unit-testing the cbo switch, it mimics a logical plan which computes + * a simple statistics or a cbo estimated statistics based on the conf. */ private case class DummyLogicalPlan( defaultStats: Statistics, cboStats: Statistics) extends LogicalPlan { - override lazy val statistics = defaultStats - override def cboStatistics(conf: CatalystConf): Statistics = cboStats override def output: Seq[Attribute] = Nil override def children: Seq[LogicalPlan] = Nil + override def computeStats(conf: CatalystConf): Statistics = + if (conf.cboEnabled) cboStats else defaultStats } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index 0d81aa3f68..0635309cd9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.statsEstimation import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.types.IntegerType @@ -25,6 +26,9 @@ import org.apache.spark.sql.types.IntegerType class StatsEstimationTestBase extends SparkFunSuite { + /** Enable stats estimation based on CBO. */ + protected val conf = SimpleCatalystConf(caseSensitiveAnalysis = true, cboEnabled = true) + def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)() /** Convert (column name, column stat) pairs to an AttributeMap based on plan output. */ @@ -40,5 +44,5 @@ class StatsEstimationTestBase extends SparkFunSuite { */ protected case class StatsTestPlan(outputList: Seq[Attribute], stats: Statistics) extends LeafNode { override def output: Seq[Attribute] = outputList - override lazy val statistics = stats + override def computeStats(conf: CatalystConf): Statistics = stats } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index aab087cd98..49336f4248 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Encoder, Row, SparkSession} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -95,7 +95,7 @@ case class ExternalRDD[T]( override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override lazy val statistics: Statistics = Statistics( + @transient override def computeStats(conf: CatalystConf): Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) @@ -170,7 +170,7 @@ case class LogicalRDD( override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override lazy val statistics: Statistics = Statistics( + @transient override def computeStats(conf: CatalystConf): Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 1257d1728c..fafb919670 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -114,9 +114,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ private def canBroadcast(plan: LogicalPlan): Boolean = { - plan.planStats(conf).isBroadcastable || - (plan.planStats(conf).sizeInBytes >= 0 && - plan.planStats(conf).sizeInBytes <= conf.autoBroadcastJoinThreshold) + plan.stats(conf).isBroadcastable || + (plan.stats(conf).sizeInBytes >= 0 && + plan.stats(conf).sizeInBytes <= conf.autoBroadcastJoinThreshold) } /** @@ -126,7 +126,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * dynamic. */ private def canBuildLocalHashMap(plan: LogicalPlan): Boolean = { - plan.planStats(conf).sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions + plan.stats(conf).sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions } /** @@ -137,7 +137,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * use the size of bytes here as estimation. */ private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = { - a.planStats(conf).sizeInBytes * 3 <= b.planStats(conf).sizeInBytes + a.stats(conf).sizeInBytes * 3 <= b.stats(conf).sizeInBytes } private def canBuildRight(joinType: JoinType): Boolean = joinType match { @@ -206,7 +206,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Join(left, right, joinType, condition) => val buildSide = - if (right.planStats(conf).sizeInBytes <= left.planStats(conf).sizeInBytes) { + if (right.stats(conf).sizeInBytes <= left.stats(conf).sizeInBytes) { BuildRight } else { BuildLeft diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 03cc04659b..37bd95e737 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -21,7 +21,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystConf, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical @@ -69,7 +69,7 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) - override lazy val statistics: Statistics = { + override def computeStats(conf: CatalystConf): Statistics = { if (batchStats.value == 0L) { // Underlying columnar RDD hasn't been materialized, no useful statistics information // available, return the default statistics. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 3fd40384d2..04a764bee2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} @@ -72,7 +73,7 @@ case class LogicalRelation( // expId can be different but the relation is still the same. override lazy val cleanArgs: Seq[Any] = Seq(relation) - @transient override lazy val statistics: Statistics = { + @transient override def computeStats(conf: CatalystConf): Statistics = { catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse( Statistics(sizeInBytes = relation.sizeInBytes)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 91da6b3846..6d34d51d31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -25,6 +25,7 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} @@ -229,5 +230,6 @@ case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum - override def statistics: Statistics = Statistics(sizePerRow * sink.allData.size) + override def computeStats(conf: CatalystConf): Statistics = + Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index fb4812adf1..339262a5a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -305,7 +305,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum - assert(cached.statistics.sizeInBytes === actualSizeInBytes) + assert(cached.stats(sqlConf).sizeInBytes === actualSizeInBytes) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index c27b815dfa..731a28c237 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1115,7 +1115,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { // instead of Int for avoiding possible overflow. val ds = (0 to 10000).map( i => (i, Seq((i, Seq((i, "This is really not that long of a string")))))).toDS() - val sizeInBytes = ds.logicalPlan.statistics.sizeInBytes + val sizeInBytes = ds.logicalPlan.stats(sqlConf).sizeInBytes // sizeInBytes is 2404280404, before the fix, it overflows to a negative number assert(sizeInBytes > 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 913b2ae976..f780fc0ec0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -32,7 +32,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { setupTestData() def statisticSizeInByte(df: DataFrame): BigInt = { - df.queryExecution.optimizedPlan.statistics.sizeInBytes + df.queryExecution.optimizedPlan.stats(sqlConf).sizeInBytes } test("equi-join is hash-join") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 18abb18587..bd1ce8aa3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -59,7 +59,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared val df = df1.join(df2, Seq("k"), "left") val sizes = df.queryExecution.analyzed.collect { case g: Join => - g.statistics.sizeInBytes + g.stats(conf).sizeInBytes } assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") @@ -106,9 +106,9 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("SPARK-15392: DataFrame created from RDD should not be broadcasted") { val rdd = sparkContext.range(1, 100).map(i => Row(i, i)) val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType)) - assert(df.queryExecution.analyzed.statistics.sizeInBytes > + assert(df.queryExecution.analyzed.stats(conf).sizeInBytes > spark.sessionState.conf.autoBroadcastJoinThreshold) - assert(df.selectExpr("a").queryExecution.analyzed.statistics.sizeInBytes > + assert(df.selectExpr("a").queryExecution.analyzed.stats(conf).sizeInBytes > spark.sessionState.conf.autoBroadcastJoinThreshold) } @@ -120,14 +120,14 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared val df = sql(s"""SELECT * FROM test limit $limit""") val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit => - g.statistics.sizeInBytes + g.stats(conf).sizeInBytes } assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") assert(sizesGlobalLimit.head === BigInt(expected), s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}") val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit => - l.statistics.sizeInBytes + l.stats(conf).sizeInBytes } assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") assert(sizesLocalLimit.head === BigInt(expected), @@ -250,13 +250,13 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils test("SPARK-18856: non-empty partitioned table should not report zero size") { withTable("ds_tbl", "hive_tbl") { spark.range(100).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("ds_tbl") - val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.statistics + val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats(conf) assert(stats.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { sql("CREATE TABLE hive_tbl(i int) PARTITIONED BY (j int)") sql("INSERT INTO hive_tbl PARTITION(j=1) SELECT 1") - val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.statistics + val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats(conf) assert(stats2.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") } } @@ -296,10 +296,10 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat)) // Check relation statistics - assert(relation.statistics.sizeInBytes == 0) - assert(relation.statistics.rowCount == Some(0)) - assert(relation.statistics.attributeStats.size == 1) - val (attribute, colStat) = relation.statistics.attributeStats.head + assert(relation.stats(conf).sizeInBytes == 0) + assert(relation.stats(conf).rowCount == Some(0)) + assert(relation.stats(conf).attributeStats.size == 1) + val (attribute, colStat) = relation.stats(conf).attributeStats.head assert(attribute.name == "c1") assert(colStat == emptyColStat) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index afeb47828e..f355a5200c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -123,7 +123,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { .toDF().createOrReplaceTempView("sizeTst") spark.catalog.cacheTable("sizeTst") assert( - spark.table("sizeTst").queryExecution.analyzed.statistics.sizeInBytes > + spark.table("sizeTst").queryExecution.analyzed.stats(sqlConf).sizeInBytes > spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index 89d57653ad..7679e854cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -36,7 +36,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { }) val totalSize = allFiles.map(_.length()).sum val df = spark.read.parquet(dir.toString) - assert(df.queryExecution.logical.statistics.sizeInBytes === BigInt(totalSize)) + assert(df.queryExecution.logical.stats(sqlConf).sizeInBytes === BigInt(totalSize)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index 8f23f98f76..24a7b7740f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -216,13 +216,15 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { // Before adding data, check output checkAnswer(sink.allData, Seq.empty) - assert(plan.statistics.sizeInBytes === 0) + assert(plan.stats(sqlConf).sizeInBytes === 0) sink.addBatch(0, 1 to 3) - assert(plan.statistics.sizeInBytes === 12) + plan.invalidateStatsCache() + assert(plan.stats(sqlConf).sizeInBytes === 12) sink.addBatch(1, 4 to 6) - assert(plan.statistics.sizeInBytes === 24) + plan.invalidateStatsCache() + assert(plan.stats(sqlConf).sizeInBytes === 24) } ignore("stress test") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 0cfe260e52..f9b3ff8405 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits} +import org.apache.spark.sql.internal.SQLConf /** * A collection of sample data used in SQL tests. @@ -28,6 +29,8 @@ import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits} private[sql] trait SQLTestData { self => protected def spark: SparkSession + protected def sqlConf: SQLConf = spark.sessionState.conf + // Helper object to import SQL implicits without a concrete SQLContext private object internalImplicits extends SQLImplicits { protected override def _sqlContext: SQLContext = self.spark.sqlContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 2239f10870..36dc2368fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.{DebugFilesystem, SparkConf} import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.internal.SQLConf /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 0407cf6a1e..ee4589f855 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -232,7 +232,8 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log Some(partitionSchema)) val logicalRelation = cached.getOrElse { - val sizeInBytes = metastoreRelation.statistics.sizeInBytes.toLong + val sizeInBytes = + metastoreRelation.stats(sparkSession.sessionState.conf).sizeInBytes.toLong val fileCatalog = { val catalog = new CatalogFileIndex( sparkSession, metastoreRelation.catalogTable, sizeInBytes) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala index 2e60cba09d..7254f73f41 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/MetastoreRelation.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.hive.ql.metadata.{Partition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference, Expression} @@ -112,7 +113,7 @@ private[hive] case class MetastoreRelation( new HiveTable(tTable) } - @transient override lazy val statistics: Statistics = { + @transient override def computeStats(conf: CatalystConf): Statistics = { catalogTable.stats.map(_.toPlanStats(output)).getOrElse(Statistics( sizeInBytes = { val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index b040f26d28..0053aa1642 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -69,7 +69,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto assert(properties.get("totalSize").toLong <= 0, "external table totalSize must be <= 0") assert(properties.get("rawDataSize").toLong <= 0, "external table rawDataSize must be <= 0") - val sizeInBytes = relation.statistics.sizeInBytes + val sizeInBytes = relation.stats(conf).sizeInBytes assert(sizeInBytes === BigInt(file1.length() + file2.length())) } } finally { @@ -80,7 +80,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("analyze MetastoreRelations") { def queryTotalSize(tableName: String): BigInt = - spark.sessionState.catalog.lookupRelation(TableIdentifier(tableName)).statistics.sizeInBytes + spark.sessionState.catalog.lookupRelation(TableIdentifier(tableName)).stats(conf).sizeInBytes // Non-partitioned table sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() @@ -481,7 +481,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("estimates the size of a test MetastoreRelation") { val df = sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation => - mr.statistics.sizeInBytes + mr.stats(conf).sizeInBytes } assert(sizes.size === 1, s"Size wrong for:\n ${df.queryExecution}") assert(sizes(0).equals(BigInt(5812)), @@ -501,7 +501,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // Assert src has a size smaller than the threshold. val sizes = df.queryExecution.analyzed.collect { - case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes + case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.stats(conf).sizeInBytes } assert(sizes.size === 2 && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold, @@ -557,7 +557,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val sizes = df.queryExecution.analyzed.collect { case r if implicitly[ClassTag[MetastoreRelation]].runtimeClass .isAssignableFrom(r.getClass) => - r.statistics.sizeInBytes + r.stats(conf).sizeInBytes } assert(sizes.size === 2 && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold, -- GitLab