diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6557c7005d1e144bff31a0a5bd194595964904ca..0139b9e87ce8448adf408f0a2356673db361288c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -46,6 +46,7 @@ object DefaultOptimizer extends Optimizer { PushPredicateThroughJoin, PushPredicateThroughProject, PushPredicateThroughGenerate, + PushPredicateThroughAggregate, ColumnPruning, // Operator combine ProjectCollapsing, @@ -674,6 +675,29 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp } } +/** + * Push [[Filter]] operators through [[Aggregate]] operators. Parts of the predicate that reference + * attributes which are subset of group by attribute set of [[Aggregate]] will be pushed beneath, + * and the rest should remain above. + */ +object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHelper { + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case filter @ Filter(condition, + aggregate @ Aggregate(groupingExpressions, aggregateExpressions, grandChild)) => + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { + conjunct => conjunct.references subsetOf AttributeSet(groupingExpressions) + } + if (pushDown.nonEmpty) { + val pushDownPredicate = pushDown.reduce(And) + val withPushdown = aggregate.copy(child = Filter(pushDownPredicate, grandChild)) + stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) + } else { + filter + } + } +} + /** * Pushes down [[Filter]] operators where the `condition` can be * evaluated using only the attributes of the left or right side of a join. Other diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 0f1fde2fb0f67707c561bb99a64d63d52454853c..ed810a12808f0e4984eb8a1eb417ed996932961c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -40,6 +40,7 @@ class FilterPushdownSuite extends PlanTest { BooleanSimplification, PushPredicateThroughJoin, PushPredicateThroughGenerate, + PushPredicateThroughAggregate, ColumnPruning, ProjectCollapsing) :: Nil } @@ -652,4 +653,48 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer.analyze) } + + test("aggregate: push down filter when filter on group by expression") { + val originalQuery = testRelation + .groupBy('a)('a, Count('b) as 'c) + .select('a, 'c) + .where('a === 2) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .where('a === 2) + .groupBy('a)('a, Count('b) as 'c) + .analyze + comparePlans(optimized, correctAnswer) + } + + test("aggregate: don't push down filter when filter not on group by expression") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)('a, Count('b) as 'c) + .where('c === 2L) + + val optimized = Optimize.execute(originalQuery.analyze) + + comparePlans(optimized, originalQuery.analyze) + } + + test("aggregate: push down filters partially which are subset of group by expressions") { + val originalQuery = testRelation + .select('a, 'b) + .groupBy('a)('a, Count('b) as 'c) + .where('c === 2L && 'a === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b) + .where('a === 3) + .groupBy('a)('a, Count('b) as 'c) + .where('c === 2L) + .analyze + + comparePlans(optimized, correctAnswer) + } }