diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
index 0f43e7bb88733b712af8f98ebe41b90170f2c57f..d6a39ecf53b86c70a54dbe3484dc7b163ca991a7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
@@ -119,14 +119,16 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
       .filter(_.isDistinct)
       .groupBy(_.aggregateFunction.children.toSet)
 
-    // Aggregation strategy can handle the query with single distinct
-    if (distinctAggGroups.size > 1) {
+    // Check if the aggregates contains functions that do not support partial aggregation.
+    val existsNonPartial = aggExpressions.exists(!_.aggregateFunction.supportsPartial)
+
+    // Aggregation strategy can handle queries with a single distinct group and partial aggregates.
+    if (distinctAggGroups.size > 1 || (distinctAggGroups.size == 1 && existsNonPartial)) {
       // Create the attributes for the grouping id and the group by clause.
-      val gid =
-        new AttributeReference("gid", IntegerType, false)(isGenerated = true)
+      val gid = AttributeReference("gid", IntegerType, nullable = false)(isGenerated = true)
       val groupByMap = a.groupingExpressions.collect {
         case ne: NamedExpression => ne -> ne.toAttribute
-        case e => e -> new AttributeReference(e.sql, e.dataType, e.nullable)()
+        case e => e -> AttributeReference(e.sql, e.dataType, e.nullable)()
       }
       val groupByAttrs = groupByMap.map(_._2)
 
@@ -135,9 +137,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
       def patchAggregateFunctionChildren(
           af: AggregateFunction)(
           attrs: Expression => Expression): AggregateFunction = {
-        af.withNewChildren(af.children.map {
-          case afc => attrs(afc)
-        }).asInstanceOf[AggregateFunction]
+        af.withNewChildren(af.children.map(attrs)).asInstanceOf[AggregateFunction]
       }
 
       // Setup unique distinct aggregate children.
@@ -265,5 +265,5 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
     // NamedExpression. This is done to prevent collisions between distinct and regular aggregate
     // children, in this case attribute reuse causes the input of the regular aggregate to bound to
     // the (nulled out) input of the distinct aggregate.
-    e -> new AttributeReference(e.sql, e.dataType, true)()
+    e -> AttributeReference(e.sql, e.dataType, nullable = true)()
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..0b973c3b659cf3d74b02253a276673820ada2e18
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.SimpleCatalystConf
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
+import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{If, Literal}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
+import org.apache.spark.sql.types.{IntegerType, StringType}
+
+class RewriteDistinctAggregatesSuite extends PlanTest {
+  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
+  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
+  val analyzer = new Analyzer(catalog, conf)
+
+  val nullInt = Literal(null, IntegerType)
+  val nullString = Literal(null, StringType)
+  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)
+
+  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
+    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
+    case _ => fail(s"Plan is not rewritten:\n$rewrite")
+  }
+
+  test("single distinct group") {
+    val input = testRelation
+      .groupBy('a)(countDistinct('e))
+      .analyze
+    val rewrite = RewriteDistinctAggregates(input)
+    comparePlans(input, rewrite)
+  }
+
+  test("single distinct group with partial aggregates") {
+    val input = testRelation
+      .groupBy('a, 'd)(
+        countDistinct('e, 'c).as('agg1),
+        max('b).as('agg2))
+      .analyze
+    val rewrite = RewriteDistinctAggregates(input)
+    comparePlans(input, rewrite)
+  }
+
+  test("single distinct group with non-partial aggregates") {
+    val input = testRelation
+      .groupBy('a, 'd)(
+        countDistinct('e, 'c).as('agg1),
+        CollectSet('b).toAggregateExpression().as('agg2))
+      .analyze
+    checkRewrite(RewriteDistinctAggregates(input))
+  }
+
+  test("multiple distinct groups") {
+    val input = testRelation
+      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
+      .analyze
+    checkRewrite(RewriteDistinctAggregates(input))
+  }
+
+  test("multiple distinct groups with partial aggregates") {
+    val input = testRelation
+      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
+      .analyze
+    checkRewrite(RewriteDistinctAggregates(input))
+  }
+
+  test("multiple distinct groups with non-partial aggregates") {
+    val input = testRelation
+      .groupBy('a)(
+        countDistinct('b, 'c),
+        countDistinct('d),
+        CollectSet('b).toAggregateExpression())
+      .analyze
+    checkRewrite(RewriteDistinctAggregates(input))
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 427390a90f1e6bce125fc3ec299dca4797103669..0e172bee4f66143c48c9cab32b025e04ee718724 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -493,4 +493,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
         Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(1.5)),
         Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(1.5))))
   }
+
+  test("SPARK-17616: distinct aggregate combined with a non-partial aggregate") {
+    val df = Seq((1, 3, "a"), (1, 2, "b"), (3, 4, "c"), (3, 4, "c"), (3, 5, "d"))
+      .toDF("x", "y", "z")
+    checkAnswer(
+      df.groupBy($"x").agg(countDistinct($"y"), sort_array(collect_list($"z"))),
+      Seq(Row(1, 2, Seq("a", "b")), Row(3, 2, Seq("c", "c", "d"))))
+  }
 }