Skip to content
Snippets Groups Projects
Commit b5a59a0f authored by Davies Liu's avatar Davies Liu Committed by Davies Liu
Browse files

[SPARK-13601] call failure callbacks before writer.close()

## What changes were proposed in this pull request?

In order to tell OutputStream that the task has failed or not, we should call the failure callbacks BEFORE calling writer.close().

## How was this patch tested?

Added new unit tests.

Author: Davies Liu <davies@databricks.com>

Closes #11450 from davies/callback.
parent 9e01fe2e
No related branches found
No related tags found
No related merge requests found
......@@ -53,6 +53,9 @@ private[spark] class TaskContextImpl(
// Whether the task has completed.
@volatile private var completed: Boolean = false
// Whether the task has failed.
@volatile private var failed: Boolean = false
override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
onCompleteCallbacks += listener
this
......@@ -63,10 +66,13 @@ private[spark] class TaskContextImpl(
this
}
/** Marks the task as completed and triggers the failure listeners. */
/** Marks the task as failed and triggers the failure listeners. */
private[spark] def markTaskFailed(error: Throwable): Unit = {
// failure callbacks should only be called once
if (failed) return
failed = true
val errorMsgs = new ArrayBuffer[String](2)
// Process complete callbacks in the reverse order of registration
// Process failure callbacks in the reverse order of registration
onFailureCallbacks.reverse.foreach { listener =>
try {
listener.onTaskFailure(this, error)
......
......@@ -1101,7 +1101,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]]
require(writer != null, "Unable to obtain RecordWriter")
var recordsWritten = 0L
Utils.tryWithSafeFinally {
Utils.tryWithSafeFinallyAndFailureCallbacks {
while (iter.hasNext) {
val pair = iter.next()
writer.write(pair._1, pair._2)
......@@ -1190,7 +1190,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
writer.open()
var recordsWritten = 0L
Utils.tryWithSafeFinally {
Utils.tryWithSafeFinallyAndFailureCallbacks {
while (iter.hasNext) {
val record = iter.next()
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
......
......@@ -1241,7 +1241,6 @@ private[spark] object Utils extends Logging {
* exception from the original `out.write` call.
*/
def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = {
// It would be nice to find a method on Try that did this
var originalThrowable: Throwable = null
try {
block
......@@ -1267,6 +1266,44 @@ private[spark] object Utils extends Logging {
}
}
/**
* Execute a block of code, call the failure callbacks before finally block if there is any
* exceptions happen. But if exceptions happen in the finally block, do not suppress the original
* exception.
*
* This is primarily an issue with `finally { out.close() }` blocks, where
* close needs to be called to clean up `out`, but if an exception happened
* in `out.write`, it's likely `out` may be corrupted and `out.close` will
* fail as well. This would then suppress the original/likely more meaningful
* exception from the original `out.write` call.
*/
def tryWithSafeFinallyAndFailureCallbacks[T](block: => T)(finallyBlock: => Unit): T = {
var originalThrowable: Throwable = null
try {
block
} catch {
case t: Throwable =>
// Purposefully not using NonFatal, because even fatal exceptions
// we don't want to have our finallyBlock suppress
originalThrowable = t
TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(t)
throw originalThrowable
} finally {
try {
finallyBlock
} catch {
case t: Throwable =>
if (originalThrowable != null) {
originalThrowable.addSuppressed(t)
logWarning(s"Suppressing exception in finally: " + t.getMessage, t)
throw originalThrowable
} else {
throw t
}
}
}
}
/** Default filtering function for finding call sites using `getCallSite`. */
private def sparkInternalExclusionFunction(className: String): Boolean = {
// A regular expression to match classes of the internal Spark API's
......
......@@ -17,6 +17,8 @@
package org.apache.spark.rdd
import java.io.IOException
import scala.collection.mutable.{ArrayBuffer, HashSet}
import scala.util.Random
......@@ -29,7 +31,8 @@ import org.apache.hadoop.mapreduce.{JobContext => NewJobContext,
RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext}
import org.apache.hadoop.util.Progressable
import org.apache.spark.{Partitioner, SharedSparkContext, SparkFunSuite}
import org.apache.spark._
import org.apache.spark.Partitioner
import org.apache.spark.util.Utils
class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
......@@ -533,6 +536,38 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
assert(FakeOutputCommitter.ran, "OutputCommitter was never called")
}
test("failure callbacks should be called before calling writer.close() in saveNewAPIHadoopFile") {
val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1)
FakeWriterWithCallback.calledBy = ""
FakeWriterWithCallback.exception = null
val e = intercept[SparkException] {
pairs.saveAsNewAPIHadoopFile[NewFakeFormatWithCallback]("ignored")
}
assert(e.getMessage contains "failed to write")
assert(FakeWriterWithCallback.calledBy === "write,callback,close")
assert(FakeWriterWithCallback.exception != null, "exception should be captured")
assert(FakeWriterWithCallback.exception.getMessage contains "failed to write")
}
test("failure callbacks should be called before calling writer.close() in saveAsHadoopFile") {
val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1)
val conf = new JobConf()
FakeWriterWithCallback.calledBy = ""
FakeWriterWithCallback.exception = null
val e = intercept[SparkException] {
pairs.saveAsHadoopFile(
"ignored", pairs.keyClass, pairs.valueClass, classOf[FakeFormatWithCallback], conf)
}
assert(e.getMessage contains "failed to write")
assert(FakeWriterWithCallback.calledBy === "write,callback,close")
assert(FakeWriterWithCallback.exception != null, "exception should be captured")
assert(FakeWriterWithCallback.exception.getMessage contains "failed to write")
}
test("lookup") {
val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (5, 7)))
......@@ -776,6 +811,60 @@ class NewFakeFormat() extends NewOutputFormat[Integer, Integer]() {
}
}
object FakeWriterWithCallback {
var calledBy: String = ""
var exception: Throwable = _
def onFailure(ctx: TaskContext, e: Throwable): Unit = {
calledBy += "callback,"
exception = e
}
}
class FakeWriterWithCallback extends FakeWriter {
override def close(p1: Reporter): Unit = {
FakeWriterWithCallback.calledBy += "close"
}
override def write(p1: Integer, p2: Integer): Unit = {
FakeWriterWithCallback.calledBy += "write,"
TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
FakeWriterWithCallback.onFailure(t, e)
}
throw new IOException("failed to write")
}
}
class FakeFormatWithCallback() extends FakeOutputFormat {
override def getRecordWriter(
ignored: FileSystem,
job: JobConf, name: String,
progress: Progressable): RecordWriter[Integer, Integer] = {
new FakeWriterWithCallback()
}
}
class NewFakeWriterWithCallback extends NewFakeWriter {
override def close(p1: NewTaskAttempContext): Unit = {
FakeWriterWithCallback.calledBy += "close"
}
override def write(p1: Integer, p2: Integer): Unit = {
FakeWriterWithCallback.calledBy += "write,"
TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
FakeWriterWithCallback.onFailure(t, e)
}
throw new IOException("failed to write")
}
}
class NewFakeFormatWithCallback() extends NewFakeFormat {
override def getRecordWriter(p1: NewTaskAttempContext): NewRecordWriter[Integer, Integer] = {
new NewFakeWriterWithCallback()
}
}
class ConfigTestFormat() extends NewFakeFormat() with Configurable {
var setConfCalled = false
......
......@@ -247,11 +247,9 @@ private[sql] class DefaultWriterContainer(
executorSideSetup(taskContext)
val configuration = taskAttemptContext.getConfiguration
configuration.set("spark.sql.sources.output.path", outputPath)
val writer = newOutputWriter(getWorkPath)
var writer = newOutputWriter(getWorkPath)
writer.initConverter(dataSchema)
var writerClosed = false
// If anything below fails, we should abort the task.
try {
while (iterator.hasNext) {
......@@ -263,16 +261,17 @@ private[sql] class DefaultWriterContainer(
} catch {
case cause: Throwable =>
logError("Aborting task.", cause)
// call failure callbacks first, so we could have a chance to cleanup the writer.
TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause)
abortTask()
throw new SparkException("Task failed while writing rows.", cause)
}
def commitTask(): Unit = {
try {
assert(writer != null, "OutputWriter instance should have been initialized")
if (!writerClosed) {
if (writer != null) {
writer.close()
writerClosed = true
writer = null
}
super.commitTask()
} catch {
......@@ -285,9 +284,8 @@ private[sql] class DefaultWriterContainer(
def abortTask(): Unit = {
try {
if (!writerClosed) {
if (writer != null) {
writer.close()
writerClosed = true
}
} finally {
super.abortTask()
......@@ -393,57 +391,62 @@ private[sql] class DynamicPartitionWriterContainer(
val getPartitionString =
UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns)
// If anything below fails, we should abort the task.
try {
// Sorts the data before write, so that we only need one writer at the same time.
// TODO: inject a local sort operator in planning.
val sorter = new UnsafeKVExternalSorter(
sortingKeySchema,
StructType.fromAttributes(dataColumns),
SparkEnv.get.blockManager,
TaskContext.get().taskMemoryManager().pageSizeBytes)
while (iterator.hasNext) {
val currentRow = iterator.next()
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
}
// Sorts the data before write, so that we only need one writer at the same time.
// TODO: inject a local sort operator in planning.
val sorter = new UnsafeKVExternalSorter(
sortingKeySchema,
StructType.fromAttributes(dataColumns),
SparkEnv.get.blockManager,
TaskContext.get().taskMemoryManager().pageSizeBytes)
while (iterator.hasNext) {
val currentRow = iterator.next()
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
}
logInfo(s"Sorting complete. Writing out partition files one at a time.")
logInfo(s"Sorting complete. Writing out partition files one at a time.")
val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
identity
} else {
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
})
}
val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
identity
} else {
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
})
}
val sortedIterator = sorter.sortedIterator()
val sortedIterator = sorter.sortedIterator()
// If anything below fails, we should abort the task.
var currentWriter: OutputWriter = null
try {
var currentKey: UnsafeRow = null
var currentWriter: OutputWriter = null
try {
while (sortedIterator.next()) {
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
if (currentKey != nextKey) {
if (currentWriter != null) {
currentWriter.close()
}
currentKey = nextKey.copy()
logDebug(s"Writing partition: $currentKey")
currentWriter = newOutputWriter(currentKey, getPartitionString)
while (sortedIterator.next()) {
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
if (currentKey != nextKey) {
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}
currentKey = nextKey.copy()
logDebug(s"Writing partition: $currentKey")
currentWriter.writeInternal(sortedIterator.getValue)
currentWriter = newOutputWriter(currentKey, getPartitionString)
}
} finally {
if (currentWriter != null) { currentWriter.close() }
currentWriter.writeInternal(sortedIterator.getValue)
}
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}
commitTask()
} catch {
case cause: Throwable =>
logError("Aborting task.", cause)
// call failure callbacks first, so we could have a chance to cleanup the writer.
TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause)
if (currentWriter != null) {
currentWriter.close()
}
abortTask()
throw new SparkException("Task failed while writing rows.", cause)
}
......
......@@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
......@@ -30,6 +31,7 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton
val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName
test("SPARK-7684: commitTask() failure should fallback to abortTask()") {
SimpleTextRelation.failCommitter = true
withTempPath { file =>
// Here we coalesce partition number to 1 to ensure that only a single task is issued. This
// prevents race condition happened when FileOutputCommitter tries to remove the `_temporary`
......@@ -43,4 +45,59 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton
assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary")))
}
}
test("call failure callbacks before close writer - default") {
SimpleTextRelation.failCommitter = false
withTempPath { file =>
// fail the job in the middle of writing
val divideByZero = udf((x: Int) => { x / (x - 1)})
val df = sqlContext.range(0, 10).select(divideByZero(col("id")))
SimpleTextRelation.callbackCalled = false
intercept[SparkException] {
df.write.format(dataSourceName).save(file.getCanonicalPath)
}
assert(SimpleTextRelation.callbackCalled, "failure callback should be called")
val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf)
assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary")))
}
}
test("failure callback of writer should not be called if failed before writing") {
SimpleTextRelation.failCommitter = false
withTempPath { file =>
// fail the job in the middle of writing
val divideByZero = udf((x: Int) => { x / (x - 1)})
val df = sqlContext.range(0, 10).select(col("id").mod(2).as("key"), divideByZero(col("id")))
SimpleTextRelation.callbackCalled = false
intercept[SparkException] {
df.write.format(dataSourceName).partitionBy("key").save(file.getCanonicalPath)
}
assert(!SimpleTextRelation.callbackCalled,
"the callback of writer should not be called if job failed before writing")
val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf)
assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary")))
}
}
test("call failure callbacks before close writer - partitioned") {
SimpleTextRelation.failCommitter = false
withTempPath { file =>
// fail the job in the middle of writing
val df = sqlContext.range(0, 10).select(col("id").mod(2).as("key"), col("id"))
SimpleTextRelation.callbackCalled = false
SimpleTextRelation.failWriter = true
intercept[SparkException] {
df.write.format(dataSourceName).partitionBy("key").save(file.getCanonicalPath)
}
assert(SimpleTextRelation.callbackCalled, "failure callback should be called")
val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf)
assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary")))
}
}
}
......@@ -25,6 +25,7 @@ import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat}
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{sources, Row, SQLContext}
import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters}
......@@ -199,6 +200,15 @@ object SimpleTextRelation {
// Used to test filter push-down
var pushedFilters: Set[Filter] = Set.empty
// Used to test failed committer
var failCommitter = false
// Used to test failed writer
var failWriter = false
// Used to test failure callback
var callbackCalled = false
}
/**
......@@ -229,9 +239,25 @@ class CommitFailureTestRelation(
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new SimpleTextOutputWriter(path, context) {
var failed = false
TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
failed = true
SimpleTextRelation.callbackCalled = true
}
override def write(row: Row): Unit = {
if (SimpleTextRelation.failWriter) {
sys.error("Intentional task writer failure for testing purpose.")
}
super.write(row)
}
override def close(): Unit = {
if (SimpleTextRelation.failCommitter) {
sys.error("Intentional task commitment failure for testing purpose.")
}
super.close()
sys.error("Intentional task commitment failure for testing purpose.")
}
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment