diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index fb975ee5e7296078afa4d39d9f355a72a518df01..4e8fc892f3eeae48cadcfea9af5d076bb057a5b0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -85,8 +85,9 @@ class CheckAnalysis {
 
             cleaned.foreach(checkValidAggregateExpression)
 
-          case o if o.children.nonEmpty && o.missingInput.nonEmpty =>
-            val missingAttributes = o.missingInput.map(_.prettyString).mkString(",")
+          case o if o.children.nonEmpty &&
+            !o.references.filter(_.name != "grouping__id").subsetOf(o.inputSet) =>
+            val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",")
             val input = o.inputSet.map(_.prettyString).mkString(",")
 
             failAnalysis(s"resolved attributes $missingAttributes missing from $input")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index a94b2d2095d1209e41ea0b0bf420859ac65110ce..384fe53a683621b8119ae601dda6a6185aa7c3e2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -191,8 +191,6 @@ case class Expand(
     val sizeInBytes = child.statistics.sizeInBytes * projections.length
     Statistics(sizeInBytes = sizeInBytes)
   }
-
-  override def missingInput = super.missingInput.filter(_.name != VirtualColumn.groupingIdName)
 }
 
 trait GroupingAnalytics extends UnaryNode {