diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index fee7010e8e033863a70f0247311fcd11707a63b7..66e99ded24886030e7d62c5c5d45fc2ce61de753 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -164,7 +164,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
         // If an aggregation needs a shuffle and support partial aggregations, a map-side partial
         // aggregation and a shuffle are added as children.
         val (mergeAgg, mapSideAgg) = AggUtils.createMapMergeAggregatePair(operator)
-        (mergeAgg, createShuffleExchange(requiredChildDistributions.head, mapSideAgg) :: Nil)
+        (mergeAgg, createShuffleExchange(
+          requiredChildDistributions.head, ensureDistributionAndOrdering(mapSideAgg)) :: Nil)
       case _ =>
         // Ensure that the operator's children satisfy their output distribution requirements:
         val childrenWithDist = operator.children.zip(requiredChildDistributions)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 07efc72bf6296cbf66d0a071752f131d384362ba..b0aa3378e5f66088bab2b888a1ef4ba485dbf720 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -18,12 +18,13 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{execution, DataFrame, Row}
+import org.apache.spark.sql.{execution, Row}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
 import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.aggregate.SortAggregateExec
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
 import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange}
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
@@ -70,6 +71,25 @@ class PlannerSuite extends SharedSQLContext {
       s"The plan of query $query does not have partial aggregations.")
   }
 
+  test("SPARK-17289 sort-based partial aggregation needs a sort operator as a child") {
+    withTempView("testSortBasedPartialAggregation") {
+      val schema = StructType(
+        StructField(s"key", IntegerType, true) :: StructField(s"value", StringType, true) :: Nil)
+      val rowRDD = sparkContext.parallelize((0 until 1000).map(d => Row(d % 2, d.toString)))
+      spark.createDataFrame(rowRDD, schema)
+        .createOrReplaceTempView("testSortBasedPartialAggregation")
+
+      // This test assumes a query below uses sort-based aggregations
+      val planned = sql("SELECT MAX(value) FROM testSortBasedPartialAggregation GROUP BY key")
+        .queryExecution.executedPlan
+      // This line extracts both SortAggregate and Sort operators
+      val extractedOps = planned.collect { case n if n.nodeName contains "Sort" => n }
+      val aggOps = extractedOps.collect { case n if n.nodeName contains "SortAggregate" => n }
+      assert(extractedOps.size == 4 && aggOps.size == 2,
+        s"The plan $planned does not have correct sort-based partial aggregate pairs.")
+    }
+  }
+
   test("non-partial aggregation for aggregates") {
     withTempView("testNonPartialAggregation") {
       val schema = StructType(StructField(s"value", IntegerType, true) :: Nil)