diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index daa63d47e6aed02d72a11cb1b5d7946c65ff88bc..05fa04c44d4f5086de00924cd6b2bee8c4608e39 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -61,7 +61,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
   private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
 
   @VisibleForTesting
-  static final int INITIAL_SORT_BUFFER_SIZE = 4096;
+  static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096;
 
   private final BlockManager blockManager;
   private final IndexShuffleBlockResolver shuffleBlockResolver;
@@ -74,6 +74,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
   private final TaskContext taskContext;
   private final SparkConf sparkConf;
   private final boolean transferToEnabled;
+  private final int initialSortBufferSize;
 
   @Nullable private MapStatus mapStatus;
   @Nullable private ShuffleExternalSorter sorter;
@@ -122,6 +123,8 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     this.taskContext = taskContext;
     this.sparkConf = sparkConf;
     this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
+    this.initialSortBufferSize = sparkConf.getInt("spark.shuffle.sort.initialBufferSize",
+                                                  DEFAULT_INITIAL_SORT_BUFFER_SIZE);
     open();
   }
 
@@ -187,7 +190,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
       memoryManager,
       blockManager,
       taskContext,
-      INITIAL_SORT_BUFFER_SIZE,
+      initialSortBufferSize,
       partitioner.numPartitions(),
       sparkConf,
       writeMetrics);
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 7dd61f85abefd40a64597c6587c545a5164deb57..daeb4675ea5f53ed35286ac0ba26e38f46bb18fb 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -413,10 +413,10 @@ public class UnsafeShuffleWriterSuite {
   }
 
   private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
-    memoryManager.limit(UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE * 16);
+    memoryManager.limit(UnsafeShuffleWriter.DEFAULT_INITIAL_SORT_BUFFER_SIZE * 16);
     final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
     final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
-    for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE + 1; i++) {
+    for (int i = 0; i < UnsafeShuffleWriter.DEFAULT_INITIAL_SORT_BUFFER_SIZE + 1; i++) {
       dataToWrite.add(new Tuple2<Object, Object>(i, i));
     }
     writer.write(dataToWrite.iterator());
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index ad76bf5a0a81cb864f37eca703f9147e90405399..0b177ad4112eadab064fc423234f1fb9c117bbc6 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -38,6 +38,7 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;
 
 public final class UnsafeExternalRowSorter {
 
+  static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096;
   /**
    * If positive, forces records to be spilled to disk at the given frequency (measured in numbers
    * of records). This is only intended to be used in tests.
@@ -85,7 +86,8 @@ public final class UnsafeExternalRowSorter {
       taskContext,
       new RowComparator(ordering, schema.length()),
       prefixComparator,
-      /* initialSize */ 4096,
+      sparkEnv.conf().getInt("spark.shuffle.sort.initialBufferSize",
+                             DEFAULT_INITIAL_SORT_BUFFER_SIZE),
       pageSizeBytes,
       canUseRadixSort
     );
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 99fe51db68aebfa5b4d8dd6e4ac874754660516d..b1cc52336321cb9f450f58f7cc2877741cbb1138 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -22,6 +22,7 @@ import java.io.IOException;
 
 import com.google.common.annotations.VisibleForTesting;
 
+import org.apache.spark.SparkEnv;
 import org.apache.spark.TaskContext;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.serializer.SerializerManager;
@@ -86,7 +87,8 @@ public final class UnsafeKVExternalSorter {
         taskContext,
         recordComparator,
         prefixComparator,
-        /* initialSize */ 4096,
+        SparkEnv.get().conf().getInt("spark.shuffle.sort.initialBufferSize",
+                                     UnsafeExternalRowSorter.DEFAULT_INITIAL_SORT_BUFFER_SIZE),
         pageSizeBytes,
         canUseRadixSort);
     } else {
@@ -131,7 +133,8 @@ public final class UnsafeKVExternalSorter {
         taskContext,
         new KVComparator(ordering, keySchema.length()),
         prefixComparator,
-        /* initialSize */ 4096,
+        SparkEnv.get().conf().getInt("spark.shuffle.sort.initialBufferSize",
+                                     UnsafeExternalRowSorter.DEFAULT_INITIAL_SORT_BUFFER_SIZE),
         pageSizeBytes,
         inMemSorter);