From fde6945417355ae57500b67d034c9cad4f20d240 Mon Sep 17 00:00:00 2001
From: Cheng Lian <lian@databricks.com>
Date: Tue, 31 Mar 2015 07:48:37 +0800
Subject: [PATCH] [SPARK-6369] [SQL] Uses commit coordinator to help committing
 Hive and Parquet tables

This PR leverages the output commit coordinator introduced in #4066 to help committing Hive and Parquet tables.

This PR extracts output commit code in `SparkHadoopWriter.commit` to `SparkHadoopMapRedUtil.commitTask`, and reuses it for committing Parquet and Hive tables on executor side.

TODO

- [ ] Add tests

<!-- Reviewable:start -->
[<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/5139)
<!-- Reviewable:end -->

Author: Cheng Lian <lian@databricks.com>

Closes #5139 from liancheng/spark-6369 and squashes the following commits:

72eb628 [Cheng Lian] Fixes typo in javadoc
9a4b82b [Cheng Lian] Adds javadoc and addresses @aarondav's comments
dfdf3ef [Cheng Lian] Uses commit coordinator to help committing Hive and Parquet tables
---
 .../org/apache/spark/SparkHadoopWriter.scala  | 52 +----------
 .../spark/mapred/SparkHadoopMapRedUtil.scala  | 91 ++++++++++++++++++-
 .../sql/parquet/ParquetTableOperations.scala  | 11 ++-
 .../apache/spark/sql/parquet/newParquet.scala |  4 +-
 .../hive/execution/InsertIntoHiveTable.scala  |  1 -
 .../spark/sql/hive/hiveWriterContainers.scala | 17 +---
 6 files changed, 103 insertions(+), 73 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index 6eb4537d10..2ec42d3aea 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -26,7 +26,6 @@ import org.apache.hadoop.mapred._
 import org.apache.hadoop.fs.FileSystem
 import org.apache.hadoop.fs.Path
 
-import org.apache.spark.executor.CommitDeniedException
 import org.apache.spark.mapred.SparkHadoopMapRedUtil
 import org.apache.spark.rdd.HadoopRDD
 
@@ -104,55 +103,8 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
   }
 
   def commit() {
-    val taCtxt = getTaskContext()
-    val cmtr = getOutputCommitter()
-
-    // Called after we have decided to commit
-    def performCommit(): Unit = {
-      try {
-        cmtr.commitTask(taCtxt)
-        logInfo (s"$taID: Committed")
-      } catch {
-        case e: IOException =>
-          logError("Error committing the output of task: " + taID.value, e)
-          cmtr.abortTask(taCtxt)
-          throw e
-      }
-    }
-
-    // First, check whether the task's output has already been committed by some other attempt
-    if (cmtr.needsTaskCommit(taCtxt)) {
-      // The task output needs to be committed, but we don't know whether some other task attempt
-      // might be racing to commit the same output partition. Therefore, coordinate with the driver
-      // in order to determine whether this attempt can commit (see SPARK-4879).
-      val shouldCoordinateWithDriver: Boolean = {
-        val sparkConf = SparkEnv.get.conf
-        // We only need to coordinate with the driver if there are multiple concurrent task
-        // attempts, which should only occur if speculation is enabled
-        val speculationEnabled = sparkConf.getBoolean("spark.speculation", false)
-        // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs
-        sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", speculationEnabled)
-      }
-      if (shouldCoordinateWithDriver) {
-        val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator
-        val canCommit = outputCommitCoordinator.canCommit(jobID, splitID, attemptID)
-        if (canCommit) {
-          performCommit()
-        } else {
-          val msg = s"$taID: Not committed because the driver did not authorize commit"
-          logInfo(msg)
-          // We need to abort the task so that the driver can reschedule new attempts, if necessary
-          cmtr.abortTask(taCtxt)
-          throw new CommitDeniedException(msg, jobID, splitID, attemptID)
-        }
-      } else {
-        // Speculation is disabled or a user has chosen to manually bypass the commit coordination
-        performCommit()
-      }
-    } else {
-      // Some other attempt committed the output, so we do nothing and signal success
-      logInfo(s"No need to commit output of task because needsTaskCommit=false: ${taID.value}")
-    }
+    SparkHadoopMapRedUtil.commitTask(
+      getOutputCommitter(), getTaskContext(), jobID, splitID, attemptID)
   }
 
   def commitJob() {
diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
index 87c2aa4810..818f7a4c8d 100644
--- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
+++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
@@ -17,9 +17,15 @@
 
 package org.apache.spark.mapred
 
+import java.io.IOException
 import java.lang.reflect.Modifier
 
-import org.apache.hadoop.mapred.{TaskAttemptID, JobID, JobConf, JobContext, TaskAttemptContext}
+import org.apache.hadoop.mapred._
+import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext}
+import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter}
+
+import org.apache.spark.executor.CommitDeniedException
+import org.apache.spark.{Logging, SparkEnv, TaskContext}
 
 private[spark]
 trait SparkHadoopMapRedUtil {
@@ -65,3 +71,86 @@ trait SparkHadoopMapRedUtil {
     }
   }
 }
+
+object SparkHadoopMapRedUtil extends Logging {
+  /**
+   * Commits a task output.  Before committing the task output, we need to know whether some other
+   * task attempt might be racing to commit the same output partition. Therefore, coordinate with
+   * the driver in order to determine whether this attempt can commit (please see SPARK-4879 for
+   * details).
+   *
+   * Output commit coordinator is only contacted when the following two configurations are both set
+   * to `true`:
+   *
+   *  - `spark.speculation`
+   *  - `spark.hadoop.outputCommitCoordination.enabled`
+   */
+  def commitTask(
+      committer: MapReduceOutputCommitter,
+      mrTaskContext: MapReduceTaskAttemptContext,
+      jobId: Int,
+      splitId: Int,
+      attemptId: Int): Unit = {
+
+    val mrTaskAttemptID = mrTaskContext.getTaskAttemptID
+
+    // Called after we have decided to commit
+    def performCommit(): Unit = {
+      try {
+        committer.commitTask(mrTaskContext)
+        logInfo(s"$mrTaskAttemptID: Committed")
+      } catch {
+        case cause: IOException =>
+          logError(s"Error committing the output of task: $mrTaskAttemptID", cause)
+          committer.abortTask(mrTaskContext)
+          throw cause
+      }
+    }
+
+    // First, check whether the task's output has already been committed by some other attempt
+    if (committer.needsTaskCommit(mrTaskContext)) {
+      val shouldCoordinateWithDriver: Boolean = {
+        val sparkConf = SparkEnv.get.conf
+        // We only need to coordinate with the driver if there are multiple concurrent task
+        // attempts, which should only occur if speculation is enabled
+        val speculationEnabled = sparkConf.getBoolean("spark.speculation", defaultValue = false)
+        // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs
+        sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", speculationEnabled)
+      }
+
+      if (shouldCoordinateWithDriver) {
+        val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator
+        val canCommit = outputCommitCoordinator.canCommit(jobId, splitId, attemptId)
+
+        if (canCommit) {
+          performCommit()
+        } else {
+          val message =
+            s"$mrTaskAttemptID: Not committed because the driver did not authorize commit"
+          logInfo(message)
+          // We need to abort the task so that the driver can reschedule new attempts, if necessary
+          committer.abortTask(mrTaskContext)
+          throw new CommitDeniedException(message, jobId, splitId, attemptId)
+        }
+      } else {
+        // Speculation is disabled or a user has chosen to manually bypass the commit coordination
+        performCommit()
+      }
+    } else {
+      // Some other attempt committed the output, so we do nothing and signal success
+      logInfo(s"No need to commit output of task because needsTaskCommit=false: $mrTaskAttemptID")
+    }
+  }
+
+  def commitTask(
+      committer: MapReduceOutputCommitter,
+      mrTaskContext: MapReduceTaskAttemptContext,
+      sparkTaskContext: TaskContext): Unit = {
+    commitTask(
+      committer,
+      mrTaskContext,
+      sparkTaskContext.stageId(),
+      sparkTaskContext.partitionId(),
+      sparkTaskContext.attemptNumber())
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index 5130d8ad5e..1c868da23e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -19,10 +19,9 @@ package org.apache.spark.sql.parquet
 
 import java.io.IOException
 import java.lang.{Long => JLong}
-import java.text.SimpleDateFormat
-import java.text.NumberFormat
+import java.text.{NumberFormat, SimpleDateFormat}
 import java.util.concurrent.{Callable, TimeUnit}
-import java.util.{ArrayList, Collections, Date, List => JList}
+import java.util.{Date, List => JList}
 
 import scala.collection.JavaConversions._
 import scala.collection.mutable
@@ -43,12 +42,13 @@ import parquet.io.ParquetDecodingException
 import parquet.schema.MessageType
 
 import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mapred.SparkHadoopMapRedUtil
 import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.SQLConf
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row, _}
 import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.StructType
 import org.apache.spark.{Logging, SerializableWritable, TaskContext}
 
 /**
@@ -356,7 +356,7 @@ private[sql] case class InsertIntoParquetTable(
       } finally {
         writer.close(hadoopContext)
       }
-      committer.commitTask(hadoopContext)
+      SparkHadoopMapRedUtil.commitTask(committer, hadoopContext, context)
       1
     }
     val jobFormat = new AppendingParquetOutputFormat(taskIdOffset)
@@ -512,6 +512,7 @@ private[parquet] class FilteringParquetRowInputFormat
 
     import parquet.filter2.compat.FilterCompat.Filter
     import parquet.filter2.compat.RowGroupFilter
+
     import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.blockLocationCache
 
     val cacheMetadata = configuration.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index 53f765ee26..19800ad88c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -42,6 +42,7 @@ import parquet.hadoop.{ParquetInputFormat, _}
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.mapred.SparkHadoopMapRedUtil
 import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
 import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD}
 import org.apache.spark.sql.catalyst.expressions
@@ -669,7 +670,8 @@ private[sql] case class ParquetRelation2(
       } finally {
         writer.close(hadoopContext)
       }
-      committer.commitTask(hadoopContext)
+
+      SparkHadoopMapRedUtil.commitTask(committer, hadoopContext, context)
     }
     val jobFormat = new AppendingParquetOutputFormat(taskIdOffset)
     /* apparently we need a TaskAttemptID to construct an OutputCommitter;
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index da53d30354..cdf012b511 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -72,7 +72,6 @@ case class InsertIntoHiveTable(
     val outputFileFormatClassName = fileSinkConf.getTableInfo.getOutputFileFormatClassName
     assert(outputFileFormatClassName != null, "Output format class not set")
     conf.value.set("mapred.output.format.class", outputFileFormatClassName)
-    conf.value.setOutputCommitter(classOf[FileOutputCommitter])
 
     FileOutputFormat.setOutputPath(
       conf.value,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
index ba2bf67aed..8398da2681 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.hive
 
-import java.io.IOException
 import java.text.NumberFormat
 import java.util.Date
 
@@ -118,19 +117,7 @@ private[hive] class SparkHiveWriterContainer(
   }
 
   protected def commit() {
-    if (committer.needsTaskCommit(taskContext)) {
-      try {
-        committer.commitTask(taskContext)
-        logInfo (taID + ": Committed")
-      } catch {
-        case e: IOException =>
-          logError("Error committing the output of task: " + taID.value, e)
-          committer.abortTask(taskContext)
-          throw e
-      }
-    } else {
-      logInfo("No need to commit output of task: " + taID.value)
-    }
+    SparkHadoopMapRedUtil.commitTask(committer, taskContext, jobID, splitID, attemptID)
   }
 
   private def setIDs(jobId: Int, splitId: Int, attemptId: Int) {
@@ -213,7 +200,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer(
       .zip(row.toSeq.takeRight(dynamicPartColNames.length))
       .map { case (col, rawVal) =>
         val string = if (rawVal == null) null else String.valueOf(rawVal)
-        val colString = 
+        val colString =
           if (string == null || string.isEmpty) {
             defaultPartName
           } else {
-- 
GitLab