diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index 7400a01918c5223a2bea0436a2a315bd2a6a8f2e..987cd7434b459cb15f01478ef23e5cd17f18c046 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -30,7 +29,7 @@ import org.apache.spark.sql.catalyst.rules._ * - Join with one or two empty children (including Intersect/Except). * 2. Unary-node Logical Plans * - Project/Filter/Sample/Join/Limit/Repartition with all empty children. - * - Aggregate with all empty children and without AggregateFunction expressions like COUNT. + * - Aggregate with all empty children and at least one grouping expression. * - Generate(Explode) with all empty children. Others like Hive UDTF may return results. */ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { @@ -39,10 +38,6 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { case _ => false } - private def containsAggregateExpression(e: Expression): Boolean = { - e.collectFirst { case _: AggregateFunction => () }.isDefined - } - private def empty(plan: LogicalPlan) = LocalRelation(plan.output, data = Seq.empty) def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { @@ -68,8 +63,13 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper { case _: LocalLimit => empty(p) case _: Repartition => empty(p) case _: RepartitionByExpression => empty(p) - // AggregateExpressions like COUNT(*) return their results like 0. - case Aggregate(_, ae, _) if !ae.exists(containsAggregateExpression) => empty(p) + // An aggregate with non-empty group expression will return one output row per group when the + // input to the aggregate is not empty. If the input to the aggregate is empty then all groups + // will be empty and thus the output will be empty. + // + // If the grouping expressions are empty, however, then the aggregate will always produce a + // single output row and thus we cannot propagate the EmptyRelation. + case Aggregate(ge, _, _) if ge.nonEmpty => empty(p) // Generators like Hive-style UDTF may return their records within `close`. case Generate(_: Explode, _, _, _, _, _) => empty(p) case _ => p diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 908dde7a6698826f9930cea7995ba58150b8cec7..2285be16938d6263e1a28aa67d4570090c92d2d0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -142,7 +142,7 @@ class PropagateEmptyRelationSuite extends PlanTest { comparePlans(optimized, correctAnswer.analyze) } - test("propagate empty relation through Aggregate without aggregate function") { + test("propagate empty relation through Aggregate with grouping expressions") { val query = testRelation1 .where(false) .groupBy('a)('a, ('a + 1).as('x)) @@ -153,13 +153,13 @@ class PropagateEmptyRelationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("don't propagate empty relation through Aggregate with aggregate function") { + test("don't propagate empty relation through Aggregate without grouping expressions") { val query = testRelation1 .where(false) - .groupBy('a)(count('a)) + .groupBy()() val optimized = Optimize.execute(query.analyze) - val correctAnswer = LocalRelation('a.int).groupBy('a)(count('a)).analyze + val correctAnswer = LocalRelation('a.int).groupBy()().analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 4d0ed43153004a65f45f0e72767c9b1558e35657..f4f5a043d4781ddbe6c716e253595126c5cb69c7 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -35,3 +35,10 @@ FROM testData; -- Aggregate with foldable input and multiple distinct groups. SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; + +-- Aggregate with empty input and non-empty GroupBy expressions. +SELECT a, COUNT(1) FROM testData WHERE false GROUP BY a; + +-- Aggregate with empty input and empty GroupBy expressions. +SELECT COUNT(1) FROM testData WHERE false; +SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 4b87d5161fc0eb2dade287736e35ed138d3d3619..f64dd0007846ac219f830de2ac83207162e52f55 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 18 -- !query 0 @@ -139,3 +139,27 @@ SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS struct<count(DISTINCT b):bigint,count(DISTINCT b, c):bigint> -- !query 14 output 1 1 + + +-- !query 15 +SELECT a, COUNT(1) FROM testData WHERE false GROUP BY a +-- !query 15 schema +struct<a:int,count(1):bigint> +-- !query 15 output + + + +-- !query 16 +SELECT COUNT(1) FROM testData WHERE false +-- !query 16 schema +struct<count(1):bigint> +-- !query 16 output +0 + + +-- !query 17 +SELECT 1 FROM (SELECT COUNT(1) FROM testData WHERE false) t +-- !query 17 schema +struct<1:int> +-- !query 17 output +1