diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java b/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java
index 922c37a10efdd7a5f78e99db544b06ddcd31bfdb..e79eef03258976600c40716d318022ae6a216eb4 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java
@@ -48,11 +48,27 @@ import com.google.common.base.Preconditions;
  * use this functionality in both a Guava 11 environment and a Guava >14 environment.
  */
 public final class LimitedInputStream extends FilterInputStream {
+  private final boolean closeWrappedStream;
   private long left;
   private long mark = -1;
 
   public LimitedInputStream(InputStream in, long limit) {
+    this(in, limit, true);
+  }
+
+  /**
+   * Create a LimitedInputStream that will read {@code limit} bytes from {@code in}.
+   * <p>
+   * If {@code closeWrappedStream} is true, this will close {@code in} when it is closed.
+   * Otherwise, the stream is left open for reading its remaining content.
+   *
+   * @param in a {@link InputStream} to read from
+   * @param limit the number of bytes to read
+   * @param closeWrappedStream whether to close {@code in} when {@link #close} is called
+     */
+  public LimitedInputStream(InputStream in, long limit, boolean closeWrappedStream) {
     super(in);
+    this.closeWrappedStream = closeWrappedStream;
     Preconditions.checkNotNull(in);
     Preconditions.checkArgument(limit >= 0, "limit must be non-negative");
     left = limit;
@@ -102,4 +118,11 @@ public final class LimitedInputStream extends FilterInputStream {
     left -= skipped;
     return skipped;
   }
+
+  @Override
+  public void close() throws IOException {
+    if (closeWrappedStream) {
+      super.close();
+    }
+  }
 }
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 05fa04c44d4f5086de00924cd6b2bee8c4608e39..08fb887bbd095ae9952f2d3e6ab93df287d5d284 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -349,12 +349,19 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
         for (int i = 0; i < spills.length; i++) {
           final long partitionLengthInSpill = spills[i].partitionLengths[partition];
           if (partitionLengthInSpill > 0) {
-            InputStream partitionInputStream =
-              new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill);
-            if (compressionCodec != null) {
-              partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
+            InputStream partitionInputStream = null;
+            boolean innerThrewException = true;
+            try {
+              partitionInputStream =
+                  new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false);
+              if (compressionCodec != null) {
+                partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
+              }
+              ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
+              innerThrewException = false;
+            } finally {
+              Closeables.close(partitionInputStream, innerThrewException);
             }
-            ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
           }
         }
         mergedFileOutputStream.flush();
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 632b0ae9c2c37e12084cb0b7b0260b675a0757e8..e8d6d587b4824173b7d18da1751aa06269a9cffc 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -232,7 +232,11 @@ private object TorrentBroadcast extends Logging {
     val out = compressionCodec.map(c => c.compressedOutputStream(cbbos)).getOrElse(cbbos)
     val ser = serializer.newInstance()
     val serOut = ser.serializeStream(out)
-    serOut.writeObject[T](obj).close()
+    Utils.tryWithSafeFinally {
+      serOut.writeObject[T](obj)
+    } {
+      serOut.close()
+    }
     cbbos.toChunkedByteBuffer.getChunks()
   }
 
@@ -246,8 +250,11 @@ private object TorrentBroadcast extends Logging {
     val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)
     val ser = serializer.newInstance()
     val serIn = ser.deserializeStream(in)
-    val obj = serIn.readObject[T]()
-    serIn.close()
+    val obj = Utils.tryWithSafeFinally {
+      serIn.readObject[T]()
+    } {
+      serIn.close()
+    }
     obj
   }
 
diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
index d17a7894fd8a8d288106d534244b321480241535..f0ed41f6903f4ad74005eb206298aa7d4341a856 100644
--- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
@@ -32,6 +32,7 @@ import org.apache.commons.io.IOUtils
 
 import org.apache.spark.{SparkEnv, SparkException}
 import org.apache.spark.io.CompressionCodec
+import org.apache.spark.util.Utils
 
 /**
  * Custom serializer used for generic Avro records. If the user registers the schemas
@@ -72,8 +73,11 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String])
   def compress(schema: Schema): Array[Byte] = compressCache.getOrElseUpdate(schema, {
     val bos = new ByteArrayOutputStream()
     val out = codec.compressedOutputStream(bos)
-    out.write(schema.toString.getBytes(StandardCharsets.UTF_8))
-    out.close()
+    Utils.tryWithSafeFinally {
+      out.write(schema.toString.getBytes(StandardCharsets.UTF_8))
+    } {
+      out.close()
+    }
     bos.toByteArray
   })
 
@@ -86,7 +90,12 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String])
       schemaBytes.array(),
       schemaBytes.arrayOffset() + schemaBytes.position(),
       schemaBytes.remaining())
-    val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis))
+    val in = codec.compressedInputStream(bis)
+    val bytes = Utils.tryWithSafeFinally {
+      IOUtils.toByteArray(in)
+    } {
+      in.close()
+    }
     new Schema.Parser().parse(new String(bytes, StandardCharsets.UTF_8))
   })