diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index c043ed9c431b70f55a14e14cee6e447f121457db..b63bef9193332c489523bcf868a6a0fa9009d152 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Attri import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -436,7 +435,7 @@ case class CatalogRelation( createTime = -1 )) - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { // For data source tables, we will create a `LogicalRelation` and won't call this method, for // hive serde tables, we will always generate a statistics. // TODO: unify the table stats generation. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 51eca6ca33760f9fd22ae01771242f5d7940ed53..3a7543e2141e999ca1f2fdbc277f8b313cafe130 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -58,7 +58,7 @@ case class CostBasedJoinReorder(conf: SQLConf) extends Rule[LogicalPlan] with Pr // Do reordering if the number of items is appropriate and join conditions exist. // We also need to check if costs of all items can be evaluated. if (items.size > 2 && items.size <= conf.joinReorderDPThreshold && conditions.nonEmpty && - items.forall(_.stats(conf).rowCount.isDefined)) { + items.forall(_.stats.rowCount.isDefined)) { JoinReorderDP.search(conf, items, conditions, output) } else { plan @@ -322,7 +322,7 @@ object JoinReorderDP extends PredicateHelper with Logging { /** Get the cost of the root node of this plan tree. */ def rootCost(conf: SQLConf): Cost = { if (itemIds.size > 1) { - val rootStats = plan.stats(conf) + val rootStats = plan.stats Cost(rootStats.rowCount.get, rootStats.sizeInBytes) } else { // If the plan is a leaf item, it has zero cost. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3ab70fb90470c35ec973c4132b26f94d462a38ec..b410312030c5dc204b5d2f3a00ca3eefbe0f29f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -317,7 +317,7 @@ case class LimitPushDown(conf: SQLConf) extends Rule[LogicalPlan] { case FullOuter => (left.maxRows, right.maxRows) match { case (None, None) => - if (left.stats(conf).sizeInBytes >= right.stats(conf).sizeInBytes) { + if (left.stats.sizeInBytes >= right.stats.sizeInBytes) { join.copy(left = maybePushLimit(exp, left)) } else { join.copy(right = maybePushLimit(exp, right)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala index 97ee9988386dd5d1c700f4f5274724566039e1f0..ca729127e7d1d3c093a849ba7079fba216e27c73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/StarSchemaDetection.scala @@ -82,7 +82,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { // Find if the input plans are eligible for star join detection. // An eligible plan is a base table access with valid statistics. val foundEligibleJoin = input.forall { - case PhysicalOperation(_, _, t: LeafNode) if t.stats(conf).rowCount.isDefined => true + case PhysicalOperation(_, _, t: LeafNode) if t.stats.rowCount.isDefined => true case _ => false } @@ -181,7 +181,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { val leafCol = findLeafNodeCol(column, plan) leafCol match { case Some(col) if t.outputSet.contains(col) => - val stats = t.stats(conf) + val stats = t.stats stats.rowCount match { case Some(rowCount) if rowCount >= 0 => if (stats.attributeStats.nonEmpty && stats.attributeStats.contains(col)) { @@ -237,7 +237,7 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { val leafCol = findLeafNodeCol(column, plan) leafCol match { case Some(col) if t.outputSet.contains(col) => - val stats = t.stats(conf) + val stats = t.stats stats.attributeStats.nonEmpty && stats.attributeStats.contains(col) case None => false } @@ -296,11 +296,11 @@ case class StarSchemaDetection(conf: SQLConf) extends PredicateHelper { */ private def getTableAccessCardinality( input: LogicalPlan): Option[BigInt] = input match { - case PhysicalOperation(_, cond, t: LeafNode) if t.stats(conf).rowCount.isDefined => - if (conf.cboEnabled && input.stats(conf).rowCount.isDefined) { - Option(input.stats(conf).rowCount.get) + case PhysicalOperation(_, cond, t: LeafNode) if t.stats.rowCount.isDefined => + if (conf.cboEnabled && input.stats.rowCount.isDefined) { + Option(input.stats.rowCount.get) } else { - Option(t.stats(conf).rowCount.get) + Option(t.stats.rowCount.get) } case _ => None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 9cd5dfd21b1607f201644f24768e0fc88085fb34..dc2add64b68b75ee0cedfcc1f1fec334f8e473e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { @@ -67,7 +66,7 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) } } - override def computeStats(conf: SQLConf): Statistics = + override def computeStats: Statistics = Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 95b4165f6b10d3bcef79f65c537b6c7b05d601ad..0c098ac0209e81766fe0d166159d922081f19428 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.CurrentOrigin -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -90,8 +89,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai * first time. If the configuration changes, the cache can be invalidated by calling * [[invalidateStatsCache()]]. */ - final def stats(conf: SQLConf): Statistics = statsCache.getOrElse { - statsCache = Some(computeStats(conf)) + final def stats: Statistics = statsCache.getOrElse { + statsCache = Some(computeStats) statsCache.get } @@ -108,11 +107,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with QueryPlanConstrai * * [[LeafNode]]s must override this. */ - protected def computeStats(conf: SQLConf): Statistics = { + protected def computeStats: Statistics = { if (children.isEmpty) { throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") } - Statistics(sizeInBytes = children.map(_.stats(conf).sizeInBytes).product) + Statistics(sizeInBytes = children.map(_.stats.sizeInBytes).product) } override def verboseStringWithSuffix: String = { @@ -333,13 +332,13 @@ abstract class UnaryNode extends LogicalPlan { override protected def validConstraints: Set[Expression] = child.constraints - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { // There should be some overhead in Row object, the size should not be zero when there is // no columns, this help to prevent divide-by-zero error. val childRowSize = child.output.map(_.dataType.defaultSize).sum + 8 val outputRowSize = output.map(_.dataType.defaultSize).sum + 8 // Assume there will be the same number of rows as child has. - var sizeInBytes = (child.stats(conf).sizeInBytes * outputRowSize) / childRowSize + var sizeInBytes = (child.stats.sizeInBytes * outputRowSize) / childRowSize if (sizeInBytes == 0) { // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero // (product of children). @@ -347,7 +346,7 @@ abstract class UnaryNode extends LogicalPlan { } // Don't propagate rowCount and attributeStats, since they are not estimated here. - Statistics(sizeInBytes = sizeInBytes, hints = child.stats(conf).hints) + Statistics(sizeInBytes = sizeInBytes, hints = child.stats.hints) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 6e88b7a57dc334e8f34cae30f2ae60dd2afe152a..d8f89b108e63fbc1c6866e92aedfd85518889f05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.RandomSampler @@ -65,11 +64,11 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend override def validConstraints: Set[Expression] = child.constraints.union(getAliasedConstraints(projectList)) - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { if (conf.cboEnabled) { - ProjectEstimation.estimate(conf, this).getOrElse(super.computeStats(conf)) + ProjectEstimation.estimate(this).getOrElse(super.computeStats) } else { - super.computeStats(conf) + super.computeStats } } } @@ -139,11 +138,11 @@ case class Filter(condition: Expression, child: LogicalPlan) child.constraints.union(predicates.toSet) } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { if (conf.cboEnabled) { - FilterEstimation(this, conf).estimate.getOrElse(super.computeStats(conf)) + FilterEstimation(this).estimate.getOrElse(super.computeStats) } else { - super.computeStats(conf) + super.computeStats } } } @@ -192,13 +191,13 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation } } - override def computeStats(conf: SQLConf): Statistics = { - val leftSize = left.stats(conf).sizeInBytes - val rightSize = right.stats(conf).sizeInBytes + override def computeStats: Statistics = { + val leftSize = left.stats.sizeInBytes + val rightSize = right.stats.sizeInBytes val sizeInBytes = if (leftSize < rightSize) leftSize else rightSize Statistics( sizeInBytes = sizeInBytes, - hints = left.stats(conf).hints.resetForJoin()) + hints = left.stats.hints.resetForJoin()) } } @@ -209,8 +208,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le override protected def validConstraints: Set[Expression] = leftConstraints - override def computeStats(conf: SQLConf): Statistics = { - left.stats(conf).copy() + override def computeStats: Statistics = { + left.stats.copy() } } @@ -248,8 +247,8 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { children.length > 1 && childrenResolved && allChildrenCompatible } - override def computeStats(conf: SQLConf): Statistics = { - val sizeInBytes = children.map(_.stats(conf).sizeInBytes).sum + override def computeStats: Statistics = { + val sizeInBytes = children.map(_.stats.sizeInBytes).sum Statistics(sizeInBytes = sizeInBytes) } @@ -357,20 +356,20 @@ case class Join( case _ => resolvedExceptNatural } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { def simpleEstimation: Statistics = joinType match { case LeftAnti | LeftSemi => // LeftSemi and LeftAnti won't ever be bigger than left - left.stats(conf) + left.stats case _ => // Make sure we don't propagate isBroadcastable in other joins, because // they could explode the size. - val stats = super.computeStats(conf) + val stats = super.computeStats stats.copy(hints = stats.hints.resetForJoin()) } if (conf.cboEnabled) { - JoinEstimation.estimate(conf, this).getOrElse(simpleEstimation) + JoinEstimation.estimate(this).getOrElse(simpleEstimation) } else { simpleEstimation } @@ -523,7 +522,7 @@ case class Range( override def newInstance(): Range = copy(output = output.map(_.newInstance())) - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { val sizeInBytes = LongType.defaultSize * numElements Statistics( sizeInBytes = sizeInBytes ) } @@ -556,20 +555,20 @@ case class Aggregate( child.constraints.union(getAliasedConstraints(nonAgg)) } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { def simpleEstimation: Statistics = { if (groupingExpressions.isEmpty) { Statistics( sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1), rowCount = Some(1), - hints = child.stats(conf).hints) + hints = child.stats.hints) } else { - super.computeStats(conf) + super.computeStats } } if (conf.cboEnabled) { - AggregateEstimation.estimate(conf, this).getOrElse(simpleEstimation) + AggregateEstimation.estimate(this).getOrElse(simpleEstimation) } else { simpleEstimation } @@ -672,8 +671,8 @@ case class Expand( override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) - override def computeStats(conf: SQLConf): Statistics = { - val sizeInBytes = super.computeStats(conf).sizeInBytes * projections.length + override def computeStats: Statistics = { + val sizeInBytes = super.computeStats.sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } @@ -743,9 +742,9 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN case _ => None } } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val childStats = child.stats(conf) + val childStats = child.stats val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit) // Don't propagate column stats, because we don't know the distribution after a limit operation Statistics( @@ -763,9 +762,9 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case _ => None } } - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val childStats = child.stats(conf) + val childStats = child.stats if (limit == 0) { // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero // (product of children). @@ -832,9 +831,9 @@ case class Sample( override def output: Seq[Attribute] = child.output - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { val ratio = upperBound - lowerBound - val childStats = child.stats(conf) + val childStats = child.stats var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio) if (sizeInBytes == 0) { sizeInBytes = 1 @@ -898,7 +897,7 @@ case class RepartitionByExpression( case object OneRowRelation extends LeafNode { override def maxRows: Option[Long] = Some(1) override def output: Seq[Attribute] = Nil - override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = 1) + override def computeStats: Statistics = Statistics(sizeInBytes = 1) } /** A logical plan for `dropDuplicates`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala index e49970df804576b13fd84f3f823b3970c75cb769..8479c702d756167ef8b5ce094261f05a47099db5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.internal.SQLConf /** * A general hint for the child that is not yet resolved. This node is generated by the parser and @@ -44,8 +43,8 @@ case class ResolvedHint(child: LogicalPlan, hints: HintInfo = HintInfo()) override lazy val canonicalized: LogicalPlan = child.canonicalized - override def computeStats(conf: SQLConf): Statistics = { - val stats = child.stats(conf) + override def computeStats: Statistics = { + val stats = child.stats stats.copy(hints = hints) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index a0c23198451a84f900e99139b259630e3a176536..c41fac4015ec0661d35c65ed0dad6f48ad571052 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics} -import org.apache.spark.sql.internal.SQLConf object AggregateEstimation { @@ -29,13 +28,13 @@ object AggregateEstimation { * Estimate the number of output rows based on column stats of group-by columns, and propagate * column stats for aggregate expressions. */ - def estimate(conf: SQLConf, agg: Aggregate): Option[Statistics] = { - val childStats = agg.child.stats(conf) + def estimate(agg: Aggregate): Option[Statistics] = { + val childStats = agg.child.stats // Check if we have column stats for all group-by columns. val colStatsExist = agg.groupingExpressions.forall { e => e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute]) } - if (rowCountsExist(conf, agg.child) && colStatsExist) { + if (rowCountsExist(agg.child) && colStatsExist) { // Multiply distinct counts of group-by columns. This is an upper bound, which assumes // the data contains all combinations of distinct values of group-by columns. var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index e5fcdf9039be9bf1ad7ae0f8e876d7685cbd46e9..9c34a9b7aa756dbedd6b89115e03c9ed912f0240 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -21,15 +21,14 @@ import scala.math.BigDecimal.RoundingMode import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, _} object EstimationUtils { /** Check if each plan has rowCount in its statistics. */ - def rowCountsExist(conf: SQLConf, plans: LogicalPlan*): Boolean = - plans.forall(_.stats(conf).rowCount.isDefined) + def rowCountsExist(plans: LogicalPlan*): Boolean = + plans.forall(_.stats.rowCount.isDefined) /** Check if each attribute has column stat in the corresponding statistics. */ def columnStatsExist(statsAndAttr: (Statistics, Attribute)*): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index df190867189ecaba197a3f32f7331cc6da277a8e..5a3bee7b9e4493df56ef357f1829587ca76160d6 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -25,12 +25,11 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging { +case class FilterEstimation(plan: Filter) extends Logging { - private val childStats = plan.child.stats(catalystConf) + private val childStats = plan.child.stats private val colStatsMap = new ColumnStatsMap(childStats.attributeStats) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 8ef905c45d50d0b1adf5a3943856a282f758c5f3..f48196997a24d382c7c06ea2c279f7fe4d9bc289 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ -import org.apache.spark.sql.internal.SQLConf object JoinEstimation extends Logging { @@ -34,12 +33,12 @@ object JoinEstimation extends Logging { * Estimate statistics after join. Return `None` if the join type is not supported, or we don't * have enough statistics for estimation. */ - def estimate(conf: SQLConf, join: Join): Option[Statistics] = { + def estimate(join: Join): Option[Statistics] = { join.joinType match { case Inner | Cross | LeftOuter | RightOuter | FullOuter => - InnerOuterEstimation(conf, join).doEstimate() + InnerOuterEstimation(join).doEstimate() case LeftSemi | LeftAnti => - LeftSemiAntiEstimation(conf, join).doEstimate() + LeftSemiAntiEstimation(join).doEstimate() case _ => logDebug(s"[CBO] Unsupported join type: ${join.joinType}") None @@ -47,16 +46,16 @@ object JoinEstimation extends Logging { } } -case class InnerOuterEstimation(conf: SQLConf, join: Join) extends Logging { +case class InnerOuterEstimation(join: Join) extends Logging { - private val leftStats = join.left.stats(conf) - private val rightStats = join.right.stats(conf) + private val leftStats = join.left.stats + private val rightStats = join.right.stats /** * Estimate output size and number of rows after a join operator, and update output column stats. */ def doEstimate(): Option[Statistics] = join match { - case _ if !rowCountsExist(conf, join.left, join.right) => + case _ if !rowCountsExist(join.left, join.right) => None case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) => @@ -273,13 +272,13 @@ case class InnerOuterEstimation(conf: SQLConf, join: Join) extends Logging { } } -case class LeftSemiAntiEstimation(conf: SQLConf, join: Join) { +case class LeftSemiAntiEstimation(join: Join) { def doEstimate(): Option[Statistics] = { // TODO: It's error-prone to estimate cardinalities for LeftSemi and LeftAnti based on basic // column stats. Now we just propagate the statistics from left side. We should do more // accurate estimation when advanced stats (e.g. histograms) are available. - if (rowCountsExist(conf, join.left)) { - val leftStats = join.left.stats(conf) + if (rowCountsExist(join.left)) { + val leftStats = join.left.stats // Propagate the original column stats for cartesian product val outputRows = leftStats.rowCount.get Some(Statistics( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala index d700cd3b20f7db37103525caa83014e6b198bca1..489eb904ffd05a97862c0aa6e66bfd156e6b0783 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala @@ -19,14 +19,13 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{Project, Statistics} -import org.apache.spark.sql.internal.SQLConf object ProjectEstimation { import EstimationUtils._ - def estimate(conf: SQLConf, project: Project): Option[Statistics] = { - if (rowCountsExist(conf, project.child)) { - val childStats = project.child.stats(conf) + def estimate(project: Project): Option[Statistics] = { + if (rowCountsExist(project.child)) { + val childStats = project.child.stats val inputAttrStats = childStats.attributeStats // Match alias with its child's column stat val aliasStats = project.expressions.collect { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 105407d43bf3935e5f4c60df63d164e36a4c2a4d..a6584aa5fbba7850194d4543125ac993db10ec73 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -142,7 +142,7 @@ class JoinOptimizationSuite extends PlanTest { comparePlans(optimized, expected) val broadcastChildren = optimized.collect { - case Join(_, r, _, _) if r.stats(conf).sizeInBytes == 1 => r + case Join(_, r, _, _) if r.stats.sizeInBytes == 1 => r } assert(broadcastChildren.size == 1) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index fb34c82de468b5ecade1953ac1cd71dd86bdf10c..d8302dfc9462dc8ad4e375c4c26a30f7a9ed8eda 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -112,7 +112,7 @@ class LimitPushdownSuite extends PlanTest { } test("full outer join where neither side is limited and both sides have same statistics") { - assert(x.stats(conf).sizeInBytes === y.stats(conf).sizeInBytes) + assert(x.stats.sizeInBytes === y.stats.sizeInBytes) val originalQuery = x.join(y, FullOuter).limit(1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Limit(1, LocalLimit(1, x).join(y, FullOuter)).analyze @@ -121,7 +121,7 @@ class LimitPushdownSuite extends PlanTest { test("full outer join where neither side is limited and left side has larger statistics") { val xBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('x) - assert(xBig.stats(conf).sizeInBytes > y.stats(conf).sizeInBytes) + assert(xBig.stats.sizeInBytes > y.stats.sizeInBytes) val originalQuery = xBig.join(y, FullOuter).limit(1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Limit(1, LocalLimit(1, xBig).join(y, FullOuter)).analyze @@ -130,7 +130,7 @@ class LimitPushdownSuite extends PlanTest { test("full outer join where neither side is limited and right side has larger statistics") { val yBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('y) - assert(x.stats(conf).sizeInBytes < yBig.stats(conf).sizeInBytes) + assert(x.stats.sizeInBytes < yBig.stats.sizeInBytes) val originalQuery = x.join(yBig, FullOuter).limit(1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Limit(1, x.join(LocalLimit(1, yBig), FullOuter)).analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index 38483a298cef04ee33953d1c76174f89c87a18be..30ddf03bd3c4f3bdfafda6c71ed4b5c34a71834f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -100,17 +100,23 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { size = Some(4 * (8 + 4)), attributeStats = AttributeMap(Seq("key12").map(nameToColInfo))) - val noGroupAgg = Aggregate(groupingExpressions = Nil, - aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child) - assert(noGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == - // overhead + count result size - Statistics(sizeInBytes = 8 + 8, rowCount = Some(1))) - - val hasGroupAgg = Aggregate(groupingExpressions = attributes, - aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child) - assert(hasGroupAgg.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == - // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize - Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4))) + val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED) + try { + SQLConf.get.setConf(SQLConf.CBO_ENABLED, false) + val noGroupAgg = Aggregate(groupingExpressions = Nil, + aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child) + assert(noGroupAgg.stats == + // overhead + count result size + Statistics(sizeInBytes = 8 + 8, rowCount = Some(1))) + + val hasGroupAgg = Aggregate(groupingExpressions = attributes, + aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child) + assert(hasGroupAgg.stats == + // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize + Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4))) + } finally { + SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue) + } } private def checkAggStats( @@ -134,6 +140,6 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { rowCount = Some(expectedOutputRowCount), attributeStats = expectedAttrStats) - assert(testAgg.stats(conf) == expectedStats) + assert(testAgg.stats == expectedStats) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 833f5a71994f7d4ad973d28f1e39b3c93c390a64..e9ed36feec48cc4ea10b4a99eb7018c744494237 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -57,16 +57,16 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { val localLimit = LocalLimit(Literal(2), plan) val globalLimit = GlobalLimit(Literal(2), plan) // LocalLimit's stats is just its child's stats except column stats - checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + checkStats(localLimit, plan.stats.copy(attributeStats = AttributeMap(Nil))) checkStats(globalLimit, Statistics(sizeInBytes = 24, rowCount = Some(2))) } test("limit estimation: limit > child's rowCount") { val localLimit = LocalLimit(Literal(20), plan) val globalLimit = GlobalLimit(Literal(20), plan) - checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + checkStats(localLimit, plan.stats.copy(attributeStats = AttributeMap(Nil))) // Limit is larger than child's rowCount, so GlobalLimit's stats is equal to its child's stats. - checkStats(globalLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + checkStats(globalLimit, plan.stats.copy(attributeStats = AttributeMap(Nil))) } test("limit estimation: limit = 0") { @@ -113,12 +113,19 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { plan: LogicalPlan, expectedStatsCboOn: Statistics, expectedStatsCboOff: Statistics): Unit = { - // Invalidate statistics - plan.invalidateStatsCache() - assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> true)) == expectedStatsCboOn) - - plan.invalidateStatsCache() - assert(plan.stats(conf.copy(SQLConf.CBO_ENABLED -> false)) == expectedStatsCboOff) + val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED) + try { + // Invalidate statistics + plan.invalidateStatsCache() + SQLConf.get.setConf(SQLConf.CBO_ENABLED, true) + assert(plan.stats == expectedStatsCboOn) + + plan.invalidateStatsCache() + SQLConf.get.setConf(SQLConf.CBO_ENABLED, false) + assert(plan.stats == expectedStatsCboOff) + } finally { + SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue) + } } /** Check estimated stats when it's the same whether cbo is turned on or off. */ @@ -135,6 +142,6 @@ private case class DummyLogicalPlan( cboStats: Statistics) extends LogicalPlan { override def output: Seq[Attribute] = Nil override def children: Seq[LogicalPlan] = Nil - override def computeStats(conf: SQLConf): Statistics = + override def computeStats: Statistics = if (conf.cboEnabled) cboStats else defaultStats } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 2fa53a6466ef2b6b230481828e20c57d1555992f..455037e6c9952578c135ba302dd7c570f602b022 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -620,7 +620,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { rowCount = Some(expectedRowCount), attributeStats = expectedAttributeMap) - val filterStats = filter.stats(conf) + val filterStats = filter.stats assert(filterStats.sizeInBytes == expectedStats.sizeInBytes) assert(filterStats.rowCount == expectedStats.rowCount) val rowCountValue = filterStats.rowCount.getOrElse(0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index 2d6b6e8e21f34f6f17b73b4cf24542110b26b8e5..097c78eb27fca3d59ec33ee2642f54ea8b2b0eb9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -77,7 +77,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // Keep the column stat from both sides unchanged. attributeStats = AttributeMap( Seq("key-1-5", "key-5-9", "key-1-2", "key-2-4").map(nameToColInfo))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("disjoint inner join") { @@ -90,7 +90,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 1, rowCount = Some(0), attributeStats = AttributeMap(Nil)) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("disjoint left outer join") { @@ -106,7 +106,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // Null count for right side columns = left row count Seq(nameToAttr("key-1-2") -> nullColumnStat(nameToAttr("key-1-2").dataType, 5), nameToAttr("key-2-4") -> nullColumnStat(nameToAttr("key-2-4").dataType, 5)))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("disjoint right outer join") { @@ -122,7 +122,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // Null count for left side columns = right row count Seq(nameToAttr("key-1-5") -> nullColumnStat(nameToAttr("key-1-5").dataType, 3), nameToAttr("key-5-9") -> nullColumnStat(nameToAttr("key-5-9").dataType, 3)))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("disjoint full outer join") { @@ -140,7 +140,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { nameToAttr("key-5-9") -> columnInfo(nameToAttr("key-5-9")).copy(nullCount = 3), nameToAttr("key-1-2") -> columnInfo(nameToAttr("key-1-2")).copy(nullCount = 5), nameToAttr("key-2-4") -> columnInfo(nameToAttr("key-2-4")).copy(nullCount = 5)))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("inner join") { @@ -161,7 +161,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap( Seq(nameToAttr("key-1-5") -> joinedColStat, nameToAttr("key-1-2") -> joinedColStat, nameToAttr("key-5-9") -> colStatForkey59, nameToColInfo("key-2-4")))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("inner join with multiple equi-join keys") { @@ -183,7 +183,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap( Seq(nameToAttr("key-1-2") -> joinedColStat1, nameToAttr("key-1-2") -> joinedColStat1, nameToAttr("key-2-4") -> joinedColStat2, nameToAttr("key-2-3") -> joinedColStat2))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("left outer join") { @@ -201,7 +201,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap( Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-3"), nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("right outer join") { @@ -219,7 +219,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { attributeStats = AttributeMap( Seq(nameToColInfo("key-1-2"), nameToAttr("key-2-4") -> joinedColStat, nameToColInfo("key-1-2"), nameToColInfo("key-2-3")))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("full outer join") { @@ -234,7 +234,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { // Keep the column stat from both sides unchanged. attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4"), nameToColInfo("key-1-2"), nameToColInfo("key-2-3")))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } test("left semi/anti join") { @@ -248,7 +248,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 3 * (8 + 4 * 2), rowCount = Some(3), attributeStats = AttributeMap(Seq(nameToColInfo("key-1-2"), nameToColInfo("key-2-4")))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } } @@ -306,7 +306,7 @@ class JoinEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 1 * (8 + 2 * getColSize(key1, columnInfo1(key1))), rowCount = Some(1), attributeStats = AttributeMap(Seq(key1 -> columnInfo1(key1), key2 -> columnInfo1(key1)))) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } } } @@ -323,6 +323,6 @@ class JoinEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 1, rowCount = Some(0), attributeStats = AttributeMap(Nil)) - assert(join.stats(conf) == expectedStats) + assert(join.stats == expectedStats) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala index a5c4d22a29386a27d0a97d672c5e2dc8c1b7e190..cda54fa9d64f4919aa271fbb8e0a9b45503989eb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -45,7 +45,7 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { sizeInBytes = 2 * (8 + 4 + 4), rowCount = Some(2), attributeStats = expectedAttrStats) - assert(proj.stats(conf) == expectedStats) + assert(proj.stats == expectedStats) } test("project on empty table") { @@ -131,6 +131,6 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { sizeInBytes = expectedSize, rowCount = Some(expectedRowCount), attributeStats = projectAttrMap) - assert(proj.stats(conf) == expectedStats) + assert(proj.stats == expectedStats) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index 263f4e18803d55b2e8cc585079c05ac2f1ade03c..eaa33e44a6a5a7725056c30187fc55701b2ce476 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -21,14 +21,24 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, CBO_ENABLED} import org.apache.spark.sql.types.{IntegerType, StringType} trait StatsEstimationTestBase extends SparkFunSuite { - /** Enable stats estimation based on CBO. */ - protected val conf = new SQLConf().copy(CASE_SENSITIVE -> true, CBO_ENABLED -> true) + var originalValue: Boolean = false + + override def beforeAll(): Unit = { + super.beforeAll() + // Enable stats estimation based on CBO. + originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED) + SQLConf.get.setConf(SQLConf.CBO_ENABLED, true) + } + + override def afterAll(): Unit = { + SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue) + super.afterAll() + } def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match { // For UTF8String: base + offset + numBytes @@ -55,7 +65,7 @@ case class StatsTestPlan( attributeStats: AttributeMap[ColumnStat], size: Option[BigInt] = None) extends LeafNode { override def output: Seq[Attribute] = outputList - override def computeStats(conf: SQLConf): Statistics = Statistics( + override def computeStats: Statistics = Statistics( // If sizeInBytes is useless in testing, we just use a fake value sizeInBytes = size.getOrElse(Int.MaxValue), rowCount = Some(rowCount), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 3d1b481a53e757c43bcf105551fd247462ba457f..66f66a289a0654ae55b874cdd5f40f4229ede28e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -89,7 +88,7 @@ case class ExternalRDD[T]( override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override def computeStats(conf: SQLConf): Statistics = Statistics( + @transient override def computeStats: Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) @@ -157,7 +156,7 @@ case class LogicalRDD( override protected def stringArgs: Iterator[Any] = Iterator(output) - @transient override def computeStats(conf: SQLConf): Statistics = Statistics( + @transient override def computeStats: Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 34998cbd615521b3109bf45f8ade54ea1be7a2cf..c7cac332a0377d0a939b8c569ca0159d8770d303 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -221,7 +221,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { def stringWithStats: String = { // trigger to compute stats for logical plans - optimizedPlan.stats(sparkSession.sessionState.conf) + optimizedPlan.stats // only show optimized logical plan and physical plan s"""== Optimized Logical Plan == diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ea86f6e00fefa02330bad7a214ceca10f7634ee4..a57d5abb90c0eae5bc1aaef9eb8ab518369776ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -114,9 +113,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Matches a plan whose output should be small enough to be used in broadcast join. */ private def canBroadcast(plan: LogicalPlan): Boolean = { - plan.stats(conf).hints.broadcast || - (plan.stats(conf).sizeInBytes >= 0 && - plan.stats(conf).sizeInBytes <= conf.autoBroadcastJoinThreshold) + plan.stats.hints.broadcast || + (plan.stats.sizeInBytes >= 0 && + plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold) } /** @@ -126,7 +125,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * dynamic. */ private def canBuildLocalHashMap(plan: LogicalPlan): Boolean = { - plan.stats(conf).sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions + plan.stats.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions } /** @@ -137,7 +136,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * use the size of bytes here as estimation. */ private def muchSmaller(a: LogicalPlan, b: LogicalPlan): Boolean = { - a.stats(conf).sizeInBytes * 3 <= b.stats(conf).sizeInBytes + a.stats.sizeInBytes * 3 <= b.stats.sizeInBytes } private def canBuildRight(joinType: JoinType): Boolean = joinType match { @@ -206,7 +205,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Join(left, right, joinType, condition) => val buildSide = - if (right.stats(conf).sizeInBytes <= left.stats(conf).sizeInBytes) { + if (right.stats.sizeInBytes <= left.stats.sizeInBytes) { BuildRight } else { BuildLeft diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 456a8f3b20f3055cda32d289e9907dc36d919de0..2972132336de02f016350a7bcde5f1dad8bba6d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator @@ -70,7 +69,7 @@ case class InMemoryRelation( @transient val partitionStatistics = new PartitionStatistics(output) - override def computeStats(conf: SQLConf): Statistics = { + override def computeStats: Statistics = { if (batchStats.value == 0L) { // Underlying columnar RDD hasn't been materialized, no useful statistics information // available, return the default statistics. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 3813f953e06a3613d362ebeb47e9b1a04595efc5..c1b2895f1747e9abe42ff82143d9c92e60aefee5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -20,7 +20,6 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.util.Utils @@ -46,7 +45,7 @@ case class LogicalRelation( // Only care about relation when canonicalizing. override def preCanonicalized: LogicalPlan = copy(catalogTable = None) - @transient override def computeStats(conf: SQLConf): Statistics = { + @transient override def computeStats: Statistics = { catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse( Statistics(sizeInBytes = relation.sizeInBytes)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 7eaa803a9ecb4f006879bed22e91cb73fe4050cf..a5dac469f85b6ae713e2648ca39164f2f95ca3be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -230,6 +229,6 @@ case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode private val sizePerRow = sink.schema.toAttributes.map(_.dataType.defaultSize).sum - override def computeStats(conf: SQLConf): Statistics = + override def computeStats: Statistics = Statistics(sizePerRow * sink.allData.size) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 8532a5b5bc8eb8b0d00893d386df2ba2098defbc..506cc2548e260a6073924bdbbb5b8534b0a0f407 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -313,7 +313,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum - assert(cached.stats(sqlConf).sizeInBytes === actualSizeInBytes) + assert(cached.stats.sizeInBytes === actualSizeInBytes) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 165176f6c040e038c8a9babff111c83db2a51cc9..87b7b090de3bfb57b56a44d6d9f1b8a19ebd3846 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1146,7 +1146,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { // instead of Int for avoiding possible overflow. val ds = (0 to 10000).map( i => (i, Seq((i, Seq((i, "This is really not that long of a string")))))).toDS() - val sizeInBytes = ds.logicalPlan.stats(sqlConf).sizeInBytes + val sizeInBytes = ds.logicalPlan.stats.sizeInBytes // sizeInBytes is 2404280404, before the fix, it overflows to a negative number assert(sizeInBytes > 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 1a66aa85f5a02bfeea2c32c0aac0f0d6191261e1..895ca196a7a518fe79cc2b4d11517d90cdeb811f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -33,7 +33,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { setupTestData() def statisticSizeInByte(df: DataFrame): BigInt = { - df.queryExecution.optimizedPlan.stats(sqlConf).sizeInBytes + df.queryExecution.optimizedPlan.stats.sizeInBytes } test("equi-join is hash-join") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 601324f2c017267dc771d0d9ce993ccc1f3fb909..9824062f969b3bedf35cd48adb651bcc2f725066 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -60,7 +60,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared val df = df1.join(df2, Seq("k"), "left") val sizes = df.queryExecution.analyzed.collect { case g: Join => - g.stats(conf).sizeInBytes + g.stats.sizeInBytes } assert(sizes.size === 1, s"number of Join nodes is wrong:\n ${df.queryExecution}") @@ -107,9 +107,9 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared test("SPARK-15392: DataFrame created from RDD should not be broadcasted") { val rdd = sparkContext.range(1, 100).map(i => Row(i, i)) val df = spark.createDataFrame(rdd, new StructType().add("a", LongType).add("b", LongType)) - assert(df.queryExecution.analyzed.stats(conf).sizeInBytes > + assert(df.queryExecution.analyzed.stats.sizeInBytes > spark.sessionState.conf.autoBroadcastJoinThreshold) - assert(df.selectExpr("a").queryExecution.analyzed.stats(conf).sizeInBytes > + assert(df.selectExpr("a").queryExecution.analyzed.stats.sizeInBytes > spark.sessionState.conf.autoBroadcastJoinThreshold) } @@ -250,13 +250,13 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils test("SPARK-18856: non-empty partitioned table should not report zero size") { withTable("ds_tbl", "hive_tbl") { spark.range(100).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("ds_tbl") - val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats(conf) + val stats = spark.table("ds_tbl").queryExecution.optimizedPlan.stats assert(stats.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") if (spark.conf.get(StaticSQLConf.CATALOG_IMPLEMENTATION) == "hive") { sql("CREATE TABLE hive_tbl(i int) PARTITIONED BY (j int)") sql("INSERT INTO hive_tbl PARTITION(j=1) SELECT 1") - val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats(conf) + val stats2 = spark.table("hive_tbl").queryExecution.optimizedPlan.stats assert(stats2.sizeInBytes > 0, "non-empty partitioned table should not report zero size.") } } @@ -296,10 +296,10 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils assert(catalogTable.stats.get.colStats == Map("c1" -> emptyColStat)) // Check relation statistics - assert(relation.stats(conf).sizeInBytes == 0) - assert(relation.stats(conf).rowCount == Some(0)) - assert(relation.stats(conf).attributeStats.size == 1) - val (attribute, colStat) = relation.stats(conf).attributeStats.head + assert(relation.stats.sizeInBytes == 0) + assert(relation.stats.rowCount == Some(0)) + assert(relation.stats.attributeStats.size == 1) + val (attribute, colStat) = relation.stats.attributeStats.head assert(attribute.name == "c1") assert(colStat == emptyColStat) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 109b1d9db60d2552bbc1f69212badf4ebf9e935f..8d411eb191cd93547173bef333df4daebef64d77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -126,7 +126,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { .toDF().createOrReplaceTempView("sizeTst") spark.catalog.cacheTable("sizeTst") assert( - spark.table("sizeTst").queryExecution.analyzed.stats(sqlConf).sizeInBytes > + spark.table("sizeTst").queryExecution.analyzed.stats.sizeInBytes > spark.conf.get(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala index becb3aa27040128c58a1a347e714a5d0a58445ea..caf03885e38736badfda8d7fb351dd1547ee76b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFsRelationSuite.scala @@ -36,7 +36,7 @@ class HadoopFsRelationSuite extends QueryTest with SharedSQLContext { }) val totalSize = allFiles.map(_.length()).sum val df = spark.read.parquet(dir.toString) - assert(df.queryExecution.logical.stats(sqlConf).sizeInBytes === BigInt(totalSize)) + assert(df.queryExecution.logical.stats.sizeInBytes === BigInt(totalSize)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index 24a7b7740fa5b36d38e2cb4fd1abaeef65d4a556..e8420eee7fe9d460874a2900ec1bb8ea46bb4839 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -216,15 +216,15 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { // Before adding data, check output checkAnswer(sink.allData, Seq.empty) - assert(plan.stats(sqlConf).sizeInBytes === 0) + assert(plan.stats.sizeInBytes === 0) sink.addBatch(0, 1 to 3) plan.invalidateStatsCache() - assert(plan.stats(sqlConf).sizeInBytes === 12) + assert(plan.stats.sizeInBytes === 12) sink.addBatch(1, 4 to 6) plan.invalidateStatsCache() - assert(plan.stats(sqlConf).sizeInBytes === 24) + assert(plan.stats.sizeInBytes === 24) } ignore("stress test") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index f9b3ff8405823b72cf711d6b40ac35585301e612..0cfe260e52152f609b0de2952e09ff1c88d2aa1e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -21,7 +21,6 @@ import java.nio.charset.StandardCharsets import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits} -import org.apache.spark.sql.internal.SQLConf /** * A collection of sample data used in SQL tests. @@ -29,8 +28,6 @@ import org.apache.spark.sql.internal.SQLConf private[sql] trait SQLTestData { self => protected def spark: SparkSession - protected def sqlConf: SQLConf = spark.sessionState.conf - // Helper object to import SQL implicits without a concrete SQLContext private object internalImplicits extends SQLImplicits { protected override def _sqlContext: SQLContext = self.spark.sqlContext diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index ff5afc8e3ce05a6388bbed405579c075103f9fe7..808dc013f170beebbfe09ff7f05fff5fbf97759a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -154,7 +154,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log Some(partitionSchema)) val logicalRelation = cached.getOrElse { - val sizeInBytes = relation.stats(sparkSession.sessionState.conf).sizeInBytes.toLong + val sizeInBytes = relation.stats.sizeInBytes.toLong val fileIndex = { val index = new CatalogFileIndex(sparkSession, relation.tableMeta, sizeInBytes) if (lazyPruningEnabled) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 001bbc230ff1857ea1bd69da43a9319316e18800..279db9a397258e91094b4ed18aefdde30e92999b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -68,7 +68,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto assert(properties("totalSize").toLong <= 0, "external table totalSize must be <= 0") assert(properties("rawDataSize").toLong <= 0, "external table rawDataSize must be <= 0") - val sizeInBytes = relation.stats(conf).sizeInBytes + val sizeInBytes = relation.stats.sizeInBytes assert(sizeInBytes === BigInt(file1.length() + file2.length())) } } @@ -77,7 +77,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("analyze Hive serde tables") { def queryTotalSize(tableName: String): BigInt = - spark.table(tableName).queryExecution.analyzed.stats(conf).sizeInBytes + spark.table(tableName).queryExecution.analyzed.stats.sizeInBytes // Non-partitioned table sql("CREATE TABLE analyzeTable (key STRING, value STRING)").collect() @@ -659,7 +659,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto test("estimates the size of a test Hive serde tables") { val df = sql("""SELECT * FROM src""") val sizes = df.queryExecution.analyzed.collect { - case relation: CatalogRelation => relation.stats(conf).sizeInBytes + case relation: CatalogRelation => relation.stats.sizeInBytes } assert(sizes.size === 1, s"Size wrong for:\n ${df.queryExecution}") assert(sizes(0).equals(BigInt(5812)), @@ -679,7 +679,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // Assert src has a size smaller than the threshold. val sizes = df.queryExecution.analyzed.collect { - case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.stats(conf).sizeInBytes + case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.stats.sizeInBytes } assert(sizes.size === 2 && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold, @@ -733,7 +733,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto // Assert src has a size smaller than the threshold. val sizes = df.queryExecution.analyzed.collect { - case relation: CatalogRelation => relation.stats(conf).sizeInBytes + case relation: CatalogRelation => relation.stats.sizeInBytes } assert(sizes.size === 2 && sizes(1) <= spark.sessionState.conf.autoBroadcastJoinThreshold && sizes(0) <= spark.sessionState.conf.autoBroadcastJoinThreshold, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala index d91f25a4da013883397a9c9929f74e3aef6d5671..3a724aa14f2a97913fe5ba102da125a1de19b24a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneFileSourcePartitionsSuite.scala @@ -86,7 +86,7 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te case relation: LogicalRelation => relation } assert(relations.size === 1, s"Size wrong for:\n ${df.queryExecution}") - val size2 = relations(0).computeStats(conf).sizeInBytes + val size2 = relations(0).computeStats.sizeInBytes assert(size2 == relations(0).catalogTable.get.stats.get.sizeInBytes) assert(size2 < tableStats.get.sizeInBytes) }