diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 8dbec408002f11fd72a1ad458a334696f698f8c9..dd68d60d3e8396db1efccfe05171e2b15c989e37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -217,11 +217,9 @@ class Analyzer(
      *  Group Count: N + 1 (N is the number of group expressions)
      *
      *  We need to get all of its subsets for the rule described above, the subset is
-     *  represented as the bit masks.
+     *  represented as sequence of expressions.
      */
-    def bitmasks(r: Rollup): Seq[Int] = {
-      Seq.tabulate(r.groupByExprs.length + 1)(idx => (1 << idx) - 1)
-    }
+    def rollupExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.inits.toSeq
 
     /*
      *  GROUP BY a, b, c WITH CUBE
@@ -230,10 +228,14 @@ class Analyzer(
      *  Group Count: 2 ^ N (N is the number of group expressions)
      *
      *  We need to get all of its subsets for a given GROUPBY expression, the subsets are
-     *  represented as the bit masks.
+     *  represented as sequence of expressions.
      */
-    def bitmasks(c: Cube): Seq[Int] = {
-      Seq.tabulate(1 << c.groupByExprs.length)(i => i)
+    def cubeExprs(exprs: Seq[Expression]): Seq[Seq[Expression]] = exprs.toList match {
+      case x :: xs =>
+        val initial = cubeExprs(xs)
+        initial.map(x +: _) ++ initial
+      case Nil =>
+        Seq(Seq.empty)
     }
 
     private def hasGroupingAttribute(expr: Expression): Boolean = {
@@ -256,17 +258,17 @@ class Analyzer(
       expr transform {
         case e: GroupingID =>
           if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) {
-            gid
+            Alias(gid, toPrettySQL(e))()
           } else {
             throw new AnalysisException(
               s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " +
                 s"grouping columns (${groupByExprs.mkString(",")})")
           }
-        case Grouping(col: Expression) =>
+        case e @ Grouping(col: Expression) =>
           val idx = groupByExprs.indexOf(col)
           if (idx >= 0) {
-            Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)),
-              Literal(1)), ByteType)
+            Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)),
+              Literal(1)), ByteType), toPrettySQL(e))()
           } else {
             throw new AnalysisException(s"Column of grouping ($col) can't be found " +
               s"in grouping columns ${groupByExprs.mkString(",")}")
@@ -274,85 +276,107 @@ class Analyzer(
       }
     }
 
-    // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort
-    def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
-      case a if !a.childrenResolved => a // be sure all of the children are resolved.
-      case p if p.expressions.exists(hasGroupingAttribute) =>
-        failAnalysis(
-          s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead")
-
-      case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
-        GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions)
-      case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
-        GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions)
+    /*
+     * Create new alias for all group by expressions for `Expand` operator.
+     */
+    private def constructGroupByAlias(groupByExprs: Seq[Expression]): Seq[Alias] = {
+      groupByExprs.map {
+        case e: NamedExpression => Alias(e, e.name)()
+        case other => Alias(other, other.toString)()
+      }
+    }
 
-      // Ensure all the expressions have been resolved.
-      case x: GroupingSets if x.expressions.forall(_.resolved) =>
-        val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
-
-        // Expand works by setting grouping expressions to null as determined by the bitmasks. To
-        // prevent these null values from being used in an aggregate instead of the original value
-        // we need to create new aliases for all group by expressions that will only be used for
-        // the intended purpose.
-        val groupByAliases: Seq[Alias] = x.groupByExprs.map {
-          case e: NamedExpression => Alias(e, e.name)()
-          case other => Alias(other, other.toString)()
+    /*
+     * Construct [[Expand]] operator with grouping sets.
+     */
+    private def constructExpand(
+        selectedGroupByExprs: Seq[Seq[Expression]],
+        child: LogicalPlan,
+        groupByAliases: Seq[Alias],
+        gid: Attribute): LogicalPlan = {
+      // Change the nullability of group by aliases if necessary. For example, if we have
+      // GROUPING SETS ((a,b), a), we do not need to change the nullability of a, but we
+      // should change the nullabilty of b to be TRUE.
+      // TODO: For Cube/Rollup just set nullability to be `true`.
+      val expandedAttributes = groupByAliases.map { alias =>
+        if (selectedGroupByExprs.exists(!_.contains(alias.child))) {
+          alias.toAttribute.withNullability(true)
+        } else {
+          alias.toAttribute
         }
+      }
 
-        // The rightmost bit in the bitmasks corresponds to the last expression in groupByAliases
-        // with 0 indicating this expression is in the grouping set. The following line of code
-        // calculates the bitmask representing the expressions that absent in at least one grouping
-        // set (indicated by 1).
-        val nullBitmask = x.bitmasks.reduce(_ | _)
-
-        val attrLength = groupByAliases.length
-        val expandedAttributes = groupByAliases.zipWithIndex.map { case (a, idx) =>
-          a.toAttribute.withNullability(((nullBitmask >> (attrLength - idx - 1)) & 1) == 1)
+      val groupingSetsAttributes = selectedGroupByExprs.map { groupingSetExprs =>
+        groupingSetExprs.map { expr =>
+          val alias = groupByAliases.find(_.child.semanticEquals(expr)).getOrElse(
+            failAnalysis(s"$expr doesn't show up in the GROUP BY list $groupByAliases"))
+          // Map alias to expanded attribute.
+          expandedAttributes.find(_.semanticEquals(alias.toAttribute)).getOrElse(
+            alias.toAttribute)
         }
+      }
 
-        val expand = Expand(x.bitmasks, groupByAliases, expandedAttributes, gid, x.child)
-        val groupingAttrs = expand.output.drop(x.child.output.length)
+      Expand(groupingSetsAttributes, groupByAliases, expandedAttributes, gid, child)
+    }
 
-        val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr =>
-          // collect all the found AggregateExpression, so we can check an expression is part of
-          // any AggregateExpression or not.
-          val aggsBuffer = ArrayBuffer[Expression]()
-          // Returns whether the expression belongs to any expressions in `aggsBuffer` or not.
-          def isPartOfAggregation(e: Expression): Boolean = {
-            aggsBuffer.exists(a => a.find(_ eq e).isDefined)
+    /*
+     * Construct new aggregate expressions by replacing grouping functions.
+     */
+    private def constructAggregateExprs(
+        groupByExprs: Seq[Expression],
+        aggregations: Seq[NamedExpression],
+        groupByAliases: Seq[Alias],
+        groupingAttrs: Seq[Expression],
+        gid: Attribute): Seq[NamedExpression] = aggregations.map {
+      // collect all the found AggregateExpression, so we can check an expression is part of
+      // any AggregateExpression or not.
+      val aggsBuffer = ArrayBuffer[Expression]()
+      // Returns whether the expression belongs to any expressions in `aggsBuffer` or not.
+      def isPartOfAggregation(e: Expression): Boolean = {
+        aggsBuffer.exists(a => a.find(_ eq e).isDefined)
+      }
+      replaceGroupingFunc(_, groupByExprs, gid).transformDown {
+        // AggregateExpression should be computed on the unmodified value of its argument
+        // expressions, so we should not replace any references to grouping expression
+        // inside it.
+        case e: AggregateExpression =>
+          aggsBuffer += e
+          e
+        case e if isPartOfAggregation(e) => e
+        case e =>
+          // Replace expression by expand output attribute.
+          val index = groupByAliases.indexWhere(_.child.semanticEquals(e))
+          if (index == -1) {
+            e
+          } else {
+            groupingAttrs(index)
           }
-          replaceGroupingFunc(expr, x.groupByExprs, gid).transformDown {
-            // AggregateExpression should be computed on the unmodified value of its argument
-            // expressions, so we should not replace any references to grouping expression
-            // inside it.
-            case e: AggregateExpression =>
-              aggsBuffer += e
-              e
-            case e if isPartOfAggregation(e) => e
-            case e =>
-              val index = groupByAliases.indexWhere(_.child.semanticEquals(e))
-              if (index == -1) {
-                e
-              } else {
-                groupingAttrs(index)
-              }
-          }.asInstanceOf[NamedExpression]
-        }
+      }.asInstanceOf[NamedExpression]
+    }
 
-        Aggregate(groupingAttrs, aggregations, expand)
+    /*
+     * Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets.
+     */
+    private def constructAggregate(
+        selectedGroupByExprs: Seq[Seq[Expression]],
+        groupByExprs: Seq[Expression],
+        aggregationExprs: Seq[NamedExpression],
+        child: LogicalPlan): LogicalPlan = {
+      val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
 
-      case f @ Filter(cond, child) if hasGroupingFunction(cond) =>
-        val groupingExprs = findGroupingExprs(child)
-        // The unresolved grouping id will be resolved by ResolveMissingReferences
-        val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute)
-        f.copy(condition = newCond)
+      // Expand works by setting grouping expressions to null as determined by the
+      // `selectedGroupByExprs`. To prevent these null values from being used in an aggregate
+      // instead of the original value we need to create new aliases for all group by expressions
+      // that will only be used for the intended purpose.
+      val groupByAliases = constructGroupByAlias(groupByExprs)
 
-      case s @ Sort(order, _, child) if order.exists(hasGroupingFunction) =>
-        val groupingExprs = findGroupingExprs(child)
-        val gid = VirtualColumn.groupingIdAttribute
-        // The unresolved grouping id will be resolved by ResolveMissingReferences
-        val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder])
-        s.copy(order = newOrder)
+      val expand = constructExpand(selectedGroupByExprs, child, groupByAliases, gid)
+      val groupingAttrs = expand.output.drop(child.output.length)
+
+      val aggregations = constructAggregateExprs(
+        groupByExprs, aggregationExprs, groupByAliases, groupingAttrs, gid)
+
+      Aggregate(groupingAttrs, aggregations, expand)
     }
 
     private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = {
@@ -369,6 +393,41 @@ class Analyzer(
         failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
       }
     }
+
+    // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort
+    def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+      case a if !a.childrenResolved => a // be sure all of the children are resolved.
+      case p if p.expressions.exists(hasGroupingAttribute) =>
+        failAnalysis(
+          s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead")
+
+      // Ensure group by expressions and aggregate expressions have been resolved.
+      case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child)
+        if (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
+        constructAggregate(cubeExprs(groupByExprs), groupByExprs, aggregateExpressions, child)
+      case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child)
+        if (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
+        constructAggregate(rollupExprs(groupByExprs), groupByExprs, aggregateExpressions, child)
+      // Ensure all the expressions have been resolved.
+      case x: GroupingSets if x.expressions.forall(_.resolved) =>
+        constructAggregate(x.selectedGroupByExprs, x.groupByExprs, x.aggregations, x.child)
+
+      // We should make sure all expressions in condition have been resolved.
+      case f @ Filter(cond, child) if hasGroupingFunction(cond) && cond.resolved =>
+        val groupingExprs = findGroupingExprs(child)
+        // The unresolved grouping id will be resolved by ResolveMissingReferences
+        val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute)
+        f.copy(condition = newCond)
+
+      // We should make sure all [[SortOrder]]s have been resolved.
+      case s @ Sort(order, _, child)
+        if order.exists(hasGroupingFunction) && order.forall(_.resolved) =>
+        val groupingExprs = findGroupingExprs(child)
+        val gid = VirtualColumn.groupingIdAttribute
+        // The unresolved grouping id will be resolved by ResolveMissingReferences
+        val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder])
+        s.copy(order = newOrder)
+    }
   }
 
   object ResolvePivot extends Rule[LogicalPlan] {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 4b151c81d8f8b1bfd0369a999ca790fee15778a7..2c4db0d2c34257c3775a2c9e535cec064a842149 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -492,33 +492,18 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
       ctx: AggregationContext,
       selectExpressions: Seq[NamedExpression],
       query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
-    import ctx._
-    val groupByExpressions = expressionList(groupingExpressions)
+    val groupByExpressions = expressionList(ctx.groupingExpressions)
 
-    if (GROUPING != null) {
+    if (ctx.GROUPING != null) {
       // GROUP BY .... GROUPING SETS (...)
-      val expressionMap = groupByExpressions.zipWithIndex.toMap
-      val numExpressions = expressionMap.size
-      val mask = (1 << numExpressions) - 1
-      val masks = ctx.groupingSet.asScala.map {
-        _.expression.asScala.foldLeft(mask) {
-          case (bitmap, eCtx) =>
-            // Find the index of the expression.
-            val e = typedVisit[Expression](eCtx)
-            val index = expressionMap.find(_._1.semanticEquals(e)).map(_._2).getOrElse(
-              throw new ParseException(
-                s"$e doesn't show up in the GROUP BY list", ctx))
-            // 0 means that the column at the given index is a grouping column, 1 means it is not,
-            // so we unset the bit in bitmap.
-            bitmap & ~(1 << (numExpressions - 1 - index))
-        }
-      }
-      GroupingSets(masks, groupByExpressions, query, selectExpressions)
+      val selectedGroupByExprs =
+        ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)))
+      GroupingSets(selectedGroupByExprs, groupByExpressions, query, selectExpressions)
     } else {
       // GROUP BY .... (WITH CUBE | WITH ROLLUP)?
-      val mappedGroupByExpressions = if (CUBE != null) {
+      val mappedGroupByExpressions = if (ctx.CUBE != null) {
         Seq(Cube(groupByExpressions))
-      } else if (ROLLUP != null) {
+      } else if (ctx.ROLLUP != null) {
         Seq(Rollup(groupByExpressions))
       } else {
         groupByExpressions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 65ceab2ce27b154177a574904e5968f44c0e1ad0..dcae7b026f58c75ef690b0ee69efccebabc6bc83 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -17,8 +17,7 @@
 
 package org.apache.spark.sql.catalyst.plans.logical
 
-import scala.collection.mutable.ArrayBuffer
-
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes
@@ -523,51 +522,56 @@ case class Window(
 
 object Expand {
   /**
-   * Extract attribute set according to the grouping id.
+   * Build bit mask from attributes of selected grouping set. A bit in the bitmask is corresponding
+   * to an attribute in group by attributes sequence, the selected attribute has corresponding bit
+   * set to 0 and otherwise set to 1. For example, if we have GroupBy attributes (a, b, c, d), the
+   * bitmask 5(whose binary form is 0101) represents grouping set (a, c).
    *
-   * @param bitmask bitmask to represent the selected of the attribute sequence
-   * @param attrs the attributes in sequence
-   * @return the attributes of non selected specified via bitmask (with the bit set to 1)
+   * @param groupingSetAttrs The attributes of selected grouping set
+   * @param attrMap Mapping group by attributes to its index in attributes sequence
+   * @return The bitmask which represents the selected attributes out of group by attributes.
    */
-  private def buildNonSelectAttrSet(
-      bitmask: Int,
-      attrs: Seq[Attribute]): AttributeSet = {
-    val nonSelect = new ArrayBuffer[Attribute]()
-
-    var bit = attrs.length - 1
-    while (bit >= 0) {
-      if (((bitmask >> bit) & 1) == 1) nonSelect += attrs(attrs.length - bit - 1)
-      bit -= 1
-    }
-
-    AttributeSet(nonSelect)
+  private def buildBitmask(
+    groupingSetAttrs: Seq[Attribute],
+    attrMap: Map[Attribute, Int]): Int = {
+    val numAttributes = attrMap.size
+    val mask = (1 << numAttributes) - 1
+    // Calculate the attrbute masks of selected grouping set. For example, if we have GroupBy
+    // attributes (a, b, c, d), grouping set (a, c) will produce the following sequence:
+    // (15, 7, 13), whose binary form is (1111, 0111, 1101)
+    val masks = (mask +: groupingSetAttrs.map(attrMap).map(index =>
+      // 0 means that the column at the given index is a grouping column, 1 means it is not,
+      // so we unset the bit in bitmap.
+      ~(1 << (numAttributes - 1 - index))
+    ))
+    // Reduce masks to generate an bitmask for the selected grouping set.
+    masks.reduce(_ & _)
   }
 
   /**
    * Apply the all of the GroupExpressions to every input row, hence we will get
    * multiple output rows for an input row.
    *
-   * @param bitmasks The bitmask set represents the grouping sets
+   * @param groupingSetsAttrs The attributes of grouping sets
    * @param groupByAliases The aliased original group by expressions
    * @param groupByAttrs The attributes of aliased group by expressions
    * @param gid Attribute of the grouping id
    * @param child Child operator
    */
   def apply(
-    bitmasks: Seq[Int],
+    groupingSetsAttrs: Seq[Seq[Attribute]],
     groupByAliases: Seq[Alias],
     groupByAttrs: Seq[Attribute],
     gid: Attribute,
     child: LogicalPlan): Expand = {
+    val attrMap = groupByAttrs.zipWithIndex.toMap
+
     // Create an array of Projections for the child projection, and replace the projections'
     // expressions which equal GroupBy expressions with Literal(null), if those expressions
-    // are not set for this grouping set (according to the bit mask).
-    val projections = bitmasks.map { bitmask =>
-      // get the non selected grouping attributes according to the bit mask
-      val nonSelectedGroupAttrSet = buildNonSelectAttrSet(bitmask, groupByAttrs)
-
+    // are not set for this grouping set.
+    val projections = groupingSetsAttrs.map { groupingSetAttrs =>
       child.output ++ groupByAttrs.map { attr =>
-        if (nonSelectedGroupAttrSet.contains(attr)) {
+        if (!groupingSetAttrs.contains(attr)) {
           // if the input attribute in the Invalid Grouping Expression set of for this group
           // replace it with constant null
           Literal.create(null, attr.dataType)
@@ -575,7 +579,7 @@ object Expand {
           attr
         }
       // groupingId is the last output, here we use the bit mask as the concrete value for it.
-      } :+ Literal.create(bitmask, IntegerType)
+      } :+ Literal.create(buildBitmask(groupingSetAttrs, attrMap), IntegerType)
     }
 
     // the `groupByAttrs` has different meaning in `Expand.output`, it could be the original
@@ -616,16 +620,15 @@ case class Expand(
  *
  * We will transform GROUPING SETS into logical plan Aggregate(.., Expand) in Analyzer
  *
- * @param bitmasks     A list of bitmasks, each of the bitmask indicates the selected
- *                     GroupBy expressions
- * @param groupByExprs The Group By expressions candidates, take effective only if the
- *                     associated bit in the bitmask set to 1.
+ * @param selectedGroupByExprs A sequence of selected GroupBy expressions, all exprs should
+ *                     exists in groupByExprs.
+ * @param groupByExprs The Group By expressions candidates.
  * @param child        Child operator
  * @param aggregations The Aggregation expressions, those non selected group by expressions
  *                     will be considered as constant null if it appears in the expressions
  */
 case class GroupingSets(
-    bitmasks: Seq[Int],
+    selectedGroupByExprs: Seq[Seq[Expression]],
     groupByExprs: Seq[Expression],
     child: LogicalPlan,
     aggregations: Seq[NamedExpression]) extends UnaryNode {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..2a0205bdc90fe043c21a11c33452a384d807d390
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala
@@ -0,0 +1,291 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types._
+
+class ResolveGroupingAnalyticsSuite extends AnalysisTest {
+
+  lazy val a = 'a.int
+  lazy val b = 'b.string
+  lazy val c = 'c.string
+  lazy val unresolved_a = UnresolvedAttribute("a")
+  lazy val unresolved_b = UnresolvedAttribute("b")
+  lazy val unresolved_c = UnresolvedAttribute("c")
+  lazy val gid = 'spark_grouping_id.int.withNullability(false)
+  lazy val hive_gid = 'grouping__id.int.withNullability(false)
+  lazy val grouping_a = Cast(ShiftRight(gid, 1) & 1, ByteType)
+  lazy val nulInt = Literal(null, IntegerType)
+  lazy val nulStr = Literal(null, StringType)
+  lazy val r1 = LocalRelation(a, b, c)
+
+  test("rollupExprs") {
+    val testRollup = (exprs: Seq[Expression], rollup: Seq[Seq[Expression]]) => {
+      val result = SimpleAnalyzer.ResolveGroupingAnalytics.rollupExprs(exprs)
+      assert(result.sortBy(_.hashCode) == rollup.sortBy(_.hashCode))
+    }
+
+    testRollup(Seq(a, b, c), Seq(Seq(), Seq(a), Seq(a, b), Seq(a, b, c)))
+    testRollup(Seq(c, b, a), Seq(Seq(), Seq(c), Seq(c, b), Seq(c, b, a)))
+    testRollup(Seq(a), Seq(Seq(), Seq(a)))
+    testRollup(Seq(), Seq(Seq()))
+  }
+
+  test("cubeExprs") {
+    val testCube = (exprs: Seq[Expression], cube: Seq[Seq[Expression]]) => {
+      val result = SimpleAnalyzer.ResolveGroupingAnalytics.cubeExprs(exprs)
+      assert(result.sortBy(_.hashCode) == cube.sortBy(_.hashCode))
+    }
+
+    testCube(Seq(a, b, c),
+      Seq(Seq(), Seq(a), Seq(b), Seq(c), Seq(a, b), Seq(a, c), Seq(b, c), Seq(a, b, c)))
+    testCube(Seq(c, b, a),
+      Seq(Seq(), Seq(a), Seq(b), Seq(c), Seq(c, b), Seq(c, a), Seq(b, a), Seq(c, b, a)))
+    testCube(Seq(a), Seq(Seq(), Seq(a)))
+    testCube(Seq(), Seq(Seq()))
+  }
+
+  test("grouping sets") {
+    val originalPlan = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)),
+      Seq(unresolved_a, unresolved_b), r1,
+      Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))))
+    val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")),
+      Expand(
+        Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
+        Seq(a, b, c, a, b, gid),
+        Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+    checkAnalysis(originalPlan, expected)
+
+    val originalPlan2 = GroupingSets(Seq(), Seq(unresolved_a, unresolved_b), r1,
+      Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))))
+    val expected2 = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")),
+      Expand(
+        Seq(),
+        Seq(a, b, c, a, b, gid),
+        Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+    checkAnalysis(originalPlan2, expected2)
+
+    val originalPlan3 = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b),
+      Seq(unresolved_c)), Seq(unresolved_a, unresolved_b), r1,
+      Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))))
+    assertAnalysisError(originalPlan3, Seq("doesn't show up in the GROUP BY list"))
+  }
+
+  test("cube") {
+    val originalPlan = Aggregate(Seq(Cube(Seq(unresolved_a, unresolved_b))),
+      Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1)
+    val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")),
+      Expand(
+        Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1),
+          Seq(a, b, c, nulInt, b, 2), Seq(a, b, c, nulInt, nulStr, 3)),
+        Seq(a, b, c, a, b, gid),
+        Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+    checkAnalysis(originalPlan, expected)
+
+    val originalPlan2 = Aggregate(Seq(Cube(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1)
+    val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")),
+      Expand(
+        Seq(Seq(a, b, c, 0)),
+        Seq(a, b, c, gid),
+        Project(Seq(a, b, c), r1)))
+    checkAnalysis(originalPlan2, expected2)
+  }
+
+  test("rollup") {
+    val originalPlan = Aggregate(Seq(Rollup(Seq(unresolved_a, unresolved_b))),
+      Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c))), r1)
+    val expected = Aggregate(Seq(a, b, gid), Seq(a, b, count(c).as("count(c)")),
+      Expand(
+        Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, nulInt, nulStr, 3)),
+        Seq(a, b, c, a, b, gid),
+        Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+    checkAnalysis(originalPlan, expected)
+
+    val originalPlan2 = Aggregate(Seq(Rollup(Seq())), Seq(UnresolvedAlias(count(unresolved_c))), r1)
+    val expected2 = Aggregate(Seq(gid), Seq(count(c).as("count(c)")),
+      Expand(
+        Seq(Seq(a, b, c, 0)),
+        Seq(a, b, c, gid),
+        Project(Seq(a, b, c), r1)))
+    checkAnalysis(originalPlan2, expected2)
+  }
+
+  test("grouping function") {
+    // GrouingSets
+    val originalPlan = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)),
+      Seq(unresolved_a, unresolved_b), r1,
+      Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)),
+        UnresolvedAlias(Grouping(unresolved_a))))
+    val expected = Aggregate(Seq(a, b, gid),
+      Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")),
+      Expand(
+        Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
+        Seq(a, b, c, a, b, gid),
+        Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+    checkAnalysis(originalPlan, expected)
+
+    // Cube
+    val originalPlan2 = Aggregate(Seq(Cube(Seq(unresolved_a, unresolved_b))),
+      Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)),
+        UnresolvedAlias(Grouping(unresolved_a))), r1)
+    val expected2 = Aggregate(Seq(a, b, gid),
+      Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")),
+      Expand(
+        Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1),
+          Seq(a, b, c, nulInt, b, 2), Seq(a, b, c, nulInt, nulStr, 3)),
+        Seq(a, b, c, a, b, gid),
+        Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+    checkAnalysis(originalPlan2, expected2)
+
+    // Rollup
+    val originalPlan3 = Aggregate(Seq(Rollup(Seq(unresolved_a, unresolved_b))),
+      Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)),
+        UnresolvedAlias(Grouping(unresolved_a))), r1)
+    val expected3 = Aggregate(Seq(a, b, gid),
+      Seq(a, b, count(c).as("count(c)"), grouping_a.as("grouping(a)")),
+      Expand(
+        Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, nulInt, nulStr, 3)),
+        Seq(a, b, c, a, b, gid),
+        Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+    checkAnalysis(originalPlan3, expected3)
+  }
+
+  test("grouping_id") {
+    // GrouingSets
+    val originalPlan = GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)),
+      Seq(unresolved_a, unresolved_b), r1,
+      Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)),
+        UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))))
+    val expected = Aggregate(Seq(a, b, gid),
+      Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")),
+      Expand(
+        Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
+        Seq(a, b, c, a, b, gid),
+        Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+    checkAnalysis(originalPlan, expected)
+
+    // Cube
+    val originalPlan2 = Aggregate(Seq(Cube(Seq(unresolved_a, unresolved_b))),
+      Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)),
+        UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1)
+    val expected2 = Aggregate(Seq(a, b, gid),
+      Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")),
+      Expand(
+        Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1),
+          Seq(a, b, c, nulInt, b, 2), Seq(a, b, c, nulInt, nulStr, 3)),
+        Seq(a, b, c, a, b, gid),
+        Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+    checkAnalysis(originalPlan2, expected2)
+
+    // Rollup
+    val originalPlan3 = Aggregate(Seq(Rollup(Seq(unresolved_a, unresolved_b))),
+      Seq(unresolved_a, unresolved_b, UnresolvedAlias(count(unresolved_c)),
+        UnresolvedAlias(GroupingID(Seq(unresolved_a, unresolved_b)))), r1)
+    val expected3 = Aggregate(Seq(a, b, gid),
+      Seq(a, b, count(c).as("count(c)"), gid.as("grouping_id(a, b)")),
+      Expand(
+        Seq(Seq(a, b, c, a, b, 0), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, nulInt, nulStr, 3)),
+        Seq(a, b, c, a, b, gid),
+        Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))
+    checkAnalysis(originalPlan3, expected3)
+  }
+
+  test("filter with grouping function") {
+    // Filter with Grouping function
+    val originalPlan = Filter(Grouping(unresolved_a) === 0,
+      GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)),
+        Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b)))
+    val expected = Project(Seq(a, b), Filter(Cast(grouping_a, IntegerType) === 0,
+      Aggregate(Seq(a, b, gid),
+        Seq(a, b, gid),
+        Expand(
+          Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
+          Seq(a, b, c, a, b, gid),
+          Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))))
+    checkAnalysis(originalPlan, expected)
+
+    val originalPlan2 = Filter(Grouping(unresolved_a) === 0,
+      Aggregate(Seq(unresolved_a), Seq(UnresolvedAlias(count(unresolved_b))), r1))
+    assertAnalysisError(originalPlan2,
+      Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup"))
+
+    // Filter with GroupingID
+    val originalPlan3 = Filter(GroupingID(Seq(unresolved_a, unresolved_b)) === 1,
+      GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)),
+        Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b)))
+    val expected3 = Project(Seq(a, b), Filter(gid === 1,
+      Aggregate(Seq(a, b, gid),
+        Seq(a, b, gid),
+        Expand(
+          Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
+          Seq(a, b, c, a, b, gid),
+          Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))))
+    checkAnalysis(originalPlan3, expected3)
+
+    val originalPlan4 = Filter(GroupingID(Seq(unresolved_a)) === 1,
+      Aggregate(Seq(unresolved_a), Seq(UnresolvedAlias(count(unresolved_b))), r1))
+    assertAnalysisError(originalPlan4,
+      Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup"))
+  }
+
+  test("sort with grouping function") {
+    // Sort with Grouping function
+    val originalPlan = Sort(
+      Seq(SortOrder(Grouping(unresolved_a), Ascending)), true,
+      GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)),
+        Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b)))
+    val expected = Project(Seq(a, b), Sort(
+      Seq(SortOrder('aggOrder.byte.withNullability(false), Ascending)), true,
+      Aggregate(Seq(a, b, gid),
+        Seq(a, b, grouping_a.as("aggOrder")),
+        Expand(
+          Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
+          Seq(a, b, c, a, b, gid),
+          Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))))
+    checkAnalysis(originalPlan, expected)
+
+    val originalPlan2 = Sort(Seq(SortOrder(Grouping(unresolved_a), Ascending)), true,
+      Aggregate(Seq(unresolved_a), Seq(unresolved_a, UnresolvedAlias(count(unresolved_b))), r1))
+    assertAnalysisError(originalPlan2,
+      Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup"))
+
+    // Sort with GroupingID
+    val originalPlan3 = Sort(
+      Seq(SortOrder(GroupingID(Seq(unresolved_a, unresolved_b)), Ascending)), true,
+      GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)),
+        Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b)))
+    val expected3 = Project(Seq(a, b), Sort(
+      Seq(SortOrder('aggOrder.int.withNullability(false), Ascending)), true,
+      Aggregate(Seq(a, b, gid),
+        Seq(a, b, gid.as("aggOrder")),
+        Expand(
+          Seq(Seq(a, b, c, nulInt, nulStr, 3), Seq(a, b, c, a, nulStr, 1), Seq(a, b, c, a, b, 0)),
+          Seq(a, b, c, a, b, gid),
+          Project(Seq(a, b, c, a.as("a"), b.as("b")), r1)))))
+    checkAnalysis(originalPlan3, expected3)
+
+    val originalPlan4 = Sort(
+      Seq(SortOrder(GroupingID(Seq(unresolved_a)), Ascending)), true,
+      Aggregate(Seq(unresolved_a), Seq(unresolved_a, UnresolvedAlias(count(unresolved_b))), r1))
+    assertAnalysisError(originalPlan4,
+      Seq("grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup"))
+  }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 7400f3430e99c7c6d6986f1016b3e95d65467380..5f0f6ee479c69750a6e87e1830c46a19f524364e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -233,9 +233,8 @@ class PlanParserSuite extends PlanTest {
 
     // Grouping Sets
     assertEqual(s"$sql grouping sets((a, b), (a), ())",
-      GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c"))))
-    intercept(s"$sql grouping sets((a, b), (c), ())",
-      "c doesn't show up in the GROUP BY list")
+      GroupingSets(Seq(Seq('a, 'b), Seq('a), Seq()), Seq('a, 'b), table("d"),
+        Seq('a, 'b, 'sum.function('c).as("c"))))
   }
 
   test("limit") {