Skip to content
Snippets Groups Projects
Commit ae253e5a authored by gatorsmile's avatar gatorsmile Committed by Wenchen Fan
Browse files

[SPARK-21273][SQL][FOLLOW-UP] Propagate logical plan stats using visitor pattern and mixin

## What changes were proposed in this pull request?
This PR is to add back the stats propagation of `Window` and remove the stats calculation of the leaf node `Range`, which has been covered by https://github.com/rxin/spark/blob/9c32d2507d3f4f269e17e841a4a4e4920b35a5e9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala#L56

## How was this patch tested?
Added two test cases.

Author: gatorsmile <gatorsmile@gmail.com>

Closes #18677 from gatorsmile/visitStats.
parent 81c99a5b
No related branches found
No related tags found
No related merge requests found
......@@ -35,13 +35,13 @@ trait LogicalPlanVisitor[T] {
case p: LocalLimit => visitLocalLimit(p)
case p: Pivot => visitPivot(p)
case p: Project => visitProject(p)
case p: Range => visitRange(p)
case p: Repartition => visitRepartition(p)
case p: RepartitionByExpression => visitRepartitionByExpr(p)
case p: ResolvedHint => visitHint(p)
case p: Sample => visitSample(p)
case p: ScriptTransformation => visitScriptTransform(p)
case p: Union => visitUnion(p)
case p: Window => visitWindow(p)
case p: LogicalPlan => default(p)
}
......@@ -73,8 +73,6 @@ trait LogicalPlanVisitor[T] {
def visitProject(p: Project): T
def visitRange(p: Range): T
def visitRepartition(p: Repartition): T
def visitRepartitionByExpr(p: RepartitionByExpression): T
......@@ -84,4 +82,6 @@ trait LogicalPlanVisitor[T] {
def visitScriptTransform(p: ScriptTransformation): T
def visitUnion(p: Union): T
def visitWindow(p: Window): T
}
......@@ -65,11 +65,6 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
ProjectEstimation.estimate(p).getOrElse(fallback(p))
}
override def visitRange(p: logical.Range): Statistics = {
val sizeInBytes = LongType.defaultSize * p.numElements
Statistics(sizeInBytes = sizeInBytes)
}
override def visitRepartition(p: Repartition): Statistics = fallback(p)
override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = fallback(p)
......@@ -79,4 +74,6 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
override def visitScriptTransform(p: ScriptTransformation): Statistics = fallback(p)
override def visitUnion(p: Union): Statistics = fallback(p)
override def visitWindow(p: Window): Statistics = fallback(p)
}
......@@ -136,10 +136,6 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
override def visitProject(p: Project): Statistics = visitUnaryNode(p)
override def visitRange(p: logical.Range): Statistics = {
p.computeStats()
}
override def visitRepartition(p: Repartition): Statistics = default(p)
override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = default(p)
......@@ -160,4 +156,6 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] {
override def visitUnion(p: Union): Statistics = {
Statistics(sizeInBytes = p.children.map(_.stats.sizeInBytes).sum)
}
override def visitWindow(p: Window): Statistics = visitUnaryNode(p)
}
......@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.statsEstimation
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
......@@ -54,6 +56,24 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase {
)
}
test("range") {
val range = Range(1, 5, 1, None)
val rangeStats = Statistics(sizeInBytes = 4 * 8)
checkStats(
range,
expectedStatsCboOn = rangeStats,
expectedStatsCboOff = rangeStats)
}
test("windows") {
val windows = plan.window(Seq(min(attribute).as('sum_attr)), Seq(attribute), Nil)
val windowsStats = Statistics(sizeInBytes = plan.size.get * (4 + 4 + 8) / (4 + 8))
checkStats(
windows,
expectedStatsCboOn = windowsStats,
expectedStatsCboOff = windowsStats)
}
test("limit estimation: limit < child's rowCount") {
val localLimit = LocalLimit(Literal(2), plan)
val globalLimit = GlobalLimit(Literal(2), plan)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment