diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 9495cd2835f97dbf82f1384875d0ab0e9355a5bb..0457a66af8e8947cdd71ad1eba9d8cf672bd9b23 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -293,6 +293,15 @@ package object config {
       .booleanConf
       .createWithDefault(false)
 
+  private[spark] val BUFFER_WRITE_CHUNK_SIZE =
+    ConfigBuilder("spark.buffer.write.chunkSize")
+      .internal()
+      .doc("The chunk size during writing out the bytes of ChunkedByteBuffer.")
+      .bytesConf(ByteUnit.BYTE)
+      .checkValue(_ <= Int.MaxValue, "The chunk size during writing out the bytes of" +
+        " ChunkedByteBuffer should not larger than Int.MaxValue.")
+      .createWithDefault(64 * 1024 * 1024)
+
   private[spark] val CHECKPOINT_COMPRESS =
     ConfigBuilder("spark.checkpoint.compress")
       .doc("Whether to compress RDD checkpoints. Generally a good idea. Compression will use " +
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
index f48bfd5c25f77de9ade3d48f2f3f8f906aa59673..c28570fb24560102b30bcf65931b04b88b828749 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
@@ -24,6 +24,8 @@ import java.nio.channels.WritableByteChannel
 import com.google.common.primitives.UnsignedBytes
 import io.netty.buffer.{ByteBuf, Unpooled}
 
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.config
 import org.apache.spark.network.util.ByteArrayWritableChannel
 import org.apache.spark.storage.StorageUtils
 
@@ -40,6 +42,11 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
   require(chunks != null, "chunks must not be null")
   require(chunks.forall(_.position() == 0), "chunks' positions must be 0")
 
+  // Chunk size in bytes
+  private val bufferWriteChunkSize =
+    Option(SparkEnv.get).map(_.conf.get(config.BUFFER_WRITE_CHUNK_SIZE))
+      .getOrElse(config.BUFFER_WRITE_CHUNK_SIZE.defaultValue.get).toInt
+
   private[this] var disposed: Boolean = false
 
   /**
@@ -56,7 +63,9 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
    */
   def writeFully(channel: WritableByteChannel): Unit = {
     for (bytes <- getChunks()) {
-      while (bytes.remaining > 0) {
+      while (bytes.remaining() > 0) {
+        val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize)
+        bytes.limit(bytes.position + ioSize)
         channel.write(bytes)
       }
     }