Skip to content
Snippets Groups Projects
Commit 0af94e77 authored by Herman van Hovell's avatar Herman van Hovell Committed by gatorsmile
Browse files

[SPARK-18300][SQL] Do not apply foldable propagation with expand as a child.


## What changes were proposed in this pull request?
The `FoldablePropagation` optimizer rule, pulls foldable values out from under an `Expand`. This breaks the `Expand` in two ways:

- It rewrites the output attributes of the `Expand`. We explicitly define output attributes for `Expand`, these are (unfortunately) considered as part of the expressions of the `Expand` and can be rewritten.
- Expand can actually change the column (it will typically re-use the attributes or the underlying plan). This means that we cannot safely propagate the expressions from under an `Expand`.

This PR fixes this and (hopefully) other issues by explicitly whitelisting allowed operators.

## How was this patch tested?
Added tests to `FoldablePropagationSuite` and to `SQLQueryTestSuite`.

Author: Herman van Hovell <hvanhovell@databricks.com>

Closes #15857 from hvanhovell/SPARK-18300.

(cherry picked from commit f14ae490)
Signed-off-by: default avatargatorsmile <gatorsmile@gmail.com>
parent 0762c0ce
No related branches found
No related tags found
No related merge requests found
...@@ -428,43 +428,49 @@ object FoldablePropagation extends Rule[LogicalPlan] { ...@@ -428,43 +428,49 @@ object FoldablePropagation extends Rule[LogicalPlan] {
} }
case _ => Nil case _ => Nil
}) })
val replaceFoldable: PartialFunction[Expression, Expression] = {
case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
}
if (foldableMap.isEmpty) { if (foldableMap.isEmpty) {
plan plan
} else { } else {
var stop = false var stop = false
CleanupAliases(plan.transformUp { CleanupAliases(plan.transformUp {
case u: Union => // A leaf node should not stop the folding process (note that we are traversing up the
stop = true // tree, starting at the leaf nodes); so we are allowing it.
u case l: LeafNode =>
case c: Command => l
stop = true
c // Whitelist of all nodes we are allowed to apply this rule to.
// For outer join, although its output attributes are derived from its children, they are case p @ (_: Project | _: Filter | _: SubqueryAlias | _: Aggregate | _: Window |
// actually different attributes: the output of outer join is not always picked from its _: Sample | _: GlobalLimit | _: LocalLimit | _: Generate | _: Distinct |
// children, but can also be null. _: AppendColumns | _: AppendColumnsWithObject | _: BroadcastHint |
_: RedistributeData | _: Repartition | _: Sort | _: TypedFilter) if !stop =>
p.transformExpressions(replaceFoldable)
// Allow inner joins. We do not allow outer join, although its output attributes are
// derived from its children, they are actually different attributes: the output of outer
// join is not always picked from its children, but can also be null.
// TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes
// of outer join. // of outer join.
case j @ Join(_, _, LeftOuter | RightOuter | FullOuter, _) => case j @ Join(_, _, Inner, _) =>
j.transformExpressions(replaceFoldable)
// We can fold the projections an expand holds. However expand changes the output columns
// and often reuses the underlying attributes; so we cannot assume that a column is still
// foldable after the expand has been applied.
// TODO(hvanhovell): Expand should use new attributes as the output attributes.
case expand: Expand if !stop =>
val newExpand = expand.copy(projections = expand.projections.map { projection =>
projection.map(_.transform(replaceFoldable))
})
stop = true stop = true
j newExpand
// These 3 operators take attributes as constructor parameters, and these attributes case other =>
// can't be replaced by alias.
case m: MapGroups =>
stop = true
m
case f: FlatMapGroupsInR =>
stop = true
f
case c: CoGroup =>
stop = true stop = true
c other
case p: LogicalPlan if !stop => p.transformExpressions {
case a: AttributeReference if foldableMap.contains(a) =>
foldableMap(a)
}
}) })
} }
} }
......
...@@ -116,16 +116,35 @@ class FoldablePropagationSuite extends PlanTest { ...@@ -116,16 +116,35 @@ class FoldablePropagationSuite extends PlanTest {
test("Propagate in subqueries of Union queries") { test("Propagate in subqueries of Union queries") {
val query = Union( val query = Union(
Seq( Seq(
testRelation.select(Literal(1).as('x), 'a).select('x + 'a), testRelation.select(Literal(1).as('x), 'a).select('x, 'x + 'a),
testRelation.select(Literal(2).as('x), 'a).select('x + 'a))) testRelation.select(Literal(2).as('x), 'a).select('x, 'x + 'a)))
.select('x) .select('x)
val optimized = Optimize.execute(query.analyze) val optimized = Optimize.execute(query.analyze)
val correctAnswer = Union( val correctAnswer = Union(
Seq( Seq(
testRelation.select(Literal(1).as('x), 'a).select((Literal(1).as('x) + 'a).as("(x + a)")), testRelation.select(Literal(1).as('x), 'a)
testRelation.select(Literal(2).as('x), 'a).select((Literal(2).as('x) + 'a).as("(x + a)")))) .select(Literal(1).as('x), (Literal(1).as('x) + 'a).as("(x + a)")),
testRelation.select(Literal(2).as('x), 'a)
.select(Literal(2).as('x), (Literal(2).as('x) + 'a).as("(x + a)"))))
.select('x).analyze .select('x).analyze
comparePlans(optimized, correctAnswer)
}
test("Propagate in expand") {
val c1 = Literal(1).as('a)
val c2 = Literal(2).as('b)
val a1 = c1.toAttribute.withNullability(true)
val a2 = c2.toAttribute.withNullability(true)
val expand = Expand(
Seq(Seq(Literal(null), 'b), Seq('a, Literal(null))),
Seq(a1, a2),
OneRowRelation.select(c1, c2))
val query = expand.where(a1.isNotNull).select(a1, a2).analyze
val optimized = Optimize.execute(query)
val correctExpand = expand.copy(projections = Seq(
Seq(Literal(null), c2),
Seq(c1, Literal(null))))
val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze
comparePlans(optimized, correctAnswer) comparePlans(optimized, correctAnswer)
} }
} }
...@@ -32,3 +32,6 @@ SELECT a + 1 + 1, COUNT(b) FROM testData GROUP BY a + 1; ...@@ -32,3 +32,6 @@ SELECT a + 1 + 1, COUNT(b) FROM testData GROUP BY a + 1;
-- Aggregate with nulls. -- Aggregate with nulls.
SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a), AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a) SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a), AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a)
FROM testData; FROM testData;
-- Aggregate with foldable input and multiple distinct groups.
SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a;
-- Automatically generated by SQLQueryTestSuite -- Automatically generated by SQLQueryTestSuite
-- Number of queries: 14 -- Number of queries: 15
-- !query 0 -- !query 0
...@@ -131,3 +131,11 @@ FROM testData ...@@ -131,3 +131,11 @@ FROM testData
struct<skewness(CAST(a AS DOUBLE)):double,kurtosis(CAST(a AS DOUBLE)):double,min(a):int,max(a):int,avg(a):double,var_samp(CAST(a AS DOUBLE)):double,stddev_samp(CAST(a AS DOUBLE)):double,sum(a):bigint,count(a):bigint> struct<skewness(CAST(a AS DOUBLE)):double,kurtosis(CAST(a AS DOUBLE)):double,min(a):int,max(a):int,avg(a):double,var_samp(CAST(a AS DOUBLE)):double,stddev_samp(CAST(a AS DOUBLE)):double,sum(a):bigint,count(a):bigint>
-- !query 13 output -- !query 13 output
-0.2723801058145729 -1.5069204152249134 1 3 2.142857142857143 0.8095238095238094 0.8997354108424372 15 7 -0.2723801058145729 -1.5069204152249134 1 3 2.142857142857143 0.8095238095238094 0.8997354108424372 15 7
-- !query 14
SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a
-- !query 14 schema
struct<count(DISTINCT b):bigint,count(DISTINCT b, c):bigint>
-- !query 14 output
1 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment