diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
index 215a8517e8608180cbea3b5180b99e9a91140772..d686a951467cf3d599818b500d4c2e74dd93f70e 100644
--- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java
@@ -21,15 +21,15 @@ import java.io.IOException;
 import java.nio.channels.WritableByteChannel;
 
 import com.google.common.base.Preconditions;
-import com.google.common.primitives.Ints;
 import io.netty.buffer.ByteBuf;
 import io.netty.channel.FileRegion;
 import io.netty.util.AbstractReferenceCounted;
 import io.netty.util.ReferenceCountUtil;
 
 /**
- * A wrapper message that holds two separate pieces (a header and a body) to avoid
- * copying the body's content.
+ * A wrapper message that holds two separate pieces (a header and a body).
+ *
+ * The header must be a ByteBuf, while the body can be a ByteBuf or a FileRegion.
  */
 class MessageWithHeader extends AbstractReferenceCounted implements FileRegion {
 
@@ -63,32 +63,36 @@ class MessageWithHeader extends AbstractReferenceCounted implements FileRegion {
     return totalBytesTransferred;
   }
 
+  /**
+   * This code is more complicated than you would think because we might require multiple
+   * transferTo invocations in order to transfer a single MessageWithHeader to avoid busy waiting.
+   *
+   * The contract is that the caller will ensure position is properly set to the total number
+   * of bytes transferred so far (i.e. value returned by transfered()).
+   */
   @Override
-  public long transferTo(WritableByteChannel target, long position) throws IOException {
+  public long transferTo(final WritableByteChannel target, final long position) throws IOException {
     Preconditions.checkArgument(position == totalBytesTransferred, "Invalid position.");
-    long written = 0;
-
-    if (position < headerLength) {
-      written += copyByteBuf(header, target);
+    // Bytes written for header in this call.
+    long writtenHeader = 0;
+    if (header.readableBytes() > 0) {
+      writtenHeader = copyByteBuf(header, target);
+      totalBytesTransferred += writtenHeader;
       if (header.readableBytes() > 0) {
-        totalBytesTransferred += written;
-        return written;
+        return writtenHeader;
       }
     }
 
+    // Bytes written for body in this call.
+    long writtenBody = 0;
     if (body instanceof FileRegion) {
-      // Adjust the position. If the write is happening as part of the same call where the header
-      // (or some part of it) is written, `position` will be less than the header size, so we want
-      // to start from position 0 in the FileRegion object. Otherwise, we start from the position
-      // requested by the caller.
-      long bodyPos = position > headerLength ? position - headerLength : 0;
-      written += ((FileRegion)body).transferTo(target, bodyPos);
+      writtenBody = ((FileRegion) body).transferTo(target, totalBytesTransferred - headerLength);
     } else if (body instanceof ByteBuf) {
-      written += copyByteBuf((ByteBuf) body, target);
+      writtenBody = copyByteBuf((ByteBuf) body, target);
     }
+    totalBytesTransferred += writtenBody;
 
-    totalBytesTransferred += written;
-    return written;
+    return writtenHeader + writtenBody;
   }
 
   @Override
@@ -102,5 +106,4 @@ class MessageWithHeader extends AbstractReferenceCounted implements FileRegion {
     buf.skipBytes(written);
     return written;
   }
-
 }