Skip to content
Snippets Groups Projects
Commit 6ce008ba authored by gatorsmile's avatar gatorsmile Committed by Cheng Lian
Browse files

[SPARK-13549][SQL] Refactor the Optimizer Rule CollapseProject

#### What changes were proposed in this pull request?

The PR https://github.com/apache/spark/pull/10541 changed the rule `CollapseProject` by enabling collapsing `Project` into `Aggregate`. It leaves a to-do item to remove the duplicate code. This PR is to finish this to-do item. Also added a test case for covering this change.

#### How was this patch tested?

Added a new test case.

liancheng Could you check if the code refactoring is fine? Thanks!

Author: gatorsmile <gatorsmile@gmail.com>

Closes #11427 from gatorsmile/collapseProjectRefactor.
parent cde086cb
No related branches found
No related tags found
No related merge requests found
......@@ -417,68 +417,57 @@ object ColumnPruning extends Rule[LogicalPlan] {
object CollapseProject extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p @ Project(projectList1, Project(projectList2, child)) =>
// Create a map of Aliases to their values from the child projection.
// e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
val aliasMap = AttributeMap(projectList2.collect {
case a: Alias => (a.toAttribute, a)
})
// We only collapse these two Projects if their overlapped expressions are all
// deterministic.
val hasNondeterministic = projectList1.exists(_.collect {
case a: Attribute if aliasMap.contains(a) => aliasMap(a).child
}.exists(!_.deterministic))
if (hasNondeterministic) {
case p1 @ Project(_, p2: Project) =>
if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) {
p1
} else {
p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList))
}
case p @ Project(_, agg: Aggregate) =>
if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) {
p
} else {
// Substitute any attributes that are produced by the child projection, so that we safely
// eliminate it.
// e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
// TODO: Fix TransformBase to avoid the cast below.
val substitutedProjection = projectList1.map(_.transform {
case a: Attribute => aliasMap.getOrElse(a, a)
}).asInstanceOf[Seq[NamedExpression]]
// collapse 2 projects may introduce unnecessary Aliases, trim them here.
val cleanedProjection = substitutedProjection.map(p =>
CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
)
Project(cleanedProjection, child)
agg.copy(aggregateExpressions = buildCleanedProjectList(
p.projectList, agg.aggregateExpressions))
}
}
// TODO Eliminate duplicate code
// This clause is identical to the one above except that the inner operator is an `Aggregate`
// rather than a `Project`.
case p @ Project(projectList1, agg @ Aggregate(_, projectList2, child)) =>
// Create a map of Aliases to their values from the child projection.
// e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
val aliasMap = AttributeMap(projectList2.collect {
case a: Alias => (a.toAttribute, a)
})
private def collectAliases(projectList: Seq[NamedExpression]): AttributeMap[Alias] = {
AttributeMap(projectList.collect {
case a: Alias => a.toAttribute -> a
})
}
// We only collapse these two Projects if their overlapped expressions are all
// deterministic.
val hasNondeterministic = projectList1.exists(_.collect {
case a: Attribute if aliasMap.contains(a) => aliasMap(a).child
}.exists(!_.deterministic))
private def haveCommonNonDeterministicOutput(
upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = {
// Create a map of Aliases to their values from the lower projection.
// e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
val aliases = collectAliases(lower)
// Collapse upper and lower Projects if and only if their overlapped expressions are all
// deterministic.
upper.exists(_.collect {
case a: Attribute if aliases.contains(a) => aliases(a).child
}.exists(!_.deterministic))
}
if (hasNondeterministic) {
p
} else {
// Substitute any attributes that are produced by the child projection, so that we safely
// eliminate it.
// e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
// TODO: Fix TransformBase to avoid the cast below.
val substitutedProjection = projectList1.map(_.transform {
case a: Attribute => aliasMap.getOrElse(a, a)
}).asInstanceOf[Seq[NamedExpression]]
// collapse 2 projects may introduce unnecessary Aliases, trim them here.
val cleanedProjection = substitutedProjection.map(p =>
CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
)
agg.copy(aggregateExpressions = cleanedProjection)
}
private def buildCleanedProjectList(
upper: Seq[NamedExpression],
lower: Seq[NamedExpression]): Seq[NamedExpression] = {
// Create a map of Aliases to their values from the lower projection.
// e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
val aliases = collectAliases(lower)
// Substitute any attributes that are produced by the lower projection, so that we safely
// eliminate it.
// e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
val rewrittenUpper = upper.map(_.transform {
case a: Attribute => aliases.getOrElse(a, a)
})
// collapse upper and lower Projects may introduce unnecessary Aliases, trim them here.
rewrittenUpper.map { p =>
CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
}
}
}
......
......@@ -29,7 +29,7 @@ class CollapseProjectSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", FixedPoint(10), EliminateSubqueryAliases) ::
Batch("CollapseProject", Once, CollapseProject) :: Nil
Batch("CollapseProject", Once, CollapseProject) :: Nil
}
val testRelation = LocalRelation('a.int, 'b.int)
......@@ -95,4 +95,28 @@ class CollapseProjectSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
test("collapse project into aggregate") {
val query = testRelation
.groupBy('a, 'b)(('a + 1).as('a_plus_1), 'b)
.select('a_plus_1, ('b + 1).as('b_plus_1))
val optimized = Optimize.execute(query.analyze)
val correctAnswer = testRelation
.groupBy('a, 'b)(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze
comparePlans(optimized, correctAnswer)
}
test("do not collapse common nondeterministic project and aggregate") {
val query = testRelation
.groupBy('a)('a, Rand(10).as('rand))
.select(('rand + 1).as('rand1), ('rand + 2).as('rand2))
val optimized = Optimize.execute(query.analyze)
val correctAnswer = query.analyze
comparePlans(optimized, correctAnswer)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment