From c068d9084c94cdd1aeee2c6ad6f55148bc4527ce Mon Sep 17 00:00:00 2001
From: zsxwing <zsxwing@gmail.com>
Date: Sun, 5 Oct 2014 09:55:17 -0700
Subject: [PATCH] SPARK-1656: Fix potential resource leaks

JIRA: https://issues.apache.org/jira/browse/SPARK-1656

Author: zsxwing <zsxwing@gmail.com>

Closes #577 from zsxwing/SPARK-1656 and squashes the following commits:

c431095 [zsxwing] Add a comment and fix the code style
2de96e5 [zsxwing] Make sure file will be deleted if exception happens
28b90dc [zsxwing] Update to follow the code style
4521d6e [zsxwing] Merge branch 'master' into SPARK-1656
afc3383 [zsxwing] Update to follow the code style
071fdd1 [zsxwing] SPARK-1656: Fix potential resource leaks

(cherry picked from commit a7c73130f1b6b0b8b19a7b0a0de5c713b673cd7b)
Signed-off-by: Andrew Or <andrewor14@gmail.com>
---
 .../spark/broadcast/HttpBroadcast.scala       | 25 +++++++++++--------
 .../master/FileSystemPersistenceEngine.scala  | 14 ++++++++---
 .../org/apache/spark/storage/DiskStore.scala  | 16 +++++++++++-
 3 files changed, 40 insertions(+), 15 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 942dc7d7ea..4cd4f4f96f 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -163,18 +163,23 @@ private[broadcast] object HttpBroadcast extends Logging {
 
   private def write(id: Long, value: Any) {
     val file = getFile(id)
-    val out: OutputStream = {
-      if (compress) {
-        compressionCodec.compressedOutputStream(new FileOutputStream(file))
-      } else {
-        new BufferedOutputStream(new FileOutputStream(file), bufferSize)
+    val fileOutputStream = new FileOutputStream(file)
+    try {
+      val out: OutputStream = {
+        if (compress) {
+          compressionCodec.compressedOutputStream(fileOutputStream)
+        } else {
+          new BufferedOutputStream(fileOutputStream, bufferSize)
+        }
       }
+      val ser = SparkEnv.get.serializer.newInstance()
+      val serOut = ser.serializeStream(out)
+      serOut.writeObject(value)
+      serOut.close()
+      files += file
+    } finally {
+      fileOutputStream.close()
     }
-    val ser = SparkEnv.get.serializer.newInstance()
-    val serOut = ser.serializeStream(out)
-    serOut.writeObject(value)
-    serOut.close()
-    files += file
   }
 
   private def read[T: ClassTag](id: Long): T = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
index aa85aa060d..08a99bbe68 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -83,15 +83,21 @@ private[spark] class FileSystemPersistenceEngine(
     val serialized = serializer.toBinary(value)
 
     val out = new FileOutputStream(file)
-    out.write(serialized)
-    out.close()
+    try {
+      out.write(serialized)
+    } finally {
+      out.close()
+    }
   }
 
   def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = {
     val fileData = new Array[Byte](file.length().asInstanceOf[Int])
     val dis = new DataInputStream(new FileInputStream(file))
-    dis.readFully(fileData)
-    dis.close()
+    try {
+      dis.readFully(fileData)
+    } finally {
+      dis.close()
+    }
 
     val clazz = m.runtimeClass.asInstanceOf[Class[T]]
     val serializer = serialization.serializerFor(clazz)
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index c83261dd91..295c706708 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -73,7 +73,21 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc
     val startTime = System.currentTimeMillis
     val file = diskManager.getFile(blockId)
     val outputStream = new FileOutputStream(file)
-    blockManager.dataSerializeStream(blockId, outputStream, values)
+    try {
+      try {
+        blockManager.dataSerializeStream(blockId, outputStream, values)
+      } finally {
+        // Close outputStream here because it should be closed before file is deleted.
+        outputStream.close()
+      }
+    } catch {
+      case e: Throwable =>
+        if (file.exists()) {
+          file.delete()
+        }
+        throw e
+    }
+
     val length = file.length
 
     val timeTaken = System.currentTimeMillis - startTime
-- 
GitLab