diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index be13cbc51a9d3f2f636e9e48d57f414aa3043b6a..644358493e2ebc616002ec04ae682350dce93e64 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -38,10 +38,9 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter}
-import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution}
+import org.apache.spark.sql.types.{StringType, StructType}
 import org.apache.spark.util.{SerializableConfiguration, Utils}
-import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
 
 
 /** A helper object for writing FileFormat data out to a location. */
@@ -64,9 +63,9 @@ object FileFormatWriter extends Logging {
       val serializableHadoopConf: SerializableConfiguration,
       val outputWriterFactory: OutputWriterFactory,
       val allColumns: Seq[Attribute],
-      val partitionColumns: Seq[Attribute],
       val dataColumns: Seq[Attribute],
-      val bucketSpec: Option[BucketSpec],
+      val partitionColumns: Seq[Attribute],
+      val bucketIdExpression: Option[Expression],
       val path: String,
       val customPartitionLocations: Map[TablePartitionSpec, String],
       val maxRecordsPerFile: Long)
@@ -108,9 +107,21 @@ object FileFormatWriter extends Logging {
     job.setOutputValueClass(classOf[InternalRow])
     FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))
 
+    val allColumns = queryExecution.logical.output
     val partitionSet = AttributeSet(partitionColumns)
     val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains)
 
+    val bucketIdExpression = bucketSpec.map { spec =>
+      val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
+      // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
+      // guarantee the data distribution is same between shuffle and bucketed data source, which
+      // enables us to only shuffle one side when join a bucketed table and a normal one.
+      HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
+    }
+    val sortColumns = bucketSpec.toSeq.flatMap {
+      spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
+    }
+
     // Note: prepareWrite has side effect. It sets "job".
     val outputWriterFactory =
       fileFormat.prepareWrite(sparkSession, job, options, dataColumns.toStructType)
@@ -119,23 +130,45 @@ object FileFormatWriter extends Logging {
       uuid = UUID.randomUUID().toString,
       serializableHadoopConf = new SerializableConfiguration(job.getConfiguration),
       outputWriterFactory = outputWriterFactory,
-      allColumns = queryExecution.logical.output,
-      partitionColumns = partitionColumns,
+      allColumns = allColumns,
       dataColumns = dataColumns,
-      bucketSpec = bucketSpec,
+      partitionColumns = partitionColumns,
+      bucketIdExpression = bucketIdExpression,
       path = outputSpec.outputPath,
       customPartitionLocations = outputSpec.customPartitionLocations,
       maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong)
         .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile)
     )
 
+    // We should first sort by partition columns, then bucket id, and finally sorting columns.
+    val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
+    // the sort order doesn't matter
+    val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child)
+    val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
+      false
+    } else {
+      requiredOrdering.zip(actualOrdering).forall {
+        case (requiredOrder, childOutputOrder) =>
+          requiredOrder.semanticEquals(childOutputOrder)
+      }
+    }
+
     SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
       // This call shouldn't be put into the `try` block below because it only initializes and
       // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
       committer.setupJob(job)
 
       try {
-        val ret = sparkSession.sparkContext.runJob(queryExecution.toRdd,
+        val rdd = if (orderingMatched) {
+          queryExecution.toRdd
+        } else {
+          SortExec(
+            requiredOrdering.map(SortOrder(_, Ascending)),
+            global = false,
+            child = queryExecution.executedPlan).execute()
+        }
+
+        val ret = sparkSession.sparkContext.runJob(rdd,
           (taskContext: TaskContext, iter: Iterator[InternalRow]) => {
             executeTask(
               description = description,
@@ -189,7 +222,7 @@ object FileFormatWriter extends Logging {
     committer.setupTask(taskAttemptContext)
 
     val writeTask =
-      if (description.partitionColumns.isEmpty && description.bucketSpec.isEmpty) {
+      if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) {
         new SingleDirectoryWriteTask(description, taskAttemptContext, committer)
       } else {
         new DynamicPartitionWriteTask(description, taskAttemptContext, committer)
@@ -287,31 +320,16 @@ object FileFormatWriter extends Logging {
    * multiple directories (partitions) or files (bucketing).
    */
   private class DynamicPartitionWriteTask(
-      description: WriteJobDescription,
+      desc: WriteJobDescription,
       taskAttemptContext: TaskAttemptContext,
       committer: FileCommitProtocol) extends ExecuteWriteTask {
 
     // currentWriter is initialized whenever we see a new key
     private var currentWriter: OutputWriter = _
 
-    private val bucketColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap {
-      spec => spec.bucketColumnNames.map(c => description.allColumns.find(_.name == c).get)
-    }
-
-    private val sortColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap {
-      spec => spec.sortColumnNames.map(c => description.allColumns.find(_.name == c).get)
-    }
-
-    private def bucketIdExpression: Option[Expression] = description.bucketSpec.map { spec =>
-      // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
-      // guarantee the data distribution is same between shuffle and bucketed data source, which
-      // enables us to only shuffle one side when join a bucketed table and a normal one.
-      HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
-    }
-
-    /** Expressions that given a partition key build a string like: col1=val/col2=val/... */
-    private def partitionStringExpression: Seq[Expression] = {
-      description.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
+    /** Expressions that given partition columns build a path string like: col1=val/col2=val/... */
+    private def partitionPathExpression: Seq[Expression] = {
+      desc.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
         // TODO: use correct timezone for partition values.
         val escaped = ScalaUDF(
           ExternalCatalogUtils.escapePathName _,
@@ -325,35 +343,46 @@ object FileFormatWriter extends Logging {
     }
 
     /**
-     * Open and returns a new OutputWriter given a partition key and optional bucket id.
+     * Opens a new OutputWriter given a partition key and optional bucket id.
      * If bucket id is specified, we will append it to the end of the file name, but before the
      * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
      *
-     * @param key vaues for fields consisting of partition keys for the current row
-     * @param partString a function that projects the partition values into a string
+     * @param partColsAndBucketId a row consisting of partition columns and a bucket id for the
+     *                            current row.
+     * @param getPartitionPath a function that projects the partition values into a path string.
      * @param fileCounter the number of files that have been written in the past for this specific
      *                    partition. This is used to limit the max number of records written for a
      *                    single file. The value should start from 0.
+     * @param updatedPartitions the set of updated partition paths, we should add the new partition
+     *                          path of this writer to it.
      */
     private def newOutputWriter(
-        key: InternalRow, partString: UnsafeProjection, fileCounter: Int): Unit = {
-      val partDir =
-        if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0))
+        partColsAndBucketId: InternalRow,
+        getPartitionPath: UnsafeProjection,
+        fileCounter: Int,
+        updatedPartitions: mutable.Set[String]): Unit = {
+      val partDir = if (desc.partitionColumns.isEmpty) {
+        None
+      } else {
+        Option(getPartitionPath(partColsAndBucketId).getString(0))
+      }
+      partDir.foreach(updatedPartitions.add)
 
-      // If the bucket spec is defined, the bucket column is right after the partition columns
-      val bucketId = if (description.bucketSpec.isDefined) {
-        BucketingUtils.bucketIdToString(key.getInt(description.partitionColumns.length))
+      // If the bucketId expression is defined, the bucketId column is right after the partition
+      // columns.
+      val bucketId = if (desc.bucketIdExpression.isDefined) {
+        BucketingUtils.bucketIdToString(partColsAndBucketId.getInt(desc.partitionColumns.length))
       } else {
         ""
       }
 
       // This must be in a form that matches our bucketing format. See BucketingUtils.
       val ext = f"$bucketId.c$fileCounter%03d" +
-        description.outputWriterFactory.getFileExtension(taskAttemptContext)
+        desc.outputWriterFactory.getFileExtension(taskAttemptContext)
 
       val customPath = partDir match {
         case Some(dir) =>
-          description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
+          desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
         case _ =>
           None
       }
@@ -363,80 +392,42 @@ object FileFormatWriter extends Logging {
         committer.newTaskTempFile(taskAttemptContext, partDir, ext)
       }
 
-      currentWriter = description.outputWriterFactory.newInstance(
+      currentWriter = desc.outputWriterFactory.newInstance(
         path = path,
-        dataSchema = description.dataColumns.toStructType,
+        dataSchema = desc.dataColumns.toStructType,
         context = taskAttemptContext)
     }
 
     override def execute(iter: Iterator[InternalRow]): Set[String] = {
-      // We should first sort by partition columns, then bucket id, and finally sorting columns.
-      val sortingExpressions: Seq[Expression] =
-        description.partitionColumns ++ bucketIdExpression ++ sortColumns
-      val getSortingKey = UnsafeProjection.create(sortingExpressions, description.allColumns)
-
-      val sortingKeySchema = StructType(sortingExpressions.map {
-        case a: Attribute => StructField(a.name, a.dataType, a.nullable)
-        // The sorting expressions are all `Attribute` except bucket id.
-        case _ => StructField("bucketId", IntegerType, nullable = false)
-      })
-
-      // Returns the data columns to be written given an input row
-      val getOutputRow = UnsafeProjection.create(
-        description.dataColumns, description.allColumns)
-
-      // Returns the partition path given a partition key.
-      val getPartitionStringFunc = UnsafeProjection.create(
-        Seq(Concat(partitionStringExpression)), description.partitionColumns)
-
-      // Sorts the data before write, so that we only need one writer at the same time.
-      val sorter = new UnsafeKVExternalSorter(
-        sortingKeySchema,
-        StructType.fromAttributes(description.dataColumns),
-        SparkEnv.get.blockManager,
-        SparkEnv.get.serializerManager,
-        TaskContext.get().taskMemoryManager().pageSizeBytes,
-        SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
-          UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD))
-
-      while (iter.hasNext) {
-        val currentRow = iter.next()
-        sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
-      }
+      val getPartitionColsAndBucketId = UnsafeProjection.create(
+        desc.partitionColumns ++ desc.bucketIdExpression, desc.allColumns)
 
-      val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
-        identity
-      } else {
-        UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
-          case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
-        })
-      }
+      // Generates the partition path given the row generated by `getPartitionColsAndBucketId`.
+      val getPartPath = UnsafeProjection.create(
+        Seq(Concat(partitionPathExpression)), desc.partitionColumns)
 
-      val sortedIterator = sorter.sortedIterator()
+      // Returns the data columns to be written given an input row
+      val getOutputRow = UnsafeProjection.create(desc.dataColumns, desc.allColumns)
 
       // If anything below fails, we should abort the task.
       var recordsInFile: Long = 0L
       var fileCounter = 0
-      var currentKey: UnsafeRow = null
+      var currentPartColsAndBucketId: UnsafeRow = null
       val updatedPartitions = mutable.Set[String]()
-      while (sortedIterator.next()) {
-        val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
-        if (currentKey != nextKey) {
-          // See a new key - write to a new partition (new file).
-          currentKey = nextKey.copy()
-          logDebug(s"Writing partition: $currentKey")
+      for (row <- iter) {
+        val nextPartColsAndBucketId = getPartitionColsAndBucketId(row)
+        if (currentPartColsAndBucketId != nextPartColsAndBucketId) {
+          // See a new partition or bucket - write to a new partition dir (or a new bucket file).
+          currentPartColsAndBucketId = nextPartColsAndBucketId.copy()
+          logDebug(s"Writing partition: $currentPartColsAndBucketId")
 
           recordsInFile = 0
           fileCounter = 0
 
           releaseResources()
-          newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
-          val partitionPath = getPartitionStringFunc(currentKey).getString(0)
-          if (partitionPath.nonEmpty) {
-            updatedPartitions.add(partitionPath)
-          }
-        } else if (description.maxRecordsPerFile > 0 &&
-            recordsInFile >= description.maxRecordsPerFile) {
+          newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions)
+        } else if (desc.maxRecordsPerFile > 0 &&
+            recordsInFile >= desc.maxRecordsPerFile) {
           // Exceeded the threshold in terms of the number of records per file.
           // Create a new file by increasing the file counter.
           recordsInFile = 0
@@ -445,10 +436,10 @@ object FileFormatWriter extends Logging {
             s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
 
           releaseResources()
-          newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
+          newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions)
         }
 
-        currentWriter.write(sortedIterator.getValue)
+        currentWriter.write(getOutputRow(row))
         recordsInFile += 1
       }
       releaseResources()