diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 0f43e7bb88733b712af8f98ebe41b90170f2c57f..d6a39ecf53b86c70a54dbe3484dc7b163ca991a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -119,14 +119,16 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { .filter(_.isDistinct) .groupBy(_.aggregateFunction.children.toSet) - // Aggregation strategy can handle the query with single distinct - if (distinctAggGroups.size > 1) { + // Check if the aggregates contains functions that do not support partial aggregation. + val existsNonPartial = aggExpressions.exists(!_.aggregateFunction.supportsPartial) + + // Aggregation strategy can handle queries with a single distinct group and partial aggregates. + if (distinctAggGroups.size > 1 || (distinctAggGroups.size == 1 && existsNonPartial)) { // Create the attributes for the grouping id and the group by clause. - val gid = - new AttributeReference("gid", IntegerType, false)(isGenerated = true) + val gid = AttributeReference("gid", IntegerType, nullable = false)(isGenerated = true) val groupByMap = a.groupingExpressions.collect { case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> new AttributeReference(e.sql, e.dataType, e.nullable)() + case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)() } val groupByAttrs = groupByMap.map(_._2) @@ -135,9 +137,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { def patchAggregateFunctionChildren( af: AggregateFunction)( attrs: Expression => Expression): AggregateFunction = { - af.withNewChildren(af.children.map { - case afc => attrs(afc) - }).asInstanceOf[AggregateFunction] + af.withNewChildren(af.children.map(attrs)).asInstanceOf[AggregateFunction] } // Setup unique distinct aggregate children. @@ -265,5 +265,5 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // NamedExpression. This is done to prevent collisions between distinct and regular aggregate // children, in this case attribute reuse causes the input of the regular aggregate to bound to // the (nulled out) input of the distinct aggregate. - e -> new AttributeReference(e.sql, e.dataType, true)() + e -> AttributeReference(e.sql, e.dataType, nullable = true)() } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..0b973c3b659cf3d74b02253a276673820ada2e18 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{If, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} +import org.apache.spark.sql.types.{IntegerType, StringType} + +class RewriteDistinctAggregatesSuite extends PlanTest { + val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false) + val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) + val analyzer = new Analyzer(catalog, conf) + + val nullInt = Literal(null, IntegerType) + val nullString = Literal(null, StringType) + val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int) + + private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match { + case Aggregate(_, _, Aggregate(_, _, _: Expand)) => + case _ => fail(s"Plan is not rewritten:\n$rewrite") + } + + test("single distinct group") { + val input = testRelation + .groupBy('a)(countDistinct('e)) + .analyze + val rewrite = RewriteDistinctAggregates(input) + comparePlans(input, rewrite) + } + + test("single distinct group with partial aggregates") { + val input = testRelation + .groupBy('a, 'd)( + countDistinct('e, 'c).as('agg1), + max('b).as('agg2)) + .analyze + val rewrite = RewriteDistinctAggregates(input) + comparePlans(input, rewrite) + } + + test("single distinct group with non-partial aggregates") { + val input = testRelation + .groupBy('a, 'd)( + countDistinct('e, 'c).as('agg1), + CollectSet('b).toAggregateExpression().as('agg2)) + .analyze + checkRewrite(RewriteDistinctAggregates(input)) + } + + test("multiple distinct groups") { + val input = testRelation + .groupBy('a)(countDistinct('b, 'c), countDistinct('d)) + .analyze + checkRewrite(RewriteDistinctAggregates(input)) + } + + test("multiple distinct groups with partial aggregates") { + val input = testRelation + .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e)) + .analyze + checkRewrite(RewriteDistinctAggregates(input)) + } + + test("multiple distinct groups with non-partial aggregates") { + val input = testRelation + .groupBy('a)( + countDistinct('b, 'c), + countDistinct('d), + CollectSet('b).toAggregateExpression()) + .analyze + checkRewrite(RewriteDistinctAggregates(input)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 427390a90f1e6bce125fc3ec299dca4797103669..0e172bee4f66143c48c9cab32b025e04ee718724 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -493,4 +493,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(1.5)), Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(1.5)))) } + + test("SPARK-17616: distinct aggregate combined with a non-partial aggregate") { + val df = Seq((1, 3, "a"), (1, 2, "b"), (3, 4, "c"), (3, 4, "c"), (3, 5, "d")) + .toDF("x", "y", "z") + checkAnswer( + df.groupBy($"x").agg(countDistinct($"y"), sort_array(collect_list($"z"))), + Seq(Row(1, 2, Seq("a", "b")), Row(3, 2, Seq("c", "c", "d")))) + } }