diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 65f6f741f7f4ec4057fa12a15bca4b1f6d10a9f0..7e96040bc4688a1c68d0e5df734d1527a7eb9fab 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -53,6 +53,9 @@ private[spark] class TaskContextImpl( // Whether the task has completed. @volatile private var completed: Boolean = false + // Whether the task has failed. + @volatile private var failed: Boolean = false + override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { onCompleteCallbacks += listener this @@ -63,10 +66,13 @@ private[spark] class TaskContextImpl( this } - /** Marks the task as completed and triggers the failure listeners. */ + /** Marks the task as failed and triggers the failure listeners. */ private[spark] def markTaskFailed(error: Throwable): Unit = { + // failure callbacks should only be called once + if (failed) return + failed = true val errorMsgs = new ArrayBuffer[String](2) - // Process complete callbacks in the reverse order of registration + // Process failure callbacks in the reverse order of registration onFailureCallbacks.reverse.foreach { listener => try { listener.onTaskFailure(this, error) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index e00b9f6cfda0e20cc73ff3cd05ff3d23ccefb5dc..91460dc4067fb6f0d97a5b080902f37801c58c5d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -1101,7 +1101,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]] require(writer != null, "Unable to obtain RecordWriter") var recordsWritten = 0L - Utils.tryWithSafeFinally { + Utils.tryWithSafeFinallyAndFailureCallbacks { while (iter.hasNext) { val pair = iter.next() writer.write(pair._1, pair._2) @@ -1190,7 +1190,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.open() var recordsWritten = 0L - Utils.tryWithSafeFinally { + Utils.tryWithSafeFinallyAndFailureCallbacks { while (iter.hasNext) { val record = iter.next() writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 6103a10ccc50e503af4d80c3b8257714178f609b..cfe247c668d7805eba23cfb8d6242b0aed56b6de 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1241,7 +1241,6 @@ private[spark] object Utils extends Logging { * exception from the original `out.write` call. */ def tryWithSafeFinally[T](block: => T)(finallyBlock: => Unit): T = { - // It would be nice to find a method on Try that did this var originalThrowable: Throwable = null try { block @@ -1267,6 +1266,44 @@ private[spark] object Utils extends Logging { } } + /** + * Execute a block of code, call the failure callbacks before finally block if there is any + * exceptions happen. But if exceptions happen in the finally block, do not suppress the original + * exception. + * + * This is primarily an issue with `finally { out.close() }` blocks, where + * close needs to be called to clean up `out`, but if an exception happened + * in `out.write`, it's likely `out` may be corrupted and `out.close` will + * fail as well. This would then suppress the original/likely more meaningful + * exception from the original `out.write` call. + */ + def tryWithSafeFinallyAndFailureCallbacks[T](block: => T)(finallyBlock: => Unit): T = { + var originalThrowable: Throwable = null + try { + block + } catch { + case t: Throwable => + // Purposefully not using NonFatal, because even fatal exceptions + // we don't want to have our finallyBlock suppress + originalThrowable = t + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(t) + throw originalThrowable + } finally { + try { + finallyBlock + } catch { + case t: Throwable => + if (originalThrowable != null) { + originalThrowable.addSuppressed(t) + logWarning(s"Suppressing exception in finally: " + t.getMessage, t) + throw originalThrowable + } else { + throw t + } + } + } + } + /** Default filtering function for finding call sites using `getCallSite`. */ private def sparkInternalExclusionFunction(className: String): Boolean = { // A regular expression to match classes of the internal Spark API's diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 16e2d2e636c11eb3066b78c39638dc35cb69fc4b..7d51538d92597f580a5606c21235da7c108b4c25 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.rdd +import java.io.IOException + import scala.collection.mutable.{ArrayBuffer, HashSet} import scala.util.Random @@ -29,7 +31,8 @@ import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext} import org.apache.hadoop.util.Progressable -import org.apache.spark.{Partitioner, SharedSparkContext, SparkFunSuite} +import org.apache.spark._ +import org.apache.spark.Partitioner import org.apache.spark.util.Utils class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { @@ -533,6 +536,38 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(FakeOutputCommitter.ran, "OutputCommitter was never called") } + test("failure callbacks should be called before calling writer.close() in saveNewAPIHadoopFile") { + val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + + FakeWriterWithCallback.calledBy = "" + FakeWriterWithCallback.exception = null + val e = intercept[SparkException] { + pairs.saveAsNewAPIHadoopFile[NewFakeFormatWithCallback]("ignored") + } + assert(e.getMessage contains "failed to write") + + assert(FakeWriterWithCallback.calledBy === "write,callback,close") + assert(FakeWriterWithCallback.exception != null, "exception should be captured") + assert(FakeWriterWithCallback.exception.getMessage contains "failed to write") + } + + test("failure callbacks should be called before calling writer.close() in saveAsHadoopFile") { + val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + val conf = new JobConf() + + FakeWriterWithCallback.calledBy = "" + FakeWriterWithCallback.exception = null + val e = intercept[SparkException] { + pairs.saveAsHadoopFile( + "ignored", pairs.keyClass, pairs.valueClass, classOf[FakeFormatWithCallback], conf) + } + assert(e.getMessage contains "failed to write") + + assert(FakeWriterWithCallback.calledBy === "write,callback,close") + assert(FakeWriterWithCallback.exception != null, "exception should be captured") + assert(FakeWriterWithCallback.exception.getMessage contains "failed to write") + } + test("lookup") { val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (5, 7))) @@ -776,6 +811,60 @@ class NewFakeFormat() extends NewOutputFormat[Integer, Integer]() { } } +object FakeWriterWithCallback { + var calledBy: String = "" + var exception: Throwable = _ + + def onFailure(ctx: TaskContext, e: Throwable): Unit = { + calledBy += "callback," + exception = e + } +} + +class FakeWriterWithCallback extends FakeWriter { + + override def close(p1: Reporter): Unit = { + FakeWriterWithCallback.calledBy += "close" + } + + override def write(p1: Integer, p2: Integer): Unit = { + FakeWriterWithCallback.calledBy += "write," + TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) => + FakeWriterWithCallback.onFailure(t, e) + } + throw new IOException("failed to write") + } +} + +class FakeFormatWithCallback() extends FakeOutputFormat { + override def getRecordWriter( + ignored: FileSystem, + job: JobConf, name: String, + progress: Progressable): RecordWriter[Integer, Integer] = { + new FakeWriterWithCallback() + } +} + +class NewFakeWriterWithCallback extends NewFakeWriter { + override def close(p1: NewTaskAttempContext): Unit = { + FakeWriterWithCallback.calledBy += "close" + } + + override def write(p1: Integer, p2: Integer): Unit = { + FakeWriterWithCallback.calledBy += "write," + TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) => + FakeWriterWithCallback.onFailure(t, e) + } + throw new IOException("failed to write") + } +} + +class NewFakeFormatWithCallback() extends NewFakeFormat { + override def getRecordWriter(p1: NewTaskAttempContext): NewRecordWriter[Integer, Integer] = { + new NewFakeWriterWithCallback() + } +} + class ConfigTestFormat() extends NewFakeFormat() with Configurable { var setConfCalled = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index c3db2a0af4bd1a6144fd28e731fcafc6953a1c78..097e9c912b5743d147601e1195cac34830d2e27b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -247,11 +247,9 @@ private[sql] class DefaultWriterContainer( executorSideSetup(taskContext) val configuration = taskAttemptContext.getConfiguration configuration.set("spark.sql.sources.output.path", outputPath) - val writer = newOutputWriter(getWorkPath) + var writer = newOutputWriter(getWorkPath) writer.initConverter(dataSchema) - var writerClosed = false - // If anything below fails, we should abort the task. try { while (iterator.hasNext) { @@ -263,16 +261,17 @@ private[sql] class DefaultWriterContainer( } catch { case cause: Throwable => logError("Aborting task.", cause) + // call failure callbacks first, so we could have a chance to cleanup the writer. + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause) abortTask() throw new SparkException("Task failed while writing rows.", cause) } def commitTask(): Unit = { try { - assert(writer != null, "OutputWriter instance should have been initialized") - if (!writerClosed) { + if (writer != null) { writer.close() - writerClosed = true + writer = null } super.commitTask() } catch { @@ -285,9 +284,8 @@ private[sql] class DefaultWriterContainer( def abortTask(): Unit = { try { - if (!writerClosed) { + if (writer != null) { writer.close() - writerClosed = true } } finally { super.abortTask() @@ -393,57 +391,62 @@ private[sql] class DynamicPartitionWriterContainer( val getPartitionString = UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns) - // If anything below fails, we should abort the task. - try { - // Sorts the data before write, so that we only need one writer at the same time. - // TODO: inject a local sort operator in planning. - val sorter = new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(dataColumns), - SparkEnv.get.blockManager, - TaskContext.get().taskMemoryManager().pageSizeBytes) - - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) - } + // Sorts the data before write, so that we only need one writer at the same time. + // TODO: inject a local sort operator in planning. + val sorter = new UnsafeKVExternalSorter( + sortingKeySchema, + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) + } + logInfo(s"Sorting complete. Writing out partition files one at a time.") - logInfo(s"Sorting complete. Writing out partition files one at a time.") + val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { + identity + } else { + UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { + case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) + }) + } - val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { - identity - } else { - UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { - case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) - }) - } + val sortedIterator = sorter.sortedIterator() - val sortedIterator = sorter.sortedIterator() + // If anything below fails, we should abort the task. + var currentWriter: OutputWriter = null + try { var currentKey: UnsafeRow = null - var currentWriter: OutputWriter = null - try { - while (sortedIterator.next()) { - val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] - if (currentKey != nextKey) { - if (currentWriter != null) { - currentWriter.close() - } - currentKey = nextKey.copy() - logDebug(s"Writing partition: $currentKey") - - currentWriter = newOutputWriter(currentKey, getPartitionString) + while (sortedIterator.next()) { + val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] + if (currentKey != nextKey) { + if (currentWriter != null) { + currentWriter.close() + currentWriter = null } + currentKey = nextKey.copy() + logDebug(s"Writing partition: $currentKey") - currentWriter.writeInternal(sortedIterator.getValue) + currentWriter = newOutputWriter(currentKey, getPartitionString) } - } finally { - if (currentWriter != null) { currentWriter.close() } + currentWriter.writeInternal(sortedIterator.getValue) + } + if (currentWriter != null) { + currentWriter.close() + currentWriter = null } commitTask() } catch { case cause: Throwable => logError("Aborting task.", cause) + // call failure callbacks first, so we could have a chance to cleanup the writer. + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause) + if (currentWriter != null) { + currentWriter.close() + } abortTask() throw new SparkException("Task failed while writing rows.", cause) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala index 64c61a50925409b986cb3ed1ebbc8e7928fe63a9..64c27da4751e3e81db4aab592bd12a1ab6caa6c6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -30,6 +31,7 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName test("SPARK-7684: commitTask() failure should fallback to abortTask()") { + SimpleTextRelation.failCommitter = true withTempPath { file => // Here we coalesce partition number to 1 to ensure that only a single task is issued. This // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary` @@ -43,4 +45,59 @@ class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) } } + + test("call failure callbacks before close writer - default") { + SimpleTextRelation.failCommitter = false + withTempPath { file => + // fail the job in the middle of writing + val divideByZero = udf((x: Int) => { x / (x - 1)}) + val df = sqlContext.range(0, 10).select(divideByZero(col("id"))) + + SimpleTextRelation.callbackCalled = false + intercept[SparkException] { + df.write.format(dataSourceName).save(file.getCanonicalPath) + } + assert(SimpleTextRelation.callbackCalled, "failure callback should be called") + + val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) + assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) + } + } + + test("failure callback of writer should not be called if failed before writing") { + SimpleTextRelation.failCommitter = false + withTempPath { file => + // fail the job in the middle of writing + val divideByZero = udf((x: Int) => { x / (x - 1)}) + val df = sqlContext.range(0, 10).select(col("id").mod(2).as("key"), divideByZero(col("id"))) + + SimpleTextRelation.callbackCalled = false + intercept[SparkException] { + df.write.format(dataSourceName).partitionBy("key").save(file.getCanonicalPath) + } + assert(!SimpleTextRelation.callbackCalled, + "the callback of writer should not be called if job failed before writing") + + val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) + assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) + } + } + + test("call failure callbacks before close writer - partitioned") { + SimpleTextRelation.failCommitter = false + withTempPath { file => + // fail the job in the middle of writing + val df = sqlContext.range(0, 10).select(col("id").mod(2).as("key"), col("id")) + + SimpleTextRelation.callbackCalled = false + SimpleTextRelation.failWriter = true + intercept[SparkException] { + df.write.format(dataSourceName).partitionBy("key").save(file.getCanonicalPath) + } + assert(SimpleTextRelation.callbackCalled, "failure callback should be called") + + val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) + assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 9fc437bf8815aad16a71e6b6d2b62a4caa4f1bbd..9cdf1fc585866ed63fd6f8002e5cfc1c0c2e7938 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.{sources, Row, SQLContext} import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters} @@ -199,6 +200,15 @@ object SimpleTextRelation { // Used to test filter push-down var pushedFilters: Set[Filter] = Set.empty + + // Used to test failed committer + var failCommitter = false + + // Used to test failed writer + var failWriter = false + + // Used to test failure callback + var callbackCalled = false } /** @@ -229,9 +239,25 @@ class CommitFailureTestRelation( dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { new SimpleTextOutputWriter(path, context) { + var failed = false + TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) => + failed = true + SimpleTextRelation.callbackCalled = true + } + + override def write(row: Row): Unit = { + if (SimpleTextRelation.failWriter) { + sys.error("Intentional task writer failure for testing purpose.") + + } + super.write(row) + } + override def close(): Unit = { + if (SimpleTextRelation.failCommitter) { + sys.error("Intentional task commitment failure for testing purpose.") + } super.close() - sys.error("Intentional task commitment failure for testing purpose.") } } }