diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 6557c7005d1e144bff31a0a5bd194595964904ca..0139b9e87ce8448adf408f0a2356673db361288c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -46,6 +46,7 @@ object DefaultOptimizer extends Optimizer {
       PushPredicateThroughJoin,
       PushPredicateThroughProject,
       PushPredicateThroughGenerate,
+      PushPredicateThroughAggregate,
       ColumnPruning,
       // Operator combine
       ProjectCollapsing,
@@ -674,6 +675,29 @@ object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelp
   }
 }
 
+/**
+ * Push [[Filter]] operators through [[Aggregate]] operators. Parts of the predicate that reference
+ * attributes which are subset of group by attribute set of [[Aggregate]] will be pushed beneath,
+ * and the rest should remain above.
+ */
+object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHelper {
+
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case filter @ Filter(condition,
+        aggregate @ Aggregate(groupingExpressions, aggregateExpressions, grandChild)) =>
+      val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition {
+        conjunct => conjunct.references subsetOf AttributeSet(groupingExpressions)
+      }
+      if (pushDown.nonEmpty) {
+        val pushDownPredicate = pushDown.reduce(And)
+        val withPushdown = aggregate.copy(child = Filter(pushDownPredicate, grandChild))
+        stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
+      } else {
+        filter
+      }
+  }
+}
+
 /**
  * Pushes down [[Filter]] operators where the `condition` can be
  * evaluated using only the attributes of the left or right side of a join.  Other
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 0f1fde2fb0f67707c561bb99a64d63d52454853c..ed810a12808f0e4984eb8a1eb417ed996932961c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -40,6 +40,7 @@ class FilterPushdownSuite extends PlanTest {
         BooleanSimplification,
         PushPredicateThroughJoin,
         PushPredicateThroughGenerate,
+        PushPredicateThroughAggregate,
         ColumnPruning,
         ProjectCollapsing) :: Nil
   }
@@ -652,4 +653,48 @@ class FilterPushdownSuite extends PlanTest {
 
     comparePlans(optimized, correctAnswer.analyze)
   }
+
+  test("aggregate: push down filter when filter on group by expression") {
+    val originalQuery = testRelation
+                        .groupBy('a)('a, Count('b) as 'c)
+                        .select('a, 'c)
+                        .where('a === 2)
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+
+    val correctAnswer = testRelation
+                        .where('a === 2)
+                        .groupBy('a)('a, Count('b) as 'c)
+                        .analyze
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("aggregate: don't push down filter when filter not on group by expression") {
+    val originalQuery = testRelation
+                        .select('a, 'b)
+                        .groupBy('a)('a, Count('b) as 'c)
+                        .where('c === 2L)
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+
+    comparePlans(optimized, originalQuery.analyze)
+  }
+
+  test("aggregate: push down filters partially which are subset of group by expressions") {
+    val originalQuery = testRelation
+                        .select('a, 'b)
+                        .groupBy('a)('a, Count('b) as 'c)
+                        .where('c === 2L && 'a === 3)
+
+    val optimized = Optimize.execute(originalQuery.analyze)
+
+    val correctAnswer = testRelation
+                        .select('a, 'b)
+                        .where('a === 3)
+                        .groupBy('a)('a, Count('b) as 'c)
+                        .where('c === 2L)
+                        .analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
 }