diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 078cc3d5b4f0cfff05bdd6194a81fc557df9dd87..0dcf0307e113f159e6a8dba6a5a73d9f347d35fe 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -237,9 +237,11 @@ private[spark] object Utils extends Logging {
     if (bb.hasArray) {
       out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
     } else {
+      val originalPosition = bb.position()
       val bbval = new Array[Byte](bb.remaining())
       bb.get(bbval)
       out.write(bbval)
+      bb.position(originalPosition)
     }
   }
 
@@ -250,9 +252,11 @@ private[spark] object Utils extends Logging {
     if (bb.hasArray) {
       out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
     } else {
+      val originalPosition = bb.position()
       val bbval = new Array[Byte](bb.remaining())
       bb.get(bbval)
       out.write(bbval)
+      bb.position(originalPosition)
     }
   }
 
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index fb7b91222b499dde7fae207dac3b81380e981311..442a603cae7915e211ed3b992350f55159a88636 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -17,7 +17,8 @@
 
 package org.apache.spark.util
 
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream, PrintStream}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataOutput, DataOutputStream, File,
+  FileOutputStream, PrintStream}
 import java.lang.{Double => JDouble, Float => JFloat}
 import java.net.{BindException, ServerSocket, URI}
 import java.nio.{ByteBuffer, ByteOrder}
@@ -389,6 +390,28 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
     assert(Utils.deserializeLongValue(bbuf.array) === testval)
   }
 
+  test("writeByteBuffer should not change ByteBuffer position") {
+    // Test a buffer with an underlying array, for both writeByteBuffer methods.
+    val testBuffer = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4))
+    assert(testBuffer.hasArray)
+    val bytesOut = new ByteBufferOutputStream(4096)
+    Utils.writeByteBuffer(testBuffer, bytesOut)
+    assert(testBuffer.position() === 0)
+
+    val dataOut = new DataOutputStream(bytesOut)
+    Utils.writeByteBuffer(testBuffer, dataOut: DataOutput)
+    assert(testBuffer.position() === 0)
+
+    // Test a buffer without an underlying array, for both writeByteBuffer methods.
+    val testDirectBuffer = ByteBuffer.allocateDirect(8)
+    assert(!testDirectBuffer.hasArray())
+    Utils.writeByteBuffer(testDirectBuffer, bytesOut)
+    assert(testDirectBuffer.position() === 0)
+
+    Utils.writeByteBuffer(testDirectBuffer, dataOut: DataOutput)
+    assert(testDirectBuffer.position() === 0)
+  }
+
   test("get iterator size") {
     val empty = Seq[Int]()
     assert(Utils.getIteratorSize(empty.toIterator) === 0L)