diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index a759ceb96ec1e7804894c05f4aa7056001513a3b..0d0448feb5b06d3b9c94fa085efd3424d803f132 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -260,10 +260,7 @@ final class ShuffleBlockFetcherIterator(
     fetchRequests ++= Utils.randomize(remoteRequests)
 
     // Send out initial requests for blocks, up to our maxBytesInFlight
-    while (fetchRequests.nonEmpty &&
-      (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
-      sendRequest(fetchRequests.dequeue())
-    }
+    fetchUpToMaxBytes()
 
     val numFetches = remoteRequests.size - fetchRequests.size
     logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
@@ -296,10 +293,7 @@ final class ShuffleBlockFetcherIterator(
       case _ =>
     }
     // Send fetch requests up to maxBytesInFlight
-    while (fetchRequests.nonEmpty &&
-      (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
-      sendRequest(fetchRequests.dequeue())
-    }
+    fetchUpToMaxBytes()
 
     result match {
       case FailureFetchResult(blockId, address, e) =>
@@ -315,6 +309,14 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
+  private def fetchUpToMaxBytes(): Unit = {
+    // Send fetch requests up to maxBytesInFlight
+    while (fetchRequests.nonEmpty &&
+      (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+      sendRequest(fetchRequests.dequeue())
+    }
+  }
+
   private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = {
     blockId match {
       case ShuffleBlockId(shufId, mapId, reduceId) =>