Skip to content
Snippets Groups Projects
Commit 6e632012 authored by Dongjoon Hyun's avatar Dongjoon Hyun Committed by Michael Armbrust
Browse files

[SPARK-14830][SQL] Add RemoveRepetitionFromGroupExpressions optimizer.

## What changes were proposed in this pull request?

This PR aims to optimize GroupExpressions by removing repeating expressions. `RemoveRepetitionFromGroupExpressions` is added.

**Before**
```scala
scala> sql("select a+1 from values 1,2 T(a) group by a+1, 1+a, A+1, 1+A").explain()
== Physical Plan ==
WholeStageCodegen
:  +- TungstenAggregate(key=[(a#0 + 1)#6,(1 + a#0)#7,(A#0 + 1)#8,(1 + A#0)#9], functions=[], output=[(a + 1)#5])
:     +- INPUT
+- Exchange hashpartitioning((a#0 + 1)#6, (1 + a#0)#7, (A#0 + 1)#8, (1 + A#0)#9, 200), None
   +- WholeStageCodegen
      :  +- TungstenAggregate(key=[(a#0 + 1) AS (a#0 + 1)#6,(1 + a#0) AS (1 + a#0)#7,(A#0 + 1) AS (A#0 + 1)#8,(1 + A#0) AS (1 + A#0)#9], functions=[], output=[(a#0 + 1)#6,(1 + a#0)#7,(A#0 + 1)#8,(1 + A#0)#9])
      :     +- INPUT
      +- LocalTableScan [a#0], [[1],[2]]
```

**After**
```scala
scala> sql("select a+1 from values 1,2 T(a) group by a+1, 1+a, A+1, 1+A").explain()
== Physical Plan ==
WholeStageCodegen
:  +- TungstenAggregate(key=[(a#0 + 1)#6], functions=[], output=[(a + 1)#5])
:     +- INPUT
+- Exchange hashpartitioning((a#0 + 1)#6, 200), None
   +- WholeStageCodegen
      :  +- TungstenAggregate(key=[(a#0 + 1) AS (a#0 + 1)#6], functions=[], output=[(a#0 + 1)#6])
      :     +- INPUT
      +- LocalTableScan [a#0], [[1],[2]]
```

## How was this patch tested?

Pass the Jenkins tests (with a new testcase)

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #12590 from dongjoon-hyun/SPARK-14830.
parent a35a67a8
No related branches found
No related tags found
No related merge requests found
...@@ -68,7 +68,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) ...@@ -68,7 +68,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
ReplaceExceptWithAntiJoin, ReplaceExceptWithAntiJoin,
ReplaceDistinctWithAggregate) :: ReplaceDistinctWithAggregate) ::
Batch("Aggregate", fixedPoint, Batch("Aggregate", fixedPoint,
RemoveLiteralFromGroupExpressions) :: RemoveLiteralFromGroupExpressions,
RemoveRepetitionFromGroupExpressions) ::
Batch("Operator Optimizations", fixedPoint, Batch("Operator Optimizations", fixedPoint,
// Operator push down // Operator push down
SetOperationPushDown, SetOperationPushDown,
...@@ -1439,6 +1440,18 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { ...@@ -1439,6 +1440,18 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
} }
} }
/**
* Removes repetition from group expressions in [[Aggregate]], as they have no effect to the result
* but only makes the grouping key bigger.
*/
object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a @ Aggregate(grouping, _, _) =>
val newGrouping = ExpressionSet(grouping).toSeq
a.copy(groupingExpressions = newGrouping)
}
}
/** /**
* Computes the current date and time to make sure we return the same result in a single query. * Computes the current date and time to make sure we return the same result in a single query.
*/ */
......
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
package org.apache.spark.sql.catalyst.optimizer package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.expressions.Literal
...@@ -25,10 +28,14 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} ...@@ -25,10 +28,14 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.rules.RuleExecutor
class AggregateOptimizeSuite extends PlanTest { class AggregateOptimizeSuite extends PlanTest {
val conf = new SimpleCatalystConf(caseSensitiveAnalysis = false)
val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
val analyzer = new Analyzer(catalog, conf)
object Optimize extends RuleExecutor[LogicalPlan] { object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Aggregate", FixedPoint(100), val batches = Batch("Aggregate", FixedPoint(100),
RemoveLiteralFromGroupExpressions) :: Nil RemoveLiteralFromGroupExpressions,
RemoveRepetitionFromGroupExpressions) :: Nil
} }
test("remove literals in grouping expression") { test("remove literals in grouping expression") {
...@@ -42,4 +49,15 @@ class AggregateOptimizeSuite extends PlanTest { ...@@ -42,4 +49,15 @@ class AggregateOptimizeSuite extends PlanTest {
comparePlans(optimized, correctAnswer) comparePlans(optimized, correctAnswer)
} }
test("remove repetition in grouping expression") {
val input = LocalRelation('a.int, 'b.int, 'c.int)
val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c))
val optimized = Optimize.execute(analyzer.execute(query))
val correctAnswer = analyzer.execute(input.groupBy('a + 1, 'b + 2)(sum('c)))
comparePlans(optimized, correctAnswer)
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment