From 16fc49617e1dfcbe9122b224f7f63b7bfddb36ce Mon Sep 17 00:00:00 2001
From: Cheng Lian <lian@databricks.com>
Date: Sat, 6 Jun 2015 17:23:12 +0800
Subject: [PATCH] [SPARK-8079] [SQL] Makes InsertIntoHadoopFsRelation job/task
 abortion more robust

As described in SPARK-8079, when writing a DataFrame to a `HadoopFsRelation`, if `HadoopFsRelation.prepareForWriteJob` throws exception, an unexpected NPE will be thrown during job abortion. (This issue doesn't bring much damage since the job is failing anyway.)

This PR makes the job/task abortion logic in `InsertIntoHadoopFsRelation` more robust to avoid such confusing exceptions.

Author: Cheng Lian <lian@databricks.com>

Closes #6612 from liancheng/spark-8079 and squashes the following commits:

87cd81e [Cheng Lian] Addresses @rxin's comment
1864c75 [Cheng Lian] Addresses review comments
9e6dbb3 [Cheng Lian] Makes InsertIntoHadoopFsRelation job/task abortion more robust
---
 .../apache/spark/sql/sources/commands.scala   | 93 ++++++++++++-------
 .../sql/sources/hadoopFsRelationSuites.scala  | 15 +++
 2 files changed, 76 insertions(+), 32 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
index e9932c0910..bd3aad6631 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection
-import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
 import org.apache.spark.sql.execution.RunnableCommand
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext, SaveMode}
@@ -127,8 +127,11 @@ private[sql] case class InsertIntoHadoopFsRelation(
     val needsConversion = relation.needConversion
     val dataSchema = relation.dataSchema
 
+    // This call shouldn't be put into the `try` block below because it only initializes and
+    // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
+    writerContainer.driverSideSetup()
+
     try {
-      writerContainer.driverSideSetup()
       df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
       writerContainer.commitJob()
       relation.refresh()
@@ -139,9 +142,10 @@ private[sql] case class InsertIntoHadoopFsRelation(
     }
 
     def writeRows(taskContext: TaskContext, iterator: Iterator[Row]): Unit = {
-      writerContainer.executorSideSetup(taskContext)
-
+      // If anything below fails, we should abort the task.
       try {
+        writerContainer.executorSideSetup(taskContext)
+
         if (needsConversion) {
           val converter = CatalystTypeConverters.createToScalaConverter(dataSchema)
           while (iterator.hasNext) {
@@ -154,6 +158,7 @@ private[sql] case class InsertIntoHadoopFsRelation(
             writerContainer.outputWriterForRow(row).write(row)
           }
         }
+
         writerContainer.commitTask()
       } catch { case cause: Throwable =>
         logError("Aborting task.", cause)
@@ -191,8 +196,11 @@ private[sql] case class InsertIntoHadoopFsRelation(
     val (partitionOutput, dataOutput) = output.partition(a => partitionColumns.contains(a.name))
     val codegenEnabled = df.sqlContext.conf.codegenEnabled
 
+    // This call shouldn't be put into the `try` block below because it only initializes and
+    // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
+    writerContainer.driverSideSetup()
+
     try {
-      writerContainer.driverSideSetup()
       df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _)
       writerContainer.commitJob()
       relation.refresh()
@@ -203,32 +211,39 @@ private[sql] case class InsertIntoHadoopFsRelation(
     }
 
     def writeRows(taskContext: TaskContext, iterator: Iterator[Row]): Unit = {
-      writerContainer.executorSideSetup(taskContext)
-
-      val partitionProj = newProjection(codegenEnabled, partitionOutput, output)
-      val dataProj = newProjection(codegenEnabled, dataOutput, output)
-
-      if (needsConversion) {
-        val converter = CatalystTypeConverters.createToScalaConverter(dataSchema)
-        while (iterator.hasNext) {
-          val row = iterator.next()
-          val partitionPart = partitionProj(row)
-          val dataPart = dataProj(row)
-          val convertedDataPart = converter(dataPart).asInstanceOf[Row]
-          writerContainer.outputWriterForRow(partitionPart).write(convertedDataPart)
-        }
-      } else {
-        val partitionSchema = StructType.fromAttributes(partitionOutput)
-        val converter = CatalystTypeConverters.createToScalaConverter(partitionSchema)
-        while (iterator.hasNext) {
-          val row = iterator.next()
-          val partitionPart = converter(partitionProj(row)).asInstanceOf[Row]
-          val dataPart = dataProj(row)
-          writerContainer.outputWriterForRow(partitionPart).write(dataPart)
+      // If anything below fails, we should abort the task.
+      try {
+        writerContainer.executorSideSetup(taskContext)
+
+        val partitionProj = newProjection(codegenEnabled, partitionOutput, output)
+        val dataProj = newProjection(codegenEnabled, dataOutput, output)
+
+        if (needsConversion) {
+          val converter = CatalystTypeConverters.createToScalaConverter(dataSchema)
+          while (iterator.hasNext) {
+            val row = iterator.next()
+            val partitionPart = partitionProj(row)
+            val dataPart = dataProj(row)
+            val convertedDataPart = converter(dataPart).asInstanceOf[Row]
+            writerContainer.outputWriterForRow(partitionPart).write(convertedDataPart)
+          }
+        } else {
+          val partitionSchema = StructType.fromAttributes(partitionOutput)
+          val converter = CatalystTypeConverters.createToScalaConverter(partitionSchema)
+          while (iterator.hasNext) {
+            val row = iterator.next()
+            val partitionPart = converter(partitionProj(row)).asInstanceOf[Row]
+            val dataPart = dataProj(row)
+            writerContainer.outputWriterForRow(partitionPart).write(dataPart)
+          }
         }
-      }
 
-      writerContainer.commitTask()
+        writerContainer.commitTask()
+      } catch { case cause: Throwable =>
+        logError("Aborting task.", cause)
+        writerContainer.abortTask()
+        throw new SparkException("Task failed while writing rows.", cause)
+      }
     }
   }
 
@@ -283,7 +298,12 @@ private[sql] abstract class BaseWriterContainer(
     setupIDs(0, 0, 0)
     setupConf()
     taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId)
+
+    // This preparation must happen before initializing output format and output committer, since
+    // their initialization involves the job configuration, which can be potentially decorated in
+    // `relation.prepareJobForWrite`.
     outputWriterFactory = relation.prepareJobForWrite(job)
+
     outputFormatClass = job.getOutputFormatClass
     outputCommitter = newOutputCommitter(taskAttemptContext)
     outputCommitter.setupJob(jobContext)
@@ -359,7 +379,9 @@ private[sql] abstract class BaseWriterContainer(
   }
 
   def abortTask(): Unit = {
-    outputCommitter.abortTask(taskAttemptContext)
+    if (outputCommitter != null) {
+      outputCommitter.abortTask(taskAttemptContext)
+    }
     logError(s"Task attempt $taskAttemptId aborted.")
   }
 
@@ -369,7 +391,9 @@ private[sql] abstract class BaseWriterContainer(
   }
 
   def abortJob(): Unit = {
-    outputCommitter.abortJob(jobContext, JobStatus.State.FAILED)
+    if (outputCommitter != null) {
+      outputCommitter.abortJob(jobContext, JobStatus.State.FAILED)
+    }
     logError(s"Job $jobId aborted.")
   }
 }
@@ -390,6 +414,7 @@ private[sql] class DefaultWriterContainer(
 
   override def commitTask(): Unit = {
     try {
+      assert(writer != null, "OutputWriter instance should have been initialized")
       writer.close()
       super.commitTask()
     } catch {
@@ -401,7 +426,9 @@ private[sql] class DefaultWriterContainer(
 
   override def abortTask(): Unit = {
     try {
-      writer.close()
+      if (writer != null) {
+        writer.close()
+      }
     } finally {
       super.abortTask()
     }
@@ -445,6 +472,7 @@ private[sql] class DynamicPartitionWriterContainer(
   override def commitTask(): Unit = {
     try {
       outputWriters.values.foreach(_.close())
+      outputWriters.clear()
       super.commitTask()
     } catch { case cause: Throwable =>
       super.abortTask()
@@ -455,6 +483,7 @@ private[sql] class DynamicPartitionWriterContainer(
   override def abortTask(): Unit = {
     try {
       outputWriters.values.foreach(_.close())
+      outputWriters.clear()
     } finally {
       super.abortTask()
     }
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
index 8787663a98..76469d7a3d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
@@ -594,4 +594,19 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest {
       checkAnswer(read.format("parquet").load(path), df)
     }
   }
+
+  test("SPARK-8079: Avoid NPE thrown from BaseWriterContainer.abortJob") {
+    withTempPath { dir =>
+      intercept[AnalysisException] {
+        // Parquet doesn't allow field names with spaces.  Here we are intentionally making an
+        // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger
+        // the bug.  Please refer to spark-8079 for more details.
+        range(1, 10)
+          .withColumnRenamed("id", "a b")
+          .write
+          .format("parquet")
+          .save(dir.getCanonicalPath)
+      }
+    }
+  }
 }
-- 
GitLab