From 3c2fc19d478256f8dc0ae7219fdd188030218c07 Mon Sep 17 00:00:00 2001
From: Xingbo Jiang <xingbo.jiang@databricks.com>
Date: Fri, 30 Jun 2017 20:30:26 +0800
Subject: [PATCH] [SPARK-18294][CORE] Implement commit protocol to support
 `mapred` package's committer

## What changes were proposed in this pull request?

This PR makes the following changes:

- Implement a new commit protocol `HadoopMapRedCommitProtocol` which support the old `mapred` package's committer;
- Refactor SparkHadoopWriter and SparkHadoopMapReduceWriter, now they are combined together, thus we can support write through both mapred and mapreduce API by the new SparkHadoopWriter, a lot of duplicated codes are removed.

After this change, it should be pretty easy for us to support the committer from both the new and the old hadoop API at high level.

## How was this patch tested?
No major behavior change, passed the existing test cases.

Author: Xingbo Jiang <xingbo.jiang@databricks.com>

Closes #18438 from jiangxb1987/SparkHadoopWriter.
---
 .../io/HadoopMapRedCommitProtocol.scala       |  36 ++
 .../internal/io/HadoopWriteConfigUtil.scala   |  79 ++++
 .../io/SparkHadoopMapReduceWriter.scala       | 181 --------
 .../spark/internal/io/SparkHadoopWriter.scala | 393 ++++++++++++++----
 .../apache/spark/rdd/PairRDDFunctions.scala   |  72 +---
 .../spark/rdd/PairRDDFunctionsSuite.scala     |   2 +-
 .../OutputCommitCoordinatorSuite.scala        |  35 +-
 7 files changed, 461 insertions(+), 337 deletions(-)
 create mode 100644 core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala
 create mode 100644 core/src/main/scala/org/apache/spark/internal/io/HadoopWriteConfigUtil.scala
 delete mode 100644 core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala

diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala
new file mode 100644
index 0000000000..ddbd624b38
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapRedCommitProtocol.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.internal.io
+
+import org.apache.hadoop.mapred._
+import org.apache.hadoop.mapreduce.{TaskAttemptContext => NewTaskAttemptContext}
+
+/**
+ * An [[FileCommitProtocol]] implementation backed by an underlying Hadoop OutputCommitter
+ * (from the old mapred API).
+ *
+ * Unlike Hadoop's OutputCommitter, this implementation is serializable.
+ */
+class HadoopMapRedCommitProtocol(jobId: String, path: String)
+  extends HadoopMapReduceCommitProtocol(jobId, path) {
+
+  override def setupCommitter(context: NewTaskAttemptContext): OutputCommitter = {
+    val config = context.getConfiguration.asInstanceOf[JobConf]
+    config.getOutputCommitter
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopWriteConfigUtil.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopWriteConfigUtil.scala
new file mode 100644
index 0000000000..9b987e0e1b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopWriteConfigUtil.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.internal.io
+
+import scala.reflect.ClassTag
+
+import org.apache.hadoop.mapreduce._
+
+import org.apache.spark.SparkConf
+
+/**
+ * Interface for create output format/committer/writer used during saving an RDD using a Hadoop
+ * OutputFormat (both from the old mapred API and the new mapreduce API)
+ *
+ * Notes:
+ * 1. Implementations should throw [[IllegalArgumentException]] when wrong hadoop API is
+ *    referenced;
+ * 2. Implementations must be serializable, as the instance instantiated on the driver
+ *    will be used for tasks on executors;
+ * 3. Implementations should have a constructor with exactly one argument:
+ *    (conf: SerializableConfiguration) or (conf: SerializableJobConf).
+ */
+abstract class HadoopWriteConfigUtil[K, V: ClassTag] extends Serializable {
+
+  // --------------------------------------------------------------------------
+  // Create JobContext/TaskAttemptContext
+  // --------------------------------------------------------------------------
+
+  def createJobContext(jobTrackerId: String, jobId: Int): JobContext
+
+  def createTaskAttemptContext(
+      jobTrackerId: String,
+      jobId: Int,
+      splitId: Int,
+      taskAttemptId: Int): TaskAttemptContext
+
+  // --------------------------------------------------------------------------
+  // Create committer
+  // --------------------------------------------------------------------------
+
+  def createCommitter(jobId: Int): HadoopMapReduceCommitProtocol
+
+  // --------------------------------------------------------------------------
+  // Create writer
+  // --------------------------------------------------------------------------
+
+  def initWriter(taskContext: TaskAttemptContext, splitId: Int): Unit
+
+  def write(pair: (K, V)): Unit
+
+  def closeWriter(taskContext: TaskAttemptContext): Unit
+
+  // --------------------------------------------------------------------------
+  // Create OutputFormat
+  // --------------------------------------------------------------------------
+
+  def initOutputFormat(jobContext: JobContext): Unit
+
+  // --------------------------------------------------------------------------
+  // Verify hadoop config
+  // --------------------------------------------------------------------------
+
+  def assertConf(jobContext: JobContext, conf: SparkConf): Unit
+}
diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala
deleted file mode 100644
index 376ff9bb19..0000000000
--- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala
+++ /dev/null
@@ -1,181 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.internal.io
-
-import java.text.SimpleDateFormat
-import java.util.{Date, Locale}
-
-import scala.reflect.ClassTag
-import scala.util.DynamicVariable
-
-import org.apache.hadoop.conf.{Configurable, Configuration}
-import org.apache.hadoop.fs.Path
-import org.apache.hadoop.mapred.{JobConf, JobID}
-import org.apache.hadoop.mapreduce._
-import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
-
-import org.apache.spark.{SparkConf, SparkException, TaskContext}
-import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.executor.OutputMetrics
-import org.apache.spark.internal.Logging
-import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
-import org.apache.spark.rdd.RDD
-import org.apache.spark.util.{SerializableConfiguration, Utils}
-
-/**
- * A helper object that saves an RDD using a Hadoop OutputFormat
- * (from the newer mapreduce API, not the old mapred API).
- */
-private[spark]
-object SparkHadoopMapReduceWriter extends Logging {
-
-  /**
-   * Basic work flow of this command is:
-   * 1. Driver side setup, prepare the data source and hadoop configuration for the write job to
-   *    be issued.
-   * 2. Issues a write job consists of one or more executor side tasks, each of which writes all
-   *    rows within an RDD partition.
-   * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task;  If any
-   *    exception is thrown during task commitment, also aborts that task.
-   * 4. If all tasks are committed, commit the job, otherwise aborts the job;  If any exception is
-   *    thrown during job commitment, also aborts the job.
-   */
-  def write[K, V: ClassTag](
-      rdd: RDD[(K, V)],
-      hadoopConf: Configuration): Unit = {
-    // Extract context and configuration from RDD.
-    val sparkContext = rdd.context
-    val stageId = rdd.id
-    val sparkConf = rdd.conf
-    val conf = new SerializableConfiguration(hadoopConf)
-
-    // Set up a job.
-    val jobTrackerId = SparkHadoopWriterUtils.createJobTrackerID(new Date())
-    val jobAttemptId = new TaskAttemptID(jobTrackerId, stageId, TaskType.MAP, 0, 0)
-    val jobContext = new TaskAttemptContextImpl(conf.value, jobAttemptId)
-    val format = jobContext.getOutputFormatClass
-
-    if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(sparkConf)) {
-      // FileOutputFormat ignores the filesystem parameter
-      val jobFormat = format.newInstance
-      jobFormat.checkOutputSpecs(jobContext)
-    }
-
-    val committer = FileCommitProtocol.instantiate(
-      className = classOf[HadoopMapReduceCommitProtocol].getName,
-      jobId = stageId.toString,
-      outputPath = conf.value.get("mapreduce.output.fileoutputformat.outputdir"),
-      isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol]
-    committer.setupJob(jobContext)
-
-    // Try to write all RDD partitions as a Hadoop OutputFormat.
-    try {
-      val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => {
-        executeTask(
-          context = context,
-          jobTrackerId = jobTrackerId,
-          sparkStageId = context.stageId,
-          sparkPartitionId = context.partitionId,
-          sparkAttemptNumber = context.attemptNumber,
-          committer = committer,
-          hadoopConf = conf.value,
-          outputFormat = format.asInstanceOf[Class[OutputFormat[K, V]]],
-          iterator = iter)
-      })
-
-      committer.commitJob(jobContext, ret)
-      logInfo(s"Job ${jobContext.getJobID} committed.")
-    } catch {
-      case cause: Throwable =>
-        logError(s"Aborting job ${jobContext.getJobID}.", cause)
-        committer.abortJob(jobContext)
-        throw new SparkException("Job aborted.", cause)
-    }
-  }
-
-  /** Write an RDD partition out in a single Spark task. */
-  private def executeTask[K, V: ClassTag](
-      context: TaskContext,
-      jobTrackerId: String,
-      sparkStageId: Int,
-      sparkPartitionId: Int,
-      sparkAttemptNumber: Int,
-      committer: FileCommitProtocol,
-      hadoopConf: Configuration,
-      outputFormat: Class[_ <: OutputFormat[K, V]],
-      iterator: Iterator[(K, V)]): TaskCommitMessage = {
-    // Set up a task.
-    val attemptId = new TaskAttemptID(jobTrackerId, sparkStageId, TaskType.REDUCE,
-      sparkPartitionId, sparkAttemptNumber)
-    val taskContext = new TaskAttemptContextImpl(hadoopConf, attemptId)
-    committer.setupTask(taskContext)
-
-    val (outputMetrics, callback) = SparkHadoopWriterUtils.initHadoopOutputMetrics(context)
-
-    // Initiate the writer.
-    val taskFormat = outputFormat.newInstance()
-    // If OutputFormat is Configurable, we should set conf to it.
-    taskFormat match {
-      case c: Configurable => c.setConf(hadoopConf)
-      case _ => ()
-    }
-    var writer = taskFormat.getRecordWriter(taskContext)
-      .asInstanceOf[RecordWriter[K, V]]
-    require(writer != null, "Unable to obtain RecordWriter")
-    var recordsWritten = 0L
-
-    // Write all rows in RDD partition.
-    try {
-      val ret = Utils.tryWithSafeFinallyAndFailureCallbacks {
-        // Write rows out, release resource and commit the task.
-        while (iterator.hasNext) {
-          val pair = iterator.next()
-          writer.write(pair._1, pair._2)
-
-          // Update bytes written metric every few records
-          SparkHadoopWriterUtils.maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten)
-          recordsWritten += 1
-        }
-        if (writer != null) {
-          writer.close(taskContext)
-          writer = null
-        }
-        committer.commitTask(taskContext)
-      }(catchBlock = {
-        // If there is an error, release resource and then abort the task.
-        try {
-          if (writer != null) {
-            writer.close(taskContext)
-            writer = null
-          }
-        } finally {
-          committer.abortTask(taskContext)
-          logError(s"Task ${taskContext.getTaskAttemptID} aborted.")
-        }
-      })
-
-      outputMetrics.setBytesWritten(callback())
-      outputMetrics.setRecordsWritten(recordsWritten)
-
-      ret
-    } catch {
-      case t: Throwable =>
-        throw new SparkException("Task failed while writing rows", t)
-    }
-  }
-}
diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala
index acc9c38571..7d846f9354 100644
--- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala
@@ -17,143 +17,374 @@
 
 package org.apache.spark.internal.io
 
-import java.io.IOException
-import java.text.{NumberFormat, SimpleDateFormat}
+import java.text.NumberFormat
 import java.util.{Date, Locale}
 
+import scala.reflect.ClassTag
+
+import org.apache.hadoop.conf.{Configurable, Configuration}
 import org.apache.hadoop.fs.FileSystem
 import org.apache.hadoop.mapred._
-import org.apache.hadoop.mapreduce.TaskType
+import org.apache.hadoop.mapreduce.{JobContext => NewJobContext,
+OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter,
+TaskAttemptContext => NewTaskAttemptContext, TaskAttemptID => NewTaskAttemptID, TaskType}
+import org.apache.hadoop.mapreduce.task.{TaskAttemptContextImpl => NewTaskAttemptContextImpl}
 
-import org.apache.spark.SerializableWritable
+import org.apache.spark.{SerializableWritable, SparkConf, SparkException, TaskContext}
+import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.internal.Logging
-import org.apache.spark.mapred.SparkHadoopMapRedUtil
-import org.apache.spark.rdd.HadoopRDD
-import org.apache.spark.util.SerializableJobConf
+import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
+import org.apache.spark.rdd.{HadoopRDD, RDD}
+import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf, Utils}
 
 /**
- * Internal helper class that saves an RDD using a Hadoop OutputFormat.
- *
- * Saves the RDD using a JobConf, which should contain an output key class, an output value class,
- * a filename to write to, etc, exactly like in a Hadoop MapReduce job.
+ * A helper object that saves an RDD using a Hadoop OutputFormat.
+ */
+private[spark]
+object SparkHadoopWriter extends Logging {
+  import SparkHadoopWriterUtils._
+
+  /**
+   * Basic work flow of this command is:
+   * 1. Driver side setup, prepare the data source and hadoop configuration for the write job to
+   *    be issued.
+   * 2. Issues a write job consists of one or more executor side tasks, each of which writes all
+   *    rows within an RDD partition.
+   * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task;  If any
+   *    exception is thrown during task commitment, also aborts that task.
+   * 4. If all tasks are committed, commit the job, otherwise aborts the job;  If any exception is
+   *    thrown during job commitment, also aborts the job.
+   */
+  def write[K, V: ClassTag](
+      rdd: RDD[(K, V)],
+      config: HadoopWriteConfigUtil[K, V]): Unit = {
+    // Extract context and configuration from RDD.
+    val sparkContext = rdd.context
+    val stageId = rdd.id
+
+    // Set up a job.
+    val jobTrackerId = createJobTrackerID(new Date())
+    val jobContext = config.createJobContext(jobTrackerId, stageId)
+    config.initOutputFormat(jobContext)
+
+    // Assert the output format/key/value class is set in JobConf.
+    config.assertConf(jobContext, rdd.conf)
+
+    val committer = config.createCommitter(stageId)
+    committer.setupJob(jobContext)
+
+    // Try to write all RDD partitions as a Hadoop OutputFormat.
+    try {
+      val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => {
+        executeTask(
+          context = context,
+          config = config,
+          jobTrackerId = jobTrackerId,
+          sparkStageId = context.stageId,
+          sparkPartitionId = context.partitionId,
+          sparkAttemptNumber = context.attemptNumber,
+          committer = committer,
+          iterator = iter)
+      })
+
+      committer.commitJob(jobContext, ret)
+      logInfo(s"Job ${jobContext.getJobID} committed.")
+    } catch {
+      case cause: Throwable =>
+        logError(s"Aborting job ${jobContext.getJobID}.", cause)
+        committer.abortJob(jobContext)
+        throw new SparkException("Job aborted.", cause)
+    }
+  }
+
+  /** Write a RDD partition out in a single Spark task. */
+  private def executeTask[K, V: ClassTag](
+      context: TaskContext,
+      config: HadoopWriteConfigUtil[K, V],
+      jobTrackerId: String,
+      sparkStageId: Int,
+      sparkPartitionId: Int,
+      sparkAttemptNumber: Int,
+      committer: FileCommitProtocol,
+      iterator: Iterator[(K, V)]): TaskCommitMessage = {
+    // Set up a task.
+    val taskContext = config.createTaskAttemptContext(
+      jobTrackerId, sparkStageId, sparkPartitionId, sparkAttemptNumber)
+    committer.setupTask(taskContext)
+
+    val (outputMetrics, callback) = initHadoopOutputMetrics(context)
+
+    // Initiate the writer.
+    config.initWriter(taskContext, sparkPartitionId)
+    var recordsWritten = 0L
+
+    // Write all rows in RDD partition.
+    try {
+      val ret = Utils.tryWithSafeFinallyAndFailureCallbacks {
+        while (iterator.hasNext) {
+          val pair = iterator.next()
+          config.write(pair)
+
+          // Update bytes written metric every few records
+          maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten)
+          recordsWritten += 1
+        }
+
+        config.closeWriter(taskContext)
+        committer.commitTask(taskContext)
+      }(catchBlock = {
+        // If there is an error, release resource and then abort the task.
+        try {
+          config.closeWriter(taskContext)
+        } finally {
+          committer.abortTask(taskContext)
+          logError(s"Task ${taskContext.getTaskAttemptID} aborted.")
+        }
+      })
+
+      outputMetrics.setBytesWritten(callback())
+      outputMetrics.setRecordsWritten(recordsWritten)
+
+      ret
+    } catch {
+      case t: Throwable =>
+        throw new SparkException("Task failed while writing rows", t)
+    }
+  }
+}
+
+/**
+ * A helper class that reads JobConf from older mapred API, creates output Format/Committer/Writer.
  */
 private[spark]
-class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable {
+class HadoopMapRedWriteConfigUtil[K, V: ClassTag](conf: SerializableJobConf)
+  extends HadoopWriteConfigUtil[K, V] with Logging {
 
-  private val now = new Date()
-  private val conf = new SerializableJobConf(jobConf)
+  private var outputFormat: Class[_ <: OutputFormat[K, V]] = null
+  private var writer: RecordWriter[K, V] = null
 
-  private var jobID = 0
-  private var splitID = 0
-  private var attemptID = 0
-  private var jID: SerializableWritable[JobID] = null
-  private var taID: SerializableWritable[TaskAttemptID] = null
+  private def getConf: JobConf = conf.value
 
-  @transient private var writer: RecordWriter[AnyRef, AnyRef] = null
-  @transient private var format: OutputFormat[AnyRef, AnyRef] = null
-  @transient private var committer: OutputCommitter = null
-  @transient private var jobContext: JobContext = null
-  @transient private var taskContext: TaskAttemptContext = null
+  // --------------------------------------------------------------------------
+  // Create JobContext/TaskAttemptContext
+  // --------------------------------------------------------------------------
 
-  def preSetup() {
-    setIDs(0, 0, 0)
-    HadoopRDD.addLocalConfiguration("", 0, 0, 0, conf.value)
+  override def createJobContext(jobTrackerId: String, jobId: Int): NewJobContext = {
+    val jobAttemptId = new SerializableWritable(new JobID(jobTrackerId, jobId))
+    new JobContextImpl(getConf, jobAttemptId.value)
+  }
 
-    val jCtxt = getJobContext()
-    getOutputCommitter().setupJob(jCtxt)
+  override def createTaskAttemptContext(
+      jobTrackerId: String,
+      jobId: Int,
+      splitId: Int,
+      taskAttemptId: Int): NewTaskAttemptContext = {
+    // Update JobConf.
+    HadoopRDD.addLocalConfiguration(jobTrackerId, jobId, splitId, taskAttemptId, conf.value)
+    // Create taskContext.
+    val attemptId = new TaskAttemptID(jobTrackerId, jobId, TaskType.MAP, splitId, taskAttemptId)
+    new TaskAttemptContextImpl(getConf, attemptId)
   }
 
+  // --------------------------------------------------------------------------
+  // Create committer
+  // --------------------------------------------------------------------------
 
-  def setup(jobid: Int, splitid: Int, attemptid: Int) {
-    setIDs(jobid, splitid, attemptid)
-    HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(now),
-      jobid, splitID, attemptID, conf.value)
+  override def createCommitter(jobId: Int): HadoopMapReduceCommitProtocol = {
+    // Update JobConf.
+    HadoopRDD.addLocalConfiguration("", 0, 0, 0, getConf)
+    // Create commit protocol.
+    FileCommitProtocol.instantiate(
+      className = classOf[HadoopMapRedCommitProtocol].getName,
+      jobId = jobId.toString,
+      outputPath = getConf.get("mapred.output.dir"),
+      isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol]
   }
 
-  def open() {
+  // --------------------------------------------------------------------------
+  // Create writer
+  // --------------------------------------------------------------------------
+
+  override def initWriter(taskContext: NewTaskAttemptContext, splitId: Int): Unit = {
     val numfmt = NumberFormat.getInstance(Locale.US)
     numfmt.setMinimumIntegerDigits(5)
     numfmt.setGroupingUsed(false)
 
-    val outputName = "part-"  + numfmt.format(splitID)
-    val path = FileOutputFormat.getOutputPath(conf.value)
+    val outputName = "part-" + numfmt.format(splitId)
+    val path = FileOutputFormat.getOutputPath(getConf)
     val fs: FileSystem = {
       if (path != null) {
-        path.getFileSystem(conf.value)
+        path.getFileSystem(getConf)
       } else {
-        FileSystem.get(conf.value)
+        FileSystem.get(getConf)
       }
     }
 
-    getOutputCommitter().setupTask(getTaskContext())
-    writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL)
+    writer = getConf.getOutputFormat
+      .getRecordWriter(fs, getConf, outputName, Reporter.NULL)
+      .asInstanceOf[RecordWriter[K, V]]
+
+    require(writer != null, "Unable to obtain RecordWriter")
   }
 
-  def write(key: AnyRef, value: AnyRef) {
+  override def write(pair: (K, V)): Unit = {
+    require(writer != null, "Must call createWriter before write.")
+    writer.write(pair._1, pair._2)
+  }
+
+  override def closeWriter(taskContext: NewTaskAttemptContext): Unit = {
     if (writer != null) {
-      writer.write(key, value)
-    } else {
-      throw new IOException("Writer is null, open() has not been called")
+      writer.close(Reporter.NULL)
     }
   }
 
-  def close() {
-    writer.close(Reporter.NULL)
-  }
+  // --------------------------------------------------------------------------
+  // Create OutputFormat
+  // --------------------------------------------------------------------------
 
-  def commit() {
-    SparkHadoopMapRedUtil.commitTask(getOutputCommitter(), getTaskContext(), jobID, splitID)
+  override def initOutputFormat(jobContext: NewJobContext): Unit = {
+    if (outputFormat == null) {
+      outputFormat = getConf.getOutputFormat.getClass
+        .asInstanceOf[Class[_ <: OutputFormat[K, V]]]
+    }
   }
 
-  def commitJob() {
-    val cmtr = getOutputCommitter()
-    cmtr.commitJob(getJobContext())
+  private def getOutputFormat(): OutputFormat[K, V] = {
+    require(outputFormat != null, "Must call initOutputFormat first.")
+
+    outputFormat.newInstance()
   }
 
-  // ********* Private Functions *********
+  // --------------------------------------------------------------------------
+  // Verify hadoop config
+  // --------------------------------------------------------------------------
+
+  override def assertConf(jobContext: NewJobContext, conf: SparkConf): Unit = {
+    val outputFormatInstance = getOutputFormat()
+    val keyClass = getConf.getOutputKeyClass
+    val valueClass = getConf.getOutputValueClass
+    if (outputFormatInstance == null) {
+      throw new SparkException("Output format class not set")
+    }
+    if (keyClass == null) {
+      throw new SparkException("Output key class not set")
+    }
+    if (valueClass == null) {
+      throw new SparkException("Output value class not set")
+    }
+    SparkHadoopUtil.get.addCredentials(getConf)
+
+    logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " +
+      valueClass.getSimpleName + ")")
 
-  private def getOutputFormat(): OutputFormat[AnyRef, AnyRef] = {
-    if (format == null) {
-      format = conf.value.getOutputFormat()
-        .asInstanceOf[OutputFormat[AnyRef, AnyRef]]
+    if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(conf)) {
+      // FileOutputFormat ignores the filesystem parameter
+      val ignoredFs = FileSystem.get(getConf)
+      getOutputFormat().checkOutputSpecs(ignoredFs, getConf)
     }
-    format
+  }
+}
+
+/**
+ * A helper class that reads Configuration from newer mapreduce API, creates output
+ * Format/Committer/Writer.
+ */
+private[spark]
+class HadoopMapReduceWriteConfigUtil[K, V: ClassTag](conf: SerializableConfiguration)
+  extends HadoopWriteConfigUtil[K, V] with Logging {
+
+  private var outputFormat: Class[_ <: NewOutputFormat[K, V]] = null
+  private var writer: NewRecordWriter[K, V] = null
+
+  private def getConf: Configuration = conf.value
+
+  // --------------------------------------------------------------------------
+  // Create JobContext/TaskAttemptContext
+  // --------------------------------------------------------------------------
+
+  override def createJobContext(jobTrackerId: String, jobId: Int): NewJobContext = {
+    val jobAttemptId = new NewTaskAttemptID(jobTrackerId, jobId, TaskType.MAP, 0, 0)
+    new NewTaskAttemptContextImpl(getConf, jobAttemptId)
+  }
+
+  override def createTaskAttemptContext(
+      jobTrackerId: String,
+      jobId: Int,
+      splitId: Int,
+      taskAttemptId: Int): NewTaskAttemptContext = {
+    val attemptId = new NewTaskAttemptID(
+      jobTrackerId, jobId, TaskType.REDUCE, splitId, taskAttemptId)
+    new NewTaskAttemptContextImpl(getConf, attemptId)
+  }
+
+  // --------------------------------------------------------------------------
+  // Create committer
+  // --------------------------------------------------------------------------
+
+  override def createCommitter(jobId: Int): HadoopMapReduceCommitProtocol = {
+    FileCommitProtocol.instantiate(
+      className = classOf[HadoopMapReduceCommitProtocol].getName,
+      jobId = jobId.toString,
+      outputPath = getConf.get("mapreduce.output.fileoutputformat.outputdir"),
+      isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol]
   }
 
-  private def getOutputCommitter(): OutputCommitter = {
-    if (committer == null) {
-      committer = conf.value.getOutputCommitter
+  // --------------------------------------------------------------------------
+  // Create writer
+  // --------------------------------------------------------------------------
+
+  override def initWriter(taskContext: NewTaskAttemptContext, splitId: Int): Unit = {
+    val taskFormat = getOutputFormat()
+    // If OutputFormat is Configurable, we should set conf to it.
+    taskFormat match {
+      case c: Configurable => c.setConf(getConf)
+      case _ => ()
     }
-    committer
+
+    writer = taskFormat.getRecordWriter(taskContext)
+      .asInstanceOf[NewRecordWriter[K, V]]
+
+    require(writer != null, "Unable to obtain RecordWriter")
+  }
+
+  override def write(pair: (K, V)): Unit = {
+    require(writer != null, "Must call createWriter before write.")
+    writer.write(pair._1, pair._2)
   }
 
-  private def getJobContext(): JobContext = {
-    if (jobContext == null) {
-      jobContext = new JobContextImpl(conf.value, jID.value)
+  override def closeWriter(taskContext: NewTaskAttemptContext): Unit = {
+    if (writer != null) {
+      writer.close(taskContext)
+      writer = null
+    } else {
+      logWarning("Writer has been closed.")
     }
-    jobContext
   }
 
-  private def getTaskContext(): TaskAttemptContext = {
-    if (taskContext == null) {
-      taskContext = newTaskAttemptContext(conf.value, taID.value)
+  // --------------------------------------------------------------------------
+  // Create OutputFormat
+  // --------------------------------------------------------------------------
+
+  override def initOutputFormat(jobContext: NewJobContext): Unit = {
+    if (outputFormat == null) {
+      outputFormat = jobContext.getOutputFormatClass
+        .asInstanceOf[Class[_ <: NewOutputFormat[K, V]]]
     }
-    taskContext
   }
 
-  protected def newTaskAttemptContext(
-      conf: JobConf,
-      attemptId: TaskAttemptID): TaskAttemptContext = {
-    new TaskAttemptContextImpl(conf, attemptId)
+  private def getOutputFormat(): NewOutputFormat[K, V] = {
+    require(outputFormat != null, "Must call initOutputFormat first.")
+
+    outputFormat.newInstance()
   }
 
-  private def setIDs(jobid: Int, splitid: Int, attemptid: Int) {
-    jobID = jobid
-    splitID = splitid
-    attemptID = attemptid
+  // --------------------------------------------------------------------------
+  // Verify hadoop config
+  // --------------------------------------------------------------------------
 
-    jID = new SerializableWritable[JobID](SparkHadoopWriterUtils.createJobID(now, jobid))
-    taID = new SerializableWritable[TaskAttemptID](
-        new TaskAttemptID(new TaskID(jID.value, TaskType.MAP, splitID), attemptID))
+  override def assertConf(jobContext: NewJobContext, conf: SparkConf): Unit = {
+    if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(conf)) {
+      getOutputFormat().checkOutputSpecs(jobContext)
+    }
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 58762cc083..4628fa8ba2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -27,7 +27,6 @@ import scala.reflect.ClassTag
 
 import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
 import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.FileSystem
 import org.apache.hadoop.io.SequenceFile.CompressionType
 import org.apache.hadoop.io.compress.CompressionCodec
 import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat}
@@ -36,13 +35,11 @@ import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewO
 import org.apache.spark._
 import org.apache.spark.Partitioner.defaultPartitioner
 import org.apache.spark.annotation.Experimental
-import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.internal.io.{SparkHadoopMapReduceWriter, SparkHadoopWriter,
-  SparkHadoopWriterUtils}
+import org.apache.spark.internal.io._
 import org.apache.spark.internal.Logging
 import org.apache.spark.partial.{BoundedDouble, PartialResult}
 import org.apache.spark.serializer.Serializer
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf, Utils}
 import org.apache.spark.util.collection.CompactBuffer
 import org.apache.spark.util.random.StratifiedSamplingUtils
 
@@ -1082,9 +1079,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
    * result of using direct output committer with speculation enabled.
    */
   def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope {
-    SparkHadoopMapReduceWriter.write(
+    val config = new HadoopMapReduceWriteConfigUtil[K, V](new SerializableConfiguration(conf))
+    SparkHadoopWriter.write(
       rdd = self,
-      hadoopConf = conf)
+      config = config)
   }
 
   /**
@@ -1094,62 +1092,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
    * MapReduce job.
    */
   def saveAsHadoopDataset(conf: JobConf): Unit = self.withScope {
-    // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038).
-    val hadoopConf = conf
-    val outputFormatInstance = hadoopConf.getOutputFormat
-    val keyClass = hadoopConf.getOutputKeyClass
-    val valueClass = hadoopConf.getOutputValueClass
-    if (outputFormatInstance == null) {
-      throw new SparkException("Output format class not set")
-    }
-    if (keyClass == null) {
-      throw new SparkException("Output key class not set")
-    }
-    if (valueClass == null) {
-      throw new SparkException("Output value class not set")
-    }
-    SparkHadoopUtil.get.addCredentials(hadoopConf)
-
-    logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " +
-      valueClass.getSimpleName + ")")
-
-    if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(self.conf)) {
-      // FileOutputFormat ignores the filesystem parameter
-      val ignoredFs = FileSystem.get(hadoopConf)
-      hadoopConf.getOutputFormat.checkOutputSpecs(ignoredFs, hadoopConf)
-    }
-
-    val writer = new SparkHadoopWriter(hadoopConf)
-    writer.preSetup()
-
-    val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => {
-      // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
-      // around by taking a mod. We expect that no task will be attempted 2 billion times.
-      val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt
-
-      val (outputMetrics, callback) = SparkHadoopWriterUtils.initHadoopOutputMetrics(context)
-
-      writer.setup(context.stageId, context.partitionId, taskAttemptId)
-      writer.open()
-      var recordsWritten = 0L
-
-      Utils.tryWithSafeFinallyAndFailureCallbacks {
-        while (iter.hasNext) {
-          val record = iter.next()
-          writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
-
-          // Update bytes written metric every few records
-          SparkHadoopWriterUtils.maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten)
-          recordsWritten += 1
-        }
-      }(finallyBlock = writer.close())
-      writer.commit()
-      outputMetrics.setBytesWritten(callback())
-      outputMetrics.setRecordsWritten(recordsWritten)
-    }
-
-    self.context.runJob(self, writeToFile)
-    writer.commitJob()
+    val config = new HadoopMapRedWriteConfigUtil[K, V](new SerializableJobConf(conf))
+    SparkHadoopWriter.write(
+      rdd = self,
+      config = config)
   }
 
   /**
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 02df157be3..44dd955ce8 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -561,7 +561,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
       pairs.saveAsHadoopFile(
         "ignored", pairs.keyClass, pairs.valueClass, classOf[FakeFormatWithCallback], conf)
     }
-    assert(e.getMessage contains "failed to write")
+    assert(e.getCause.getMessage contains "failed to write")
 
     assert(FakeWriterWithCallback.calledBy === "write,callback,close")
     assert(FakeWriterWithCallback.exception != null, "exception should be captured")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
index e51e6a0d3f..1579b614ea 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala
@@ -18,12 +18,14 @@
 package org.apache.spark.scheduler
 
 import java.io.File
+import java.util.Date
 import java.util.concurrent.TimeoutException
 
 import scala.concurrent.duration._
 import scala.language.postfixOps
 
-import org.apache.hadoop.mapred.{JobConf, OutputCommitter, TaskAttemptContext, TaskAttemptID}
+import org.apache.hadoop.mapred._
+import org.apache.hadoop.mapreduce.TaskType
 import org.mockito.Matchers
 import org.mockito.Mockito._
 import org.mockito.invocation.InvocationOnMock
@@ -31,7 +33,7 @@ import org.mockito.stubbing.Answer
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark._
-import org.apache.spark.internal.io.SparkHadoopWriter
+import org.apache.spark.internal.io.{FileCommitProtocol, HadoopMapRedCommitProtocol, SparkHadoopWriterUtils}
 import org.apache.spark.rdd.{FakeOutputCommitter, RDD}
 import org.apache.spark.util.{ThreadUtils, Utils}
 
@@ -214,6 +216,8 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter {
  */
 private case class OutputCommitFunctions(tempDirPath: String) {
 
+  private val jobId = new SerializableWritable(SparkHadoopWriterUtils.createJobID(new Date, 0))
+
   // Mock output committer that simulates a successful commit (after commit is authorized)
   private def successfulOutputCommitter = new FakeOutputCommitter {
     override def commitTask(context: TaskAttemptContext): Unit = {
@@ -256,14 +260,23 @@ private case class OutputCommitFunctions(tempDirPath: String) {
     def jobConf = new JobConf {
       override def getOutputCommitter(): OutputCommitter = outputCommitter
     }
-    val sparkHadoopWriter = new SparkHadoopWriter(jobConf) {
-      override def newTaskAttemptContext(
-        conf: JobConf,
-        attemptId: TaskAttemptID): TaskAttemptContext = {
-        mock(classOf[TaskAttemptContext])
-      }
-    }
-    sparkHadoopWriter.setup(ctx.stageId, ctx.partitionId, ctx.attemptNumber)
-    sparkHadoopWriter.commit()
+
+    // Instantiate committer.
+    val committer = FileCommitProtocol.instantiate(
+      className = classOf[HadoopMapRedCommitProtocol].getName,
+      jobId = jobId.value.getId.toString,
+      outputPath = jobConf.get("mapred.output.dir"),
+      isAppend = false)
+
+    // Create TaskAttemptContext.
+    // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
+    // around by taking a mod. We expect that no task will be attempted 2 billion times.
+    val taskAttemptId = (ctx.taskAttemptId % Int.MaxValue).toInt
+    val attemptId = new TaskAttemptID(
+      new TaskID(jobId.value, TaskType.MAP, ctx.partitionId), taskAttemptId)
+    val taskContext = new TaskAttemptContextImpl(jobConf, attemptId)
+
+    committer.setupTask(taskContext)
+    committer.commitTask(taskContext)
   }
 }
-- 
GitLab