diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index c33d1e33f030f3e7b1558fd65083706818d58f38..338faaadb33d485724db701d917e66a5f7805c2a 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -43,6 +43,7 @@ import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.LongArray;
 import org.apache.spark.unsafe.memory.MemoryBlock;
 import org.apache.spark.util.Utils;
+import org.apache.spark.internal.config.package$;
 
 /**
  * An external sorter that is specialized for sort-based shuffle.
@@ -82,6 +83,9 @@ final class ShuffleExternalSorter extends MemoryConsumer {
   /** The buffer size to use when writing spills using DiskBlockObjectWriter */
   private final int fileBufferSizeBytes;
 
+  /** The buffer size to use when writing the sorted records to an on-disk file */
+  private final int diskWriteBufferSize;
+
   /**
    * Memory pages that hold the records being sorted. The pages in this list are freed when
    * spilling, although in principle we could recycle these pages across spills (on the other hand,
@@ -116,13 +120,14 @@ final class ShuffleExternalSorter extends MemoryConsumer {
     this.taskContext = taskContext;
     this.numPartitions = numPartitions;
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
-    this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+    this.fileBufferSizeBytes = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
     this.numElementsForSpillThreshold =
       conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", 1024 * 1024 * 1024);
     this.writeMetrics = writeMetrics;
     this.inMemSorter = new ShuffleInMemorySorter(
       this, initialSize, conf.getBoolean("spark.shuffle.sort.useRadixSort", true));
     this.peakMemoryUsedBytes = getMemoryUsage();
+    this.diskWriteBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE());
   }
 
   /**
@@ -155,7 +160,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
     // be an API to directly transfer bytes from managed memory to the disk writer, we buffer
     // data through a byte array. This array does not need to be large enough to hold a single
     // record;
-    final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];
+    final byte[] writeBuffer = new byte[diskWriteBufferSize];
 
     // Because this output will be read during shuffle, its compression codec must be controlled by
     // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
@@ -195,7 +200,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
       int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage);
       long recordReadPosition = recordOffsetInPage + 4; // skip over record length
       while (dataRemaining > 0) {
-        final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining);
+        final int toTransfer = Math.min(diskWriteBufferSize, dataRemaining);
         Platform.copyMemory(
           recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer);
         writer.write(writeBuffer, 0, toTransfer);
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 34c179990214f6922d5aa2211d5ec493094092f0..1b578491b81d7a05c92c550fedc40c5d8ffcd5be 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -55,6 +55,7 @@ import org.apache.spark.storage.BlockManager;
 import org.apache.spark.storage.TimeTrackingOutputStream;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.util.Utils;
+import org.apache.spark.internal.config.package$;
 
 @Private
 public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@@ -65,6 +66,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
 
   @VisibleForTesting
   static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096;
+  static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024;
 
   private final BlockManager blockManager;
   private final IndexShuffleBlockResolver shuffleBlockResolver;
@@ -78,6 +80,8 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
   private final SparkConf sparkConf;
   private final boolean transferToEnabled;
   private final int initialSortBufferSize;
+  private final int inputBufferSizeInBytes;
+  private final int outputBufferSizeInBytes;
 
   @Nullable private MapStatus mapStatus;
   @Nullable private ShuffleExternalSorter sorter;
@@ -140,6 +144,10 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
     this.initialSortBufferSize = sparkConf.getInt("spark.shuffle.sort.initialBufferSize",
                                                   DEFAULT_INITIAL_SORT_BUFFER_SIZE);
+    this.inputBufferSizeInBytes =
+      (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
+    this.outputBufferSizeInBytes =
+      (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024;
     open();
   }
 
@@ -209,7 +217,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
       partitioner.numPartitions(),
       sparkConf,
       writeMetrics);
-    serBuffer = new MyByteArrayOutputStream(1024 * 1024);
+    serBuffer = new MyByteArrayOutputStream(DEFAULT_INITIAL_SER_BUFFER_SIZE);
     serOutputStream = serializer.serializeStream(serBuffer);
   }
 
@@ -360,12 +368,10 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
 
     final OutputStream bos = new BufferedOutputStream(
             new FileOutputStream(outputFile),
-            (int) sparkConf.getSizeAsKb("spark.shuffle.unsafe.file.output.buffer", "32k") * 1024);
+            outputBufferSizeInBytes);
     // Use a counting output stream to avoid having to close the underlying file and ask
     // the file system for its size after each partition is written.
     final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(bos);
-    final int inputBufferSizeInBytes =
-      (int) sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
 
     boolean threwException = true;
     try {
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
index 164b9d70b79d711ad2b13a0f41d87d595529ffcd..f9b5493755443c4034eee49fe08750b7fbc3f7b1 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -20,9 +20,10 @@ package org.apache.spark.util.collection.unsafe.sort;
 import java.io.File;
 import java.io.IOException;
 
-import org.apache.spark.serializer.SerializerManager;
 import scala.Tuple2;
 
+import org.apache.spark.SparkConf;
+import org.apache.spark.serializer.SerializerManager;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.serializer.DummySerializerInstance;
 import org.apache.spark.storage.BlockId;
@@ -30,6 +31,7 @@ import org.apache.spark.storage.BlockManager;
 import org.apache.spark.storage.DiskBlockObjectWriter;
 import org.apache.spark.storage.TempLocalBlockId;
 import org.apache.spark.unsafe.Platform;
+import org.apache.spark.internal.config.package$;
 
 /**
  * Spills a list of sorted records to disk. Spill files have the following format:
@@ -38,12 +40,16 @@ import org.apache.spark.unsafe.Platform;
  */
 public final class UnsafeSorterSpillWriter {
 
-  static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
+  private final SparkConf conf = new SparkConf();
+
+  /** The buffer size to use when writing the sorted records to an on-disk file */
+  private final int diskWriteBufferSize =
+    (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE());
 
   // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
   // be an API to directly transfer bytes from managed memory to the disk writer, we buffer
   // data through a byte array.
-  private byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];
+  private byte[] writeBuffer = new byte[diskWriteBufferSize];
 
   private final File file;
   private final BlockId blockId;
@@ -114,7 +120,7 @@ public final class UnsafeSorterSpillWriter {
     writeIntToBuffer(recordLength, 0);
     writeLongToBuffer(keyPrefix, 4);
     int dataRemaining = recordLength;
-    int freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE - 4 - 8; // space used by prefix + len
+    int freeSpaceInWriteBuffer = diskWriteBufferSize - 4 - 8; // space used by prefix + len
     long recordReadPosition = baseOffset;
     while (dataRemaining > 0) {
       final int toTransfer = Math.min(freeSpaceInWriteBuffer, dataRemaining);
@@ -122,15 +128,15 @@ public final class UnsafeSorterSpillWriter {
         baseObject,
         recordReadPosition,
         writeBuffer,
-        Platform.BYTE_ARRAY_OFFSET + (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer),
+        Platform.BYTE_ARRAY_OFFSET + (diskWriteBufferSize - freeSpaceInWriteBuffer),
         toTransfer);
-      writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer) + toTransfer);
+      writer.write(writeBuffer, 0, (diskWriteBufferSize - freeSpaceInWriteBuffer) + toTransfer);
       recordReadPosition += toTransfer;
       dataRemaining -= toTransfer;
-      freeSpaceInWriteBuffer = DISK_WRITE_BUFFER_SIZE;
+      freeSpaceInWriteBuffer = diskWriteBufferSize;
     }
-    if (freeSpaceInWriteBuffer < DISK_WRITE_BUFFER_SIZE) {
-      writer.write(writeBuffer, 0, (DISK_WRITE_BUFFER_SIZE - freeSpaceInWriteBuffer));
+    if (freeSpaceInWriteBuffer < diskWriteBufferSize) {
+      writer.write(writeBuffer, 0, (diskWriteBufferSize - freeSpaceInWriteBuffer));
     }
     writer.recordWritten();
   }
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 8dee0d970c4c6ed439eb0bd78b865341b42114b8..a629810bf093a3f6bcc7e2f237e7546a723a146b 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -336,4 +336,31 @@ package object config {
         "spark.")
       .booleanConf
       .createWithDefault(false)
+
+  private[spark] val SHUFFLE_FILE_BUFFER_SIZE =
+    ConfigBuilder("spark.shuffle.file.buffer")
+      .doc("Size of the in-memory buffer for each shuffle file output stream. " +
+        "These buffers reduce the number of disk seeks and system calls made " +
+        "in creating intermediate shuffle files.")
+      .bytesConf(ByteUnit.KiB)
+      .checkValue(v => v > 0 && v <= Int.MaxValue / 1024,
+        s"The file buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.")
+      .createWithDefaultString("32k")
+
+  private[spark] val SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE =
+    ConfigBuilder("spark.shuffle.unsafe.file.output.buffer")
+      .doc("The file system for this buffer size after each partition " +
+        "is written in unsafe shuffle writer.")
+      .bytesConf(ByteUnit.KiB)
+      .checkValue(v => v > 0 && v <= Int.MaxValue / 1024,
+        s"The buffer size must be greater than 0 and less than ${Int.MaxValue / 1024}.")
+      .createWithDefaultString("32k")
+
+  private[spark] val SHUFFLE_DISK_WRITE_BUFFER_SIZE =
+    ConfigBuilder("spark.shuffle.spill.diskWriteBufferSize")
+      .doc("The buffer size to use when writing the sorted records to an on-disk file.")
+      .bytesConf(ByteUnit.BYTE)
+      .checkValue(v => v > 0 && v <= Int.MaxValue,
+        s"The buffer size must be greater than 0 and less than ${Int.MaxValue}.")
+      .createWithDefault(1024 * 1024)
 }