Skip to content
Snippets Groups Projects
Commit 4ba63b19 authored by jiangxingbo's avatar jiangxingbo Committed by Herman van Hovell
Browse files

[SPARK-17142][SQL] Complex query triggers binding error in HashAggregateExec

## What changes were proposed in this pull request?

In `ReorderAssociativeOperator` rule, we extract foldable expressions with Add/Multiply arithmetics, and replace with eval literal. For example, `(a + 1) + (b + 2)` is optimized to `(a + b + 3)` by this rule.
For aggregate operator, output expressions should be derived from groupingExpressions, current implemenation of `ReorderAssociativeOperator` rule may break this promise. A instance could be:
```
SELECT
  ((t1.a + 1) + (t2.a + 2)) AS out_col
FROM
  testdata2 AS t1
INNER JOIN
  testdata2 AS t2
ON
  (t1.a = t2.a)
GROUP BY (t1.a + 1), (t2.a + 2)
```
`((t1.a + 1) + (t2.a + 2))` is optimized to `(t1.a + t2.a + 3)`, which could not be derived from `ExpressionSet((t1.a +1), (t2.a + 2))`.
Maybe we should improve the rule of `ReorderAssociativeOperator` by adding a GroupingExpressionSet to keep Aggregate.groupingExpressions, and respect these expressions during the optimize stage.

## How was this patch tested?

Add new test case in `ReorderAssociativeOperatorSuite`.

Author: jiangxingbo <jiangxb1987@gmail.com>

Closes #14917 from jiangxb1987/rao.
parent 3f6a2bb3
No related branches found
No related tags found
No related merge requests found
......@@ -57,20 +57,37 @@ object ConstantFolding extends Rule[LogicalPlan] {
* Reorder associative integral-type operators and fold all constants into one.
*/
object ReorderAssociativeOperator extends Rule[LogicalPlan] {
private def flattenAdd(e: Expression): Seq[Expression] = e match {
case Add(l, r) => flattenAdd(l) ++ flattenAdd(r)
private def flattenAdd(
expression: Expression,
groupSet: ExpressionSet): Seq[Expression] = expression match {
case expr @ Add(l, r) if !groupSet.contains(expr) =>
flattenAdd(l, groupSet) ++ flattenAdd(r, groupSet)
case other => other :: Nil
}
private def flattenMultiply(e: Expression): Seq[Expression] = e match {
case Multiply(l, r) => flattenMultiply(l) ++ flattenMultiply(r)
private def flattenMultiply(
expression: Expression,
groupSet: ExpressionSet): Seq[Expression] = expression match {
case expr @ Multiply(l, r) if !groupSet.contains(expr) =>
flattenMultiply(l, groupSet) ++ flattenMultiply(r, groupSet)
case other => other :: Nil
}
private def collectGroupingExpressions(plan: LogicalPlan): ExpressionSet = plan match {
case Aggregate(groupingExpressions, aggregateExpressions, child) =>
ExpressionSet.apply(groupingExpressions)
case _ => ExpressionSet(Seq())
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsDown {
case q: LogicalPlan =>
// We have to respect aggregate expressions which exists in grouping expressions when plan
// is an Aggregate operator, otherwise the optimized expression could not be derived from
// grouping expressions.
val groupingExpressionSet = collectGroupingExpressions(q)
q transformExpressionsDown {
case a: Add if a.deterministic && a.dataType.isInstanceOf[IntegralType] =>
val (foldables, others) = flattenAdd(a).partition(_.foldable)
val (foldables, others) = flattenAdd(a, groupingExpressionSet).partition(_.foldable)
if (foldables.size > 1) {
val foldableExpr = foldables.reduce((x, y) => Add(x, y))
val c = Literal.create(foldableExpr.eval(EmptyRow), a.dataType)
......@@ -79,7 +96,7 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] {
a
}
case m: Multiply if m.deterministic && m.dataType.isInstanceOf[IntegralType] =>
val (foldables, others) = flattenMultiply(m).partition(_.foldable)
val (foldables, others) = flattenMultiply(m, groupingExpressionSet).partition(_.foldable)
if (foldables.size > 1) {
val foldableExpr = foldables.reduce((x, y) => Multiply(x, y))
val c = Literal.create(foldableExpr.eval(EmptyRow), m.dataType)
......
......@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
......@@ -60,4 +60,18 @@ class ReorderAssociativeOperatorSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
test("nested expression with aggregate operator") {
val originalQuery =
testRelation.as("t1")
.join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
.groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
(("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = originalQuery.analyze
comparePlans(optimized, correctAnswer)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment