diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b5fa372643bd42f7e0e59be1557abfa74c590bfa..268d7f21e6f85ecb9ebcdba7081725b5c05653f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -298,12 +298,10 @@ class Analyzer(
           }.asInstanceOf[NamedExpression]
         }
 
-        val child = Project(x.child.output ++ groupByAliases, x.child)
-
         Aggregate(
           groupByAttributes :+ VirtualColumn.groupingIdAttribute,
           aggregations,
-          Expand(x.bitmasks, groupByAttributes, gid, child))
+          Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child))
     }
   }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 411594c95166c3d016f43f4b7fb87b0c622d034a..3bc246a32dec88063522cc605a3d32346058fb3a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -449,21 +449,21 @@ private[sql] object Expand {
    * Extract attribute set according to the grouping id.
    *
    * @param bitmask bitmask to represent the selected of the attribute sequence
-   * @param exprs the attributes in sequence
+   * @param attrs the attributes in sequence
    * @return the attributes of non selected specified via bitmask (with the bit set to 1)
    */
-  private def buildNonSelectExprSet(
+  private def buildNonSelectAttrSet(
       bitmask: Int,
-      exprs: Seq[Expression]): ArrayBuffer[Expression] = {
-    val set = new ArrayBuffer[Expression](2)
+      attrs: Seq[Attribute]): AttributeSet = {
+    val nonSelect = new ArrayBuffer[Attribute]()
 
-    var bit = exprs.length - 1
+    var bit = attrs.length - 1
     while (bit >= 0) {
-      if (((bitmask >> bit) & 1) == 1) set += exprs(exprs.length - bit - 1)
+      if (((bitmask >> bit) & 1) == 1) nonSelect += attrs(attrs.length - bit - 1)
       bit -= 1
     }
 
-    set
+    AttributeSet(nonSelect)
   }
 
   /**
@@ -471,13 +471,15 @@ private[sql] object Expand {
    * multiple output rows for a input row.
    *
    * @param bitmasks The bitmask set represents the grouping sets
-   * @param groupByExprs The grouping by expressions
+   * @param groupByAliases The aliased original group by expressions
+   * @param groupByAttrs The attributes of aliased group by expressions
    * @param gid Attribute of the grouping id
    * @param child Child operator
    */
   def apply(
     bitmasks: Seq[Int],
-    groupByExprs: Seq[Expression],
+    groupByAliases: Seq[Alias],
+    groupByAttrs: Seq[Attribute],
     gid: Attribute,
     child: LogicalPlan): Expand = {
     // Create an array of Projections for the child projection, and replace the projections'
@@ -485,27 +487,21 @@ private[sql] object Expand {
     // are not set for this grouping set (according to the bit mask).
     val projections = bitmasks.map { bitmask =>
       // get the non selected grouping attributes according to the bit mask
-      val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs)
+      val nonSelectedGroupAttrSet = buildNonSelectAttrSet(bitmask, groupByAttrs)
 
-      (child.output :+ gid).map(expr => expr transformDown {
-        // TODO this causes a problem when a column is used both for grouping and aggregation.
-        case x: Expression if nonSelectedGroupExprSet.exists(_.semanticEquals(x)) =>
+      child.output ++ groupByAttrs.map { attr =>
+        if (nonSelectedGroupAttrSet.contains(attr)) {
           // if the input attribute in the Invalid Grouping Expression set of for this group
           // replace it with constant null
-          Literal.create(null, expr.dataType)
-        case x if x == gid =>
-          // replace the groupingId with concrete value (the bit mask)
-          Literal.create(bitmask, IntegerType)
-      })
-    }
-    val output = child.output.map { attr =>
-      if (groupByExprs.exists(_.semanticEquals(attr))) {
-        attr.withNullability(true)
-      } else {
-        attr
-      }
+          Literal.create(null, attr.dataType)
+        } else {
+          attr
+        }
+      // groupingId is the last output, here we use the bit mask as the concrete value for it.
+      } :+ Literal.create(bitmask, IntegerType)
     }
-    Expand(projections, output :+ gid, child)
+    val output = child.output ++ groupByAttrs :+ gid
+    Expand(projections, output, Project(child.output ++ groupByAliases, child))
   }
 }