diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
index 7400a01918c5223a2bea0436a2a315bd2a6a8f2e..987cd7434b459cb15f01478ef23e5cd17f18c046 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.catalyst.optimizer
 
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
@@ -30,7 +29,7 @@ import org.apache.spark.sql.catalyst.rules._
  *    - Join with one or two empty children (including Intersect/Except).
  * 2. Unary-node Logical Plans
  *    - Project/Filter/Sample/Join/Limit/Repartition with all empty children.
- *    - Aggregate with all empty children and without AggregateFunction expressions like COUNT.
+ *    - Aggregate with all empty children and at least one grouping expression.
  *    - Generate(Explode) with all empty children. Others like Hive UDTF may return results.
  */
 object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
@@ -39,10 +38,6 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
     case _ => false
   }
 
-  private def containsAggregateExpression(e: Expression): Boolean = {
-    e.collectFirst { case _: AggregateFunction => () }.isDefined
-  }
-
   private def empty(plan: LogicalPlan) = LocalRelation(plan.output, data = Seq.empty)
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
@@ -68,8 +63,13 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
       case _: LocalLimit => empty(p)
       case _: Repartition => empty(p)
       case _: RepartitionByExpression => empty(p)
-      // AggregateExpressions like COUNT(*) return their results like 0.
-      case Aggregate(_, ae, _) if !ae.exists(containsAggregateExpression) => empty(p)
+      // An aggregate with non-empty group expression will return one output row per group when the
+      // input to the aggregate is not empty. If the input to the aggregate is empty then all groups
+      // will be empty and thus the output will be empty.
+      //
+      // If the grouping expressions are empty, however, then the aggregate will always produce a
+      // single output row and thus we cannot propagate the EmptyRelation.
+      case Aggregate(ge, _, _) if ge.nonEmpty => empty(p)
       // Generators like Hive-style UDTF may return their records within `close`.
       case Generate(_: Explode, _, _, _, _, _) => empty(p)
       case _ => p
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
index c261a6091d476638b218244cd2315c0396da561a..38dff4733f7140124481783fde2a7cd26d04b83b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
@@ -142,7 +142,7 @@ class PropagateEmptyRelationSuite extends PlanTest {
     comparePlans(optimized, correctAnswer.analyze)
   }
 
-  test("propagate empty relation through Aggregate without aggregate function") {
+  test("propagate empty relation through Aggregate with grouping expressions") {
     val query = testRelation1
       .where(false)
       .groupBy('a)('a, ('a + 1).as('x))
@@ -153,13 +153,13 @@ class PropagateEmptyRelationSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
-  test("don't propagate empty relation through Aggregate with aggregate function") {
+  test("don't propagate empty relation through Aggregate without grouping expressions") {
     val query = testRelation1
       .where(false)
-      .groupBy('a)(count('a))
+      .groupBy()()
 
     val optimized = Optimize.execute(query.analyze)
-    val correctAnswer = LocalRelation('a.int).groupBy('a)(count('a)).analyze
+    val correctAnswer = LocalRelation('a.int).groupBy()().analyze
 
     comparePlans(optimized, correctAnswer)
   }
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
index a7994f3beaff3c59166c3ff47864c9b2e4eee4f1..1e1384549a410897bc1d9340692b98c125ccf7e3 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
@@ -53,3 +53,10 @@ set spark.sql.groupByAliases=false;
 
 -- Check analysis exceptions
 SELECT a AS k, COUNT(b) FROM testData GROUP BY k;
+
+-- Aggregate with empty input and non-empty GroupBy expressions.
+SELECT a, COUNT(1) FROM testData WHERE false GROUP BY a;
+
+-- Aggregate with empty input and empty GroupBy expressions.
+SELECT COUNT(1) FROM testData WHERE false;
+SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t;
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
index 6bf9dff883c1e3f0d1dd2abf47dd845b096bb0cc..42e82308ee1f058bfeccdb0c905e604f2fd120b8 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 22
+-- Number of queries: 25
 
 
 -- !query 0
@@ -203,3 +203,27 @@ struct<>
 -- !query 21 output
 org.apache.spark.sql.AnalysisException
 cannot resolve '`k`' given input columns: [a, b]; line 1 pos 47
+
+
+-- !query 22
+SELECT a, COUNT(1) FROM testData WHERE false GROUP BY a
+-- !query 22 schema
+struct<a:int,count(1):bigint>
+-- !query 22 output
+
+
+
+-- !query 23
+SELECT COUNT(1) FROM testData WHERE false
+-- !query 23 schema
+struct<count(1):bigint>
+-- !query 23 output
+0
+
+
+-- !query 24
+SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t
+-- !query 24 schema
+struct<1:int>
+-- !query 24 output
+1