diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 236476900a5197862dbf9857b9a0e3d2e90d739e..8595762988b4b1497c704ac153258932eb8864b0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -296,10 +296,13 @@ class Analyzer(
 
         val nonNullBitmask = x.bitmasks.reduce(_ & _)
 
-        val groupByAttributes = groupByAliases.zipWithIndex.map { case (a, idx) =>
+        val expandedAttributes = groupByAliases.zipWithIndex.map { case (a, idx) =>
           a.toAttribute.withNullability((nonNullBitmask & 1 << idx) == 0)
         }
 
+        val expand = Expand(x.bitmasks, groupByAliases, expandedAttributes, gid, x.child)
+        val groupingAttrs = expand.output.drop(x.child.output.length)
+
         val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr =>
           // collect all the found AggregateExpression, so we can check an expression is part of
           // any AggregateExpression or not.
@@ -321,15 +324,12 @@ class Analyzer(
               if (index == -1) {
                 e
               } else {
-                groupByAttributes(index)
+                groupingAttrs(index)
               }
           }.asInstanceOf[NamedExpression]
         }
 
-        Aggregate(
-          groupByAttributes :+ gid,
-          aggregations,
-          Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child))
+        Aggregate(groupingAttrs, aggregations, expand)
 
       case f @ Filter(cond, child) if hasGroupingFunction(cond) =>
         val groupingExprs = findGroupingExprs(child)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index ecc2d773e7753650547f55c4cbcd2401186e5614..e6d554565d442924625aca30dc69b6e26c805f97 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1020,8 +1020,6 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
     case filter @ Filter(_, f: Filter) => filter
     // should not push predicates through sample, or will generate different results.
     case filter @ Filter(_, s: Sample) => filter
-    // TODO: push predicates through expand
-    case filter @ Filter(_, e: Expand) => filter
 
     case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) =>
       pushDownPredicate(filter, u.child) { predicate =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index d4fc9e4da944aebd97539e1fc0b9fb46fda5800b..a445ce694750a57d5ee16d1481e0bd61ac80a28d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -516,7 +516,10 @@ private[sql] object Expand {
       // groupingId is the last output, here we use the bit mask as the concrete value for it.
       } :+ Literal.create(bitmask, IntegerType)
     }
-    val output = child.output ++ groupByAttrs :+ gid
+
+    // the `groupByAttrs` has different meaning in `Expand.output`, it could be the original
+    // grouping expression or null, so here we create new instance of it.
+    val output = child.output ++ groupByAttrs.map(_.newInstance) :+ gid
     Expand(projections, output, Project(child.output ++ groupByAliases, child))
   }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index df7529d83f7c829955f16d733844fb1723dc487f..9174b4e649a6eae7d3a44915f172082529f50850 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -743,4 +743,19 @@ class FilterPushdownSuite extends PlanTest {
 
     comparePlans(optimized, correctAnswer)
   }
+
+  test("expand") {
+    val agg = testRelation
+      .groupBy(Cube(Seq('a, 'b)))('a, 'b, sum('c))
+      .analyze
+      .asInstanceOf[Aggregate]
+
+    val a = agg.output(0)
+    val b = agg.output(1)
+
+    val query = agg.where(a > 1 && b > 2)
+    val optimized = Optimize.execute(query)
+    val correctedAnswer = agg.copy(child = agg.child.where(a > 1 && b > 2)).analyze
+    comparePlans(optimized, correctedAnswer)
+  }
 }
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
index e54358e657690c1aa15465880d6111f7db1a4c0e..2d44813f0eac55968c9e2ac840b5c374e929a469 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
@@ -288,8 +288,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
 
   private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = {
     assert(a.child == e && e.child == p)
-    a.groupingExpressions.forall(_.isInstanceOf[Attribute]) &&
-      sameOutput(e.output, p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute]))
+    a.groupingExpressions.forall(_.isInstanceOf[Attribute]) && sameOutput(
+      e.output.drop(p.child.output.length),
+      a.groupingExpressions.map(_.asInstanceOf[Attribute]))
   }
 
   private def groupingSetToSQL(
@@ -303,25 +304,28 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
 
     val numOriginalOutput = project.child.output.length
     // Assumption: Aggregate's groupingExpressions is composed of
-    // 1) the attributes of aliased group by expressions
+    // 1) the grouping attributes
     // 2) gid, which is always the last one
     val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute])
     // Assumption: Project's projectList is composed of
     // 1) the original output (Project's child.output),
     // 2) the aliased group by expressions.
+    val expandedAttributes = project.output.drop(numOriginalOutput)
     val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child)
     val groupingSQL = groupByExprs.map(_.sql).mkString(", ")
 
     // a map from group by attributes to the original group by expressions.
     val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs))
+    // a map from expanded attributes to the original group by expressions.
+    val expandedAttrMap = AttributeMap(expandedAttributes.zip(groupByExprs))
 
     val groupingSet: Seq[Seq[Expression]] = expand.projections.map { project =>
       // Assumption: expand.projections is composed of
       // 1) the original output (Project's child.output),
-      // 2) group by attributes(or null literal)
+      // 2) expanded attributes(or null literal)
       // 3) gid, which is always the last one in each project in Expand
       project.drop(numOriginalOutput).dropRight(1).collect {
-        case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr)
+        case attr: Attribute if expandedAttrMap.contains(attr) => expandedAttrMap(attr)
       }
     }
     val groupingSetSQL = "GROUPING SETS(" +