diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index e9932c09107dbbf7c5fd9877cadc415c8dd273e4..bd3aad66317487e8156b2611a858139c87c65d33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection -import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext, SaveMode} @@ -127,8 +127,11 @@ private[sql] case class InsertIntoHadoopFsRelation( val needsConversion = relation.needConversion val dataSchema = relation.dataSchema + // This call shouldn't be put into the `try` block below because it only initializes and + // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. + writerContainer.driverSideSetup() + try { - writerContainer.driverSideSetup() df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _) writerContainer.commitJob() relation.refresh() @@ -139,9 +142,10 @@ private[sql] case class InsertIntoHadoopFsRelation( } def writeRows(taskContext: TaskContext, iterator: Iterator[Row]): Unit = { - writerContainer.executorSideSetup(taskContext) - + // If anything below fails, we should abort the task. try { + writerContainer.executorSideSetup(taskContext) + if (needsConversion) { val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) while (iterator.hasNext) { @@ -154,6 +158,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.outputWriterForRow(row).write(row) } } + writerContainer.commitTask() } catch { case cause: Throwable => logError("Aborting task.", cause) @@ -191,8 +196,11 @@ private[sql] case class InsertIntoHadoopFsRelation( val (partitionOutput, dataOutput) = output.partition(a => partitionColumns.contains(a.name)) val codegenEnabled = df.sqlContext.conf.codegenEnabled + // This call shouldn't be put into the `try` block below because it only initializes and + // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. + writerContainer.driverSideSetup() + try { - writerContainer.driverSideSetup() df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _) writerContainer.commitJob() relation.refresh() @@ -203,32 +211,39 @@ private[sql] case class InsertIntoHadoopFsRelation( } def writeRows(taskContext: TaskContext, iterator: Iterator[Row]): Unit = { - writerContainer.executorSideSetup(taskContext) - - val partitionProj = newProjection(codegenEnabled, partitionOutput, output) - val dataProj = newProjection(codegenEnabled, dataOutput, output) - - if (needsConversion) { - val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) - while (iterator.hasNext) { - val row = iterator.next() - val partitionPart = partitionProj(row) - val dataPart = dataProj(row) - val convertedDataPart = converter(dataPart).asInstanceOf[Row] - writerContainer.outputWriterForRow(partitionPart).write(convertedDataPart) - } - } else { - val partitionSchema = StructType.fromAttributes(partitionOutput) - val converter = CatalystTypeConverters.createToScalaConverter(partitionSchema) - while (iterator.hasNext) { - val row = iterator.next() - val partitionPart = converter(partitionProj(row)).asInstanceOf[Row] - val dataPart = dataProj(row) - writerContainer.outputWriterForRow(partitionPart).write(dataPart) + // If anything below fails, we should abort the task. + try { + writerContainer.executorSideSetup(taskContext) + + val partitionProj = newProjection(codegenEnabled, partitionOutput, output) + val dataProj = newProjection(codegenEnabled, dataOutput, output) + + if (needsConversion) { + val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) + while (iterator.hasNext) { + val row = iterator.next() + val partitionPart = partitionProj(row) + val dataPart = dataProj(row) + val convertedDataPart = converter(dataPart).asInstanceOf[Row] + writerContainer.outputWriterForRow(partitionPart).write(convertedDataPart) + } + } else { + val partitionSchema = StructType.fromAttributes(partitionOutput) + val converter = CatalystTypeConverters.createToScalaConverter(partitionSchema) + while (iterator.hasNext) { + val row = iterator.next() + val partitionPart = converter(partitionProj(row)).asInstanceOf[Row] + val dataPart = dataProj(row) + writerContainer.outputWriterForRow(partitionPart).write(dataPart) + } } - } - writerContainer.commitTask() + writerContainer.commitTask() + } catch { case cause: Throwable => + logError("Aborting task.", cause) + writerContainer.abortTask() + throw new SparkException("Task failed while writing rows.", cause) + } } } @@ -283,7 +298,12 @@ private[sql] abstract class BaseWriterContainer( setupIDs(0, 0, 0) setupConf() taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + + // This preparation must happen before initializing output format and output committer, since + // their initialization involves the job configuration, which can be potentially decorated in + // `relation.prepareJobForWrite`. outputWriterFactory = relation.prepareJobForWrite(job) + outputFormatClass = job.getOutputFormatClass outputCommitter = newOutputCommitter(taskAttemptContext) outputCommitter.setupJob(jobContext) @@ -359,7 +379,9 @@ private[sql] abstract class BaseWriterContainer( } def abortTask(): Unit = { - outputCommitter.abortTask(taskAttemptContext) + if (outputCommitter != null) { + outputCommitter.abortTask(taskAttemptContext) + } logError(s"Task attempt $taskAttemptId aborted.") } @@ -369,7 +391,9 @@ private[sql] abstract class BaseWriterContainer( } def abortJob(): Unit = { - outputCommitter.abortJob(jobContext, JobStatus.State.FAILED) + if (outputCommitter != null) { + outputCommitter.abortJob(jobContext, JobStatus.State.FAILED) + } logError(s"Job $jobId aborted.") } } @@ -390,6 +414,7 @@ private[sql] class DefaultWriterContainer( override def commitTask(): Unit = { try { + assert(writer != null, "OutputWriter instance should have been initialized") writer.close() super.commitTask() } catch { @@ -401,7 +426,9 @@ private[sql] class DefaultWriterContainer( override def abortTask(): Unit = { try { - writer.close() + if (writer != null) { + writer.close() + } } finally { super.abortTask() } @@ -445,6 +472,7 @@ private[sql] class DynamicPartitionWriterContainer( override def commitTask(): Unit = { try { outputWriters.values.foreach(_.close()) + outputWriters.clear() super.commitTask() } catch { case cause: Throwable => super.abortTask() @@ -455,6 +483,7 @@ private[sql] class DynamicPartitionWriterContainer( override def abortTask(): Unit = { try { outputWriters.values.foreach(_.close()) + outputWriters.clear() } finally { super.abortTask() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 8787663a98f8fbffa163a386522939131cf64343..76469d7a3d6a54ef78e46e1c02fd1b006d8a59da 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -594,4 +594,19 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { checkAnswer(read.format("parquet").load(path), df) } } + + test("SPARK-8079: Avoid NPE thrown from BaseWriterContainer.abortJob") { + withTempPath { dir => + intercept[AnalysisException] { + // Parquet doesn't allow field names with spaces. Here we are intentionally making an + // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger + // the bug. Please refer to spark-8079 for more details. + range(1, 10) + .withColumnRenamed("id", "a b") + .write + .format("parquet") + .save(dir.getCanonicalPath) + } + } + } }