diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 52f609bc158ca7a31d6225f9c79612935af9278f..2901d8f2efddf8b80be2b9b21b4ca212cfaecebd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -59,7 +59,7 @@ object DefaultOptimizer extends Optimizer { ConstantFolding, LikeSimplification, BooleanSimplification, - RemoveDispensable, + RemoveDispensableExpressions, SimplifyFilters, SimplifyCasts, SimplifyCaseConversionExpressions) :: @@ -660,14 +660,14 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp case filter @ Filter(condition, g: Generate) => // Predicates that reference attributes produced by the `Generate` operator cannot // be pushed below the operator. - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { - conjunct => conjunct.references subsetOf g.child.outputSet + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => + cond.references subsetOf g.child.outputSet } if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) - val withPushdown = Generate(g.generator, join = g.join, outer = g.outer, + val newGenerate = Generate(g.generator, join = g.join, outer = g.outer, g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child)) - stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) + if (stayUp.isEmpty) newGenerate else Filter(stayUp.reduce(And), newGenerate) } else { filter } @@ -675,34 +675,34 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp } /** - * Push [[Filter]] operators through [[Aggregate]] operators. Parts of the predicate that reference - * attributes which are subset of group by attribute set of [[Aggregate]] will be pushed beneath, - * and the rest should remain above. + * Push [[Filter]] operators through [[Aggregate]] operators, iff the filters reference only + * non-aggregate attributes (typically literals or grouping expressions). */ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(condition, - aggregate @ Aggregate(groupingExpressions, aggregateExpressions, grandChild)) => - - def hasAggregate(expression: Expression): Boolean = expression match { - case agg: AggregateExpression => true - case other => expression.children.exists(hasAggregate) - } - // Create a map of Alias for expressions that does not have AggregateExpression - val aliasMap = AttributeMap(aggregateExpressions.collect { - case a: Alias if !hasAggregate(a.child) => (a.toAttribute, a.child) + case filter @ Filter(condition, aggregate: Aggregate) => + // Find all the aliased expressions in the aggregate list that don't include any actual + // AggregateExpression, and create a map from the alias to the expression + val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect { + case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty => + (a.toAttribute, a.child) }) - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { conjunct => - val replaced = replaceAlias(conjunct, aliasMap) - replaced.references.subsetOf(grandChild.outputSet) && replaced.deterministic + // For each filter, expand the alias and check if the filter can be evaluated using + // attributes produced by the aggregate operator's child operator. + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => + val replaced = replaceAlias(cond, aliasMap) + replaced.references.subsetOf(aggregate.child.outputSet) && replaced.deterministic } + if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) val replaced = replaceAlias(pushDownPredicate, aliasMap) - val withPushdown = aggregate.copy(child = Filter(replaced, grandChild)) - stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) + val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child)) + // If there is no more filter to stay up, just eliminate the filter. + // Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)". + if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate) } else { filter } @@ -714,7 +714,7 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel * evaluated using only the attributes of the left or right side of a join. Other * [[Filter]] conditions are moved into the `condition` of the [[Join]]. * - * And also Pushes down the join filter, where the `condition` can be evaluated using only the + * And also pushes down the join filter, where the `condition` can be evaluated using only the * attributes of the left or right side of sub query when applicable. * * Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details @@ -821,7 +821,7 @@ object SimplifyCasts extends Rule[LogicalPlan] { /** * Removes nodes that are not necessary. */ -object RemoveDispensable extends Rule[LogicalPlan] { +object RemoveDispensableExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case UnaryPositive(child) => child case PromotePrecision(child) => child diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 0128c220baaca2c20545dbc01244375e68c398ae..fba4c5ca77d64c7597566755d9778b53b8bbc315 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -734,7 +734,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("aggregate: don't push down filters which is nondeterministic") { + test("aggregate: don't push down filters that are nondeterministic") { val originalQuery = testRelation .select('a, 'b) .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd"))