diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 52f609bc158ca7a31d6225f9c79612935af9278f..2901d8f2efddf8b80be2b9b21b4ca212cfaecebd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -59,7 +59,7 @@ object DefaultOptimizer extends Optimizer {
       ConstantFolding,
       LikeSimplification,
       BooleanSimplification,
-      RemoveDispensable,
+      RemoveDispensableExpressions,
       SimplifyFilters,
       SimplifyCasts,
       SimplifyCaseConversionExpressions) ::
@@ -660,14 +660,14 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp
     case filter @ Filter(condition, g: Generate) =>
       // Predicates that reference attributes produced by the `Generate` operator cannot
       // be pushed below the operator.
-      val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition {
-        conjunct => conjunct.references subsetOf g.child.outputSet
+      val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond =>
+        cond.references subsetOf g.child.outputSet
       }
       if (pushDown.nonEmpty) {
         val pushDownPredicate = pushDown.reduce(And)
-        val withPushdown = Generate(g.generator, join = g.join, outer = g.outer,
+        val newGenerate = Generate(g.generator, join = g.join, outer = g.outer,
           g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child))
-        stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
+        if (stayUp.isEmpty) newGenerate else Filter(stayUp.reduce(And), newGenerate)
       } else {
         filter
       }
@@ -675,34 +675,34 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp
 }
 
 /**
- * Push [[Filter]] operators through [[Aggregate]] operators. Parts of the predicate that reference
- * attributes which are subset of group by attribute set of [[Aggregate]] will be pushed beneath,
- * and the rest should remain above.
+ * Push [[Filter]] operators through [[Aggregate]] operators, iff the filters reference only
+ * non-aggregate attributes (typically literals or grouping expressions).
  */
 object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHelper {
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case filter @ Filter(condition,
-        aggregate @ Aggregate(groupingExpressions, aggregateExpressions, grandChild)) =>
-
-      def hasAggregate(expression: Expression): Boolean = expression match {
-        case agg: AggregateExpression => true
-        case other => expression.children.exists(hasAggregate)
-      }
-      // Create a map of Alias for expressions that does not have AggregateExpression
-      val aliasMap = AttributeMap(aggregateExpressions.collect {
-        case a: Alias if !hasAggregate(a.child) => (a.toAttribute, a.child)
+    case filter @ Filter(condition, aggregate: Aggregate) =>
+      // Find all the aliased expressions in the aggregate list that don't include any actual
+      // AggregateExpression, and create a map from the alias to the expression
+      val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect {
+        case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty =>
+          (a.toAttribute, a.child)
       })
 
-      val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { conjunct =>
-        val replaced = replaceAlias(conjunct, aliasMap)
-        replaced.references.subsetOf(grandChild.outputSet) && replaced.deterministic
+      // For each filter, expand the alias and check if the filter can be evaluated using
+      // attributes produced by the aggregate operator's child operator.
+      val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond =>
+        val replaced = replaceAlias(cond, aliasMap)
+        replaced.references.subsetOf(aggregate.child.outputSet) && replaced.deterministic
       }
+
       if (pushDown.nonEmpty) {
         val pushDownPredicate = pushDown.reduce(And)
         val replaced = replaceAlias(pushDownPredicate, aliasMap)
-        val withPushdown = aggregate.copy(child = Filter(replaced, grandChild))
-        stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
+        val newAggregate = aggregate.copy(child = Filter(replaced, aggregate.child))
+        // If there is no more filter to stay up, just eliminate the filter.
+        // Otherwise, create "Filter(stayUp) <- Aggregate <- Filter(pushDownPredicate)".
+        if (stayUp.isEmpty) newAggregate else Filter(stayUp.reduce(And), newAggregate)
       } else {
         filter
       }
@@ -714,7 +714,7 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel
  * evaluated using only the attributes of the left or right side of a join.  Other
  * [[Filter]] conditions are moved into the `condition` of the [[Join]].
  *
- * And also Pushes down the join filter, where the `condition` can be evaluated using only the
+ * And also pushes down the join filter, where the `condition` can be evaluated using only the
  * attributes of the left or right side of sub query when applicable.
  *
  * Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details
@@ -821,7 +821,7 @@ object SimplifyCasts extends Rule[LogicalPlan] {
 /**
  * Removes nodes that are not necessary.
  */
-object RemoveDispensable extends Rule[LogicalPlan] {
+object RemoveDispensableExpressions extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
     case UnaryPositive(child) => child
     case PromotePrecision(child) => child
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 0128c220baaca2c20545dbc01244375e68c398ae..fba4c5ca77d64c7597566755d9778b53b8bbc315 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -734,7 +734,7 @@ class FilterPushdownSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
-  test("aggregate: don't push down filters which is nondeterministic") {
+  test("aggregate: don't push down filters that are nondeterministic") {
     val originalQuery = testRelation
       .select('a, 'b)
       .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd"))