From b5a59a0fe2ea703ce2712561e7b9044f772660a2 Mon Sep 17 00:00:00 2001
From: Davies Liu <davies@databricks.com>
Date: Wed, 2 Mar 2016 14:35:44 -0800
Subject: [PATCH] [SPARK-13601] call failure callbacks before writer.close()

## What changes were proposed in this pull request?

In order to tell OutputStream that the task has failed or not, we should call the failure callbacks BEFORE calling writer.close().

## How was this patch tested?

Added new unit tests.

Author: Davies Liu <davies@databricks.com>

Closes #11450 from davies/callback.
---
 .../org/apache/spark/TaskContextImpl.scala    | 10 +-
 .../apache/spark/rdd/PairRDDFunctions.scala   |  4 +-
 .../scala/org/apache/spark/util/Utils.scala   | 39 +++++++-
 .../spark/rdd/PairRDDFunctionsSuite.scala     | 91 +++++++++++++++++-
 .../datasources/WriterContainer.scala         | 95 ++++++++++---------
 .../CommitFailureTestRelationSuite.scala      | 57 +++++++++++
 .../sql/sources/SimpleTextRelation.scala      | 28 +++++-
 7 files changed, 271 insertions(+), 53 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 65f6f741f7..7e96040bc4 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -53,6 +53,9 @@ private[spark] class TaskContextImpl(
   // Whether the task has completed.
   @volatile private var completed: Boolean = false
 
+  // Whether the task has failed.
+  @volatile private var failed: Boolean = false
+
   override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
     onCompleteCallbacks += listener
     this
@@ -63,10 +66,13 @@ private[spark] class TaskContextImpl(
     this
   }
 
-  /** Marks the task as completed and triggers the failure listeners. */
+  /** Marks the task as failed and triggers the failure listeners. */
   private[spark] def markTaskFailed(error: Throwable): Unit = {
+    // failure callbacks should only be called once
+    if (failed) return
+    failed = true
     val errorMsgs = new ArrayBuffer[String](2)
-    // Process complete callbacks in the reverse order of registration
+    // Process failure callbacks in the reverse order of registration
     onFailureCallbacks.reverse.foreach { listener =>
       try {
         listener.onTaskFailure(this, error)
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index e00b9f6cfd..91460dc406 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -1101,7 +1101,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
       val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]]
       require(writer != null, "Unable to obtain RecordWriter")
       var recordsWritten = 0L
-      Utils.tryWithSafeFinally {
+      Utils.tryWithSafeFinallyAndFailureCallbacks {
         while (iter.hasNext) {
           val pair = iter.next()
           writer.write(pair._1, pair._2)
@@ -1190,7 +1190,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
       writer.open()
       var recordsWritten = 0L
 
-      Utils.tryWithSafeFinally {
+      Utils.tryWithSafeFinallyAndFailureCallbacks {
         while (iter.hasNext) {
           val record = iter.next()
           writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 6103a10ccc..cfe247c668 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1241,7 +1241,6 @@ private[spark] object Utils extends Logging {
    * exception from the original `out.write` call.
    */
   def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = {
-    // It would be nice to find a method on Try that did this
     var originalThrowable: Throwable = null
     try {
       block
@@ -1267,6 +1266,44 @@ private[spark] object Utils extends Logging {
     }
   }
 
+  /**
+   * Execute a block of code, call the failure callbacks before finally block if there is any
+   * exceptions happen. But if exceptions happen in the finally block, do not suppress the original
+   * exception.
+   *
+   * This is primarily an issue with `finally { out.close() }` blocks, where
+   * close needs to be called to clean up `out`, but if an exception happened
+   * in `out.write`, it's likely `out` may be corrupted and `out.close` will
+   * fail as well. This would then suppress the original/likely more meaningful
+   * exception from the original `out.write` call.
+   */
+  def tryWithSafeFinallyAndFailureCallbacks[T](block: => T)(finallyBlock: => Unit): T = {
+    var originalThrowable: Throwable = null
+    try {
+      block
+    } catch {
+      case t: Throwable =>
+        // Purposefully not using NonFatal, because even fatal exceptions
+        // we don't want to have our finallyBlock suppress
+        originalThrowable = t
+        TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(t)
+        throw originalThrowable
+    } finally {
+      try {
+        finallyBlock
+      } catch {
+        case t: Throwable =>
+          if (originalThrowable != null) {
+            originalThrowable.addSuppressed(t)
+            logWarning(s"Suppressing exception in finally: " + t.getMessage, t)
+            throw originalThrowable
+          } else {
+            throw t
+          }
+      }
+    }
+  }
+
   /** Default filtering function for finding call sites using `getCallSite`. */
   private def sparkInternalExclusionFunction(className: String): Boolean = {
     // A regular expression to match classes of the internal Spark API's
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 16e2d2e636..7d51538d92 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.rdd
 
+import java.io.IOException
+
 import scala.collection.mutable.{ArrayBuffer, HashSet}
 import scala.util.Random
 
@@ -29,7 +31,8 @@ import org.apache.hadoop.mapreduce.{JobContext => NewJobContext,
   RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext}
 import org.apache.hadoop.util.Progressable
 
-import org.apache.spark.{Partitioner, SharedSparkContext, SparkFunSuite}
+import org.apache.spark._
+import org.apache.spark.Partitioner
 import org.apache.spark.util.Utils
 
 class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
@@ -533,6 +536,38 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
     assert(FakeOutputCommitter.ran, "OutputCommitter was never called")
   }
 
+  test("failure callbacks should be called before calling writer.close() in saveNewAPIHadoopFile") {
+    val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1)
+
+    FakeWriterWithCallback.calledBy = ""
+    FakeWriterWithCallback.exception = null
+    val e = intercept[SparkException] {
+      pairs.saveAsNewAPIHadoopFile[NewFakeFormatWithCallback]("ignored")
+    }
+    assert(e.getMessage contains "failed to write")
+
+    assert(FakeWriterWithCallback.calledBy === "write,callback,close")
+    assert(FakeWriterWithCallback.exception != null, "exception should be captured")
+    assert(FakeWriterWithCallback.exception.getMessage contains "failed to write")
+  }
+
+  test("failure callbacks should be called before calling writer.close() in saveAsHadoopFile") {
+    val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1)
+    val conf = new JobConf()
+
+    FakeWriterWithCallback.calledBy = ""
+    FakeWriterWithCallback.exception = null
+    val e = intercept[SparkException] {
+      pairs.saveAsHadoopFile(
+        "ignored", pairs.keyClass, pairs.valueClass, classOf[FakeFormatWithCallback], conf)
+    }
+    assert(e.getMessage contains "failed to write")
+
+    assert(FakeWriterWithCallback.calledBy === "write,callback,close")
+    assert(FakeWriterWithCallback.exception != null, "exception should be captured")
+    assert(FakeWriterWithCallback.exception.getMessage contains "failed to write")
+  }
+
   test("lookup") {
     val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (5, 7)))
 
@@ -776,6 +811,60 @@ class NewFakeFormat() extends NewOutputFormat[Integer, Integer]() {
   }
 }
 
+object FakeWriterWithCallback {
+  var calledBy: String = ""
+  var exception: Throwable = _
+
+  def onFailure(ctx: TaskContext, e: Throwable): Unit = {
+    calledBy += "callback,"
+    exception = e
+  }
+}
+
+class FakeWriterWithCallback extends FakeWriter {
+
+  override def close(p1: Reporter): Unit = {
+    FakeWriterWithCallback.calledBy += "close"
+  }
+
+  override def write(p1: Integer, p2: Integer): Unit = {
+    FakeWriterWithCallback.calledBy += "write,"
+    TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
+      FakeWriterWithCallback.onFailure(t, e)
+    }
+    throw new IOException("failed to write")
+  }
+}
+
+class FakeFormatWithCallback() extends FakeOutputFormat {
+  override def getRecordWriter(
+    ignored: FileSystem,
+    job: JobConf, name: String,
+    progress: Progressable): RecordWriter[Integer, Integer] = {
+    new FakeWriterWithCallback()
+  }
+}
+
+class NewFakeWriterWithCallback extends NewFakeWriter {
+  override def close(p1: NewTaskAttempContext): Unit = {
+    FakeWriterWithCallback.calledBy += "close"
+  }
+
+  override def write(p1: Integer, p2: Integer): Unit = {
+    FakeWriterWithCallback.calledBy += "write,"
+    TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
+      FakeWriterWithCallback.onFailure(t, e)
+    }
+    throw new IOException("failed to write")
+  }
+}
+
+class NewFakeFormatWithCallback() extends NewFakeFormat {
+  override def getRecordWriter(p1: NewTaskAttempContext): NewRecordWriter[Integer, Integer] = {
+    new NewFakeWriterWithCallback()
+  }
+}
+
 class ConfigTestFormat() extends NewFakeFormat() with Configurable {
 
   var setConfCalled = false
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index c3db2a0af4..097e9c912b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -247,11 +247,9 @@ private[sql] class DefaultWriterContainer(
     executorSideSetup(taskContext)
     val configuration = taskAttemptContext.getConfiguration
     configuration.set("spark.sql.sources.output.path", outputPath)
-    val writer = newOutputWriter(getWorkPath)
+    var writer = newOutputWriter(getWorkPath)
     writer.initConverter(dataSchema)
 
-    var writerClosed = false
-
     // If anything below fails, we should abort the task.
     try {
       while (iterator.hasNext) {
@@ -263,16 +261,17 @@ private[sql] class DefaultWriterContainer(
     } catch {
       case cause: Throwable =>
         logError("Aborting task.", cause)
+        // call failure callbacks first, so we could have a chance to cleanup the writer.
+        TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause)
         abortTask()
         throw new SparkException("Task failed while writing rows.", cause)
     }
 
     def commitTask(): Unit = {
       try {
-        assert(writer != null, "OutputWriter instance should have been initialized")
-        if (!writerClosed) {
+        if (writer != null) {
           writer.close()
-          writerClosed = true
+          writer = null
         }
         super.commitTask()
       } catch {
@@ -285,9 +284,8 @@ private[sql] class DefaultWriterContainer(
 
     def abortTask(): Unit = {
       try {
-        if (!writerClosed) {
+        if (writer != null) {
           writer.close()
-          writerClosed = true
         }
       } finally {
         super.abortTask()
@@ -393,57 +391,62 @@ private[sql] class DynamicPartitionWriterContainer(
     val getPartitionString =
       UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns)
 
-    // If anything below fails, we should abort the task.
-    try {
-      // Sorts the data before write, so that we only need one writer at the same time.
-      // TODO: inject a local sort operator in planning.
-      val sorter = new UnsafeKVExternalSorter(
-        sortingKeySchema,
-        StructType.fromAttributes(dataColumns),
-        SparkEnv.get.blockManager,
-        TaskContext.get().taskMemoryManager().pageSizeBytes)
-
-      while (iterator.hasNext) {
-        val currentRow = iterator.next()
-        sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
-      }
+    // Sorts the data before write, so that we only need one writer at the same time.
+    // TODO: inject a local sort operator in planning.
+    val sorter = new UnsafeKVExternalSorter(
+      sortingKeySchema,
+      StructType.fromAttributes(dataColumns),
+      SparkEnv.get.blockManager,
+      TaskContext.get().taskMemoryManager().pageSizeBytes)
+
+    while (iterator.hasNext) {
+      val currentRow = iterator.next()
+      sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
+    }
+    logInfo(s"Sorting complete. Writing out partition files one at a time.")
 
-      logInfo(s"Sorting complete. Writing out partition files one at a time.")
+    val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
+      identity
+    } else {
+      UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
+        case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
+      })
+    }
 
-      val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
-        identity
-      } else {
-        UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
-          case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
-        })
-      }
+    val sortedIterator = sorter.sortedIterator()
 
-      val sortedIterator = sorter.sortedIterator()
+    // If anything below fails, we should abort the task.
+    var currentWriter: OutputWriter = null
+    try {
       var currentKey: UnsafeRow = null
-      var currentWriter: OutputWriter = null
-      try {
-        while (sortedIterator.next()) {
-          val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
-          if (currentKey != nextKey) {
-            if (currentWriter != null) {
-              currentWriter.close()
-            }
-            currentKey = nextKey.copy()
-            logDebug(s"Writing partition: $currentKey")
-
-            currentWriter = newOutputWriter(currentKey, getPartitionString)
+      while (sortedIterator.next()) {
+        val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
+        if (currentKey != nextKey) {
+          if (currentWriter != null) {
+            currentWriter.close()
+            currentWriter = null
           }
+          currentKey = nextKey.copy()
+          logDebug(s"Writing partition: $currentKey")
 
-          currentWriter.writeInternal(sortedIterator.getValue)
+          currentWriter = newOutputWriter(currentKey, getPartitionString)
         }
-      } finally {
-        if (currentWriter != null) { currentWriter.close() }
+        currentWriter.writeInternal(sortedIterator.getValue)
+      }
+      if (currentWriter != null) {
+        currentWriter.close()
+        currentWriter = null
       }
 
       commitTask()
     } catch {
       case cause: Throwable =>
         logError("Aborting task.", cause)
+        // call failure callbacks first, so we could have a chance to cleanup the writer.
+        TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause)
+        if (currentWriter != null) {
+          currentWriter.close()
+        }
         abortTask()
         throw new SparkException("Task failed while writing rows.", cause)
     }
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
index 64c61a5092..64c27da475 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
@@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path
 
 import org.apache.spark.SparkException
 import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 import org.apache.spark.sql.test.SQLTestUtils
 
@@ -30,6 +31,7 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton
   val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName
 
   test("SPARK-7684: commitTask() failure should fallback to abortTask()") {
+    SimpleTextRelation.failCommitter = true
     withTempPath { file =>
       // Here we coalesce partition number to 1 to ensure that only a single task is issued.  This
       // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary`
@@ -43,4 +45,59 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton
       assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary")))
     }
   }
+
+  test("call failure callbacks before close writer - default") {
+    SimpleTextRelation.failCommitter = false
+    withTempPath { file =>
+      // fail the job in the middle of writing
+      val divideByZero = udf((x: Int) => { x / (x - 1)})
+      val df = sqlContext.range(0, 10).select(divideByZero(col("id")))
+
+      SimpleTextRelation.callbackCalled = false
+      intercept[SparkException] {
+        df.write.format(dataSourceName).save(file.getCanonicalPath)
+      }
+      assert(SimpleTextRelation.callbackCalled, "failure callback should be called")
+
+      val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf)
+      assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary")))
+    }
+  }
+
+  test("failure callback of writer should not be called if failed before writing") {
+    SimpleTextRelation.failCommitter = false
+    withTempPath { file =>
+      // fail the job in the middle of writing
+      val divideByZero = udf((x: Int) => { x / (x - 1)})
+      val df = sqlContext.range(0, 10).select(col("id").mod(2).as("key"), divideByZero(col("id")))
+
+      SimpleTextRelation.callbackCalled = false
+      intercept[SparkException] {
+        df.write.format(dataSourceName).partitionBy("key").save(file.getCanonicalPath)
+      }
+      assert(!SimpleTextRelation.callbackCalled,
+        "the callback of writer should not be called if job failed before writing")
+
+      val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf)
+      assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary")))
+    }
+  }
+
+  test("call failure callbacks before close writer - partitioned") {
+    SimpleTextRelation.failCommitter = false
+    withTempPath { file =>
+      // fail the job in the middle of writing
+      val df = sqlContext.range(0, 10).select(col("id").mod(2).as("key"), col("id"))
+
+      SimpleTextRelation.callbackCalled = false
+      SimpleTextRelation.failWriter = true
+      intercept[SparkException] {
+        df.write.format(dataSourceName).partitionBy("key").save(file.getCanonicalPath)
+      }
+      assert(SimpleTextRelation.callbackCalled, "failure callback should be called")
+
+      val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf)
+      assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary")))
+    }
+  }
 }
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
index 9fc437bf88..9cdf1fc585 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala
@@ -25,6 +25,7 @@ import org.apache.hadoop.io.{NullWritable, Text}
 import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
 import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat}
 
+import org.apache.spark.TaskContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{sources, Row, SQLContext}
 import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters}
@@ -199,6 +200,15 @@ object SimpleTextRelation {
 
   // Used to test filter push-down
   var pushedFilters: Set[Filter] = Set.empty
+
+  // Used to test failed committer
+  var failCommitter = false
+
+  // Used to test failed writer
+  var failWriter = false
+
+  // Used to test failure callback
+  var callbackCalled = false
 }
 
 /**
@@ -229,9 +239,25 @@ class CommitFailureTestRelation(
         dataSchema: StructType,
         context: TaskAttemptContext): OutputWriter = {
       new SimpleTextOutputWriter(path, context) {
+        var failed = false
+        TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
+          failed = true
+          SimpleTextRelation.callbackCalled = true
+        }
+
+        override def write(row: Row): Unit = {
+          if (SimpleTextRelation.failWriter) {
+            sys.error("Intentional task writer failure for testing purpose.")
+
+          }
+          super.write(row)
+        }
+
         override def close(): Unit = {
+          if (SimpleTextRelation.failCommitter) {
+            sys.error("Intentional task commitment failure for testing purpose.")
+          }
           super.close()
-          sys.error("Intentional task commitment failure for testing purpose.")
         }
       }
     }
-- 
GitLab