Commit f313117b authored by Reynold Xin's avatar Reynold Xin Committed by Cheng Lian
[SPARK-18012][SQL] Simplify WriterContainer

## What changes were proposed in this pull request?
This patch refactors WriterContainer to simplify the logic and make control flow more obvious.The previous code setup made it pretty difficult to track the actual dependencies on variables and setups because the driver side and the executor side were using the same set of variables.

## How was this patch tested?
N/A - this should be covered by existing tests.

Author: Reynold Xin <>

Closes #15551 from rxin/writercontainer-refactor.
......@@ -20,18 +20,12 @@ package org.apache.spark.sql.execution.datasources
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.spark._
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.internal.SQLConf
* A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending.
......@@ -40,20 +34,6 @@ import org.apache.spark.sql.internal.SQLConf
* implementation of [[HadoopFsRelation]] should use this UUID together with task id to generate
* unique file path for each task output file. This UUID is passed to executor side via a
* property named `spark.sql.sources.writeJobUUID`.
* Different writer containers, [[DefaultWriterContainer]] and [[DynamicPartitionWriterContainer]]
* are used to write to normal tables and tables with dynamic partitions.
* Basic work flow of this command is:
* 1. Driver side setup, including output committer initialization and data source specific
* preparation work for the write job to be issued.
* 2. Issues a write job consists of one or more executor side tasks, each of which writes all
* rows within an RDD partition.
* 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any
* exception is thrown during task commitment, also aborts that task.
* 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is
* thrown during job commitment, also aborts the job.
case class InsertIntoHadoopFsRelationCommand(
outputPath: Path,
......@@ -103,52 +83,17 @@ case class InsertIntoHadoopFsRelationCommand(
val isAppend = pathExists && (mode == SaveMode.Append)
if (doInsertion) {
val job = Job.getInstance(hadoopConf)
FileOutputFormat.setOutputPath(job, qualifiedOutputPath)
val partitionSet = AttributeSet(partitionColumns)
val dataColumns = query.output.filterNot(partitionSet.contains)
val queryExecution = Dataset.ofRows(sparkSession, query).queryExecution
SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
val relation =
fileFormat.prepareWrite(sparkSession, _, options, dataColumns.toStructType),
val writerContainer = if (partitionColumns.isEmpty && bucketSpec.isEmpty) {
new DefaultWriterContainer(relation, job, isAppend)
} else {
new DynamicPartitionWriterContainer(
partitionColumns = partitionColumns,
dataColumns = dataColumns,
inputSchema = query.output,
// This call shouldn't be put into the `try` block below because it only initializes and
// prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
try {
sparkSession.sparkContext.runJob(queryExecution.toRdd, writerContainer.writeRows _)
} catch { case cause: Throwable =>
logError("Aborting job.", cause)
throw new SparkException("Job aborted.", cause)
} else {
logInfo("Skipping insertion into a relation that already exists.")
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
package org.apache.spark.sql.execution.datasources
import java.util.{Date, UUID}
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter}
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils}
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
/** A container for all the details required when writing to a table. */
private[datasources] case class WriteRelation(
sparkSession: SparkSession,
dataSchema: StructType,
path: String,
prepareJobForWrite: Job => OutputWriterFactory,
bucketSpec: Option[BucketSpec])
object WriterContainer {
val DATASOURCE_WRITEJOBUUID = "spark.sql.sources.writeJobUUID"
private[datasources] abstract class BaseWriterContainer(
@transient val relation: WriteRelation,
@transient private val job: Job,
isAppend: Boolean)
extends Logging with Serializable {
protected val dataSchema = relation.dataSchema
protected val serializableConf =
new SerializableConfiguration(job.getConfiguration)
// This UUID is used to avoid output file name collision between different appending write jobs.
// These jobs may belong to different SparkContext instances. Concrete data source implementations
// may use this UUID to generate unique file names (e.g., `part-r-<task-id>-<job-uuid>.parquet`).
// The reason why this ID is used to identify a job rather than a single task output file is
// that, speculative tasks must generate the same output file name as the original task.
private val uniqueWriteJobId = UUID.randomUUID()
// This is only used on driver side.
@transient private val jobContext: JobContext = job
// The following fields are initialized and used on both driver and executor side.
@transient protected var outputCommitter: OutputCommitter = _
@transient private var jobId: JobID = _
@transient private var taskId: TaskID = _
@transient private var taskAttemptId: TaskAttemptID = _
@transient protected var taskAttemptContext: TaskAttemptContext = _
protected val outputPath: String = relation.path
protected var outputWriterFactory: OutputWriterFactory = _
private var outputFormatClass: Class[_ <: OutputFormat[_, _]] = _
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit
def driverSideSetup(): Unit = {
setupIDs(0, 0, 0)
// This UUID is sent to executor side together with the serialized `Configuration` object within
// the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate
// unique task output files.
job.getConfiguration.set(WriterContainer.DATASOURCE_WRITEJOBUUID, uniqueWriteJobId.toString)
// Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor
// clones the Configuration object passed in. If we initialize the TaskAttemptContext first,
// configurations made in prepareJobForWrite(job) are not populated into the TaskAttemptContext.
// Also, the `prepareJobForWrite` call must happen before initializing output format and output
// committer, since their initialization involve the job configuration, which can be potentially
// decorated in `prepareJobForWrite`.
outputWriterFactory = relation.prepareJobForWrite(job)
taskAttemptContext = new TaskAttemptContextImpl(serializableConf.value, taskAttemptId)
outputFormatClass = job.getOutputFormatClass
outputCommitter = newOutputCommitter(taskAttemptContext)
def executorSideSetup(taskContext: TaskContext): Unit = {
setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber())
taskAttemptContext = new TaskAttemptContextImpl(serializableConf.value, taskAttemptId)
outputCommitter = newOutputCommitter(taskAttemptContext)
protected def getWorkPath: String = {
outputCommitter match {
// FileOutputCommitter writes to a temporary location returned by `getWorkPath`.
case f: MapReduceFileOutputCommitter => f.getWorkPath.toString
case _ => outputPath
protected def newOutputWriter(path: String, bucketId: Option[Int] = None): OutputWriter = {
try {
outputWriterFactory.newInstance(path, bucketId, dataSchema, taskAttemptContext)
} catch {
case e: org.apache.hadoop.fs.FileAlreadyExistsException =>
if (outputCommitter.getClass.getName.contains("Direct")) {
// SPARK-11382: DirectParquetOutputCommitter is not idempotent, meaning on retry
// attempts, the task will fail because the output file is created from a prior attempt.
// This often means the most visible error to the user is misleading. Augment the error
// to tell the user to look for the actual error.
throw new SparkException("The output file already exists but this could be due to a " +
"failure from an earlier attempt. Look through the earlier logs or stage page for " +
"the first error.\n File exists error: " + e, e)
} else {
throw e
private def newOutputCommitter(context: TaskAttemptContext): OutputCommitter = {
val defaultOutputCommitter = outputFormatClass.newInstance().getOutputCommitter(context)
if (isAppend) {
// If we are appending data to an existing dir, we will only use the output committer
// associated with the file output format since it is not safe to use a custom
// committer for appending. For example, in S3, direct parquet output committer may
// leave partial data in the destination dir when the appending job fails.
// See SPARK-8578 for more details
s"Using default output committer ${defaultOutputCommitter.getClass.getCanonicalName} " +
"for appending.")
} else {
val configuration = context.getConfiguration
val committerClass = configuration.getClass(
SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter])
Option(committerClass).map { clazz =>
logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}")
// Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat
// has an associated output committer. To override this output committer,
// we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS.
// If a data source needs to override the output committer, it needs to set the
// output committer in prepareForWrite method.
if (classOf[MapReduceFileOutputCommitter].isAssignableFrom(clazz)) {
// The specified output committer is a FileOutputCommitter.
// So, we will use the FileOutputCommitter-specified constructor.
val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext])
ctor.newInstance(new Path(outputPath), context)
} else {
// The specified output committer is just an OutputCommitter.
// So, we will use the no-argument constructor.
val ctor = clazz.getDeclaredConstructor()
}.getOrElse {
// If output committer class is not set, we will use the one associated with the
// file output format.
s"Using output committer class ${defaultOutputCommitter.getClass.getCanonicalName}")
private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = {
this.jobId = SparkHadoopWriter.createJobID(new Date, jobId)
this.taskId = new TaskID(this.jobId, TaskType.MAP, splitId)
this.taskAttemptId = new TaskAttemptID(taskId, attemptId)
private def setupConf(): Unit = {
serializableConf.value.set("", jobId.toString)
serializableConf.value.set("", taskAttemptId.getTaskID.toString)
serializableConf.value.set("", taskAttemptId.toString)
serializableConf.value.setBoolean("", true)
serializableConf.value.setInt("mapred.task.partition", 0)
def commitTask(): Unit = {
SparkHadoopMapRedUtil.commitTask(outputCommitter, taskAttemptContext, jobId.getId, taskId.getId)
def abortTask(): Unit = {
if (outputCommitter != null) {
logError(s"Task attempt $taskAttemptId aborted.")
def commitJob(): Unit = {
logInfo(s"Job $jobId committed.")
def abortJob(): Unit = {
if (outputCommitter != null) {
outputCommitter.abortJob(jobContext, JobStatus.State.FAILED)
logError(s"Job $jobId aborted.")
* A writer that writes all of the rows in a partition to a single file.
private[datasources] class DefaultWriterContainer(
relation: WriteRelation,
job: Job,
isAppend: Boolean)
extends BaseWriterContainer(relation, job, isAppend) {
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
var writer = newOutputWriter(getWorkPath)
// If anything below fails, we should abort the task.
try {
Utils.tryWithSafeFinallyAndFailureCallbacks {
while (iterator.hasNext) {
val internalRow =
}(catchBlock = abortTask())
} catch {
case t: Throwable =>
throw new SparkException("Task failed while writing rows", t)
def commitTask(): Unit = {
try {
if (writer != null) {
writer = null
} catch {
case cause: Throwable =>
// This exception will be handled in `InsertIntoHadoopFsRelation.insert$writeRows`, and
// will cause `abortTask()` to be invoked.
throw new RuntimeException("Failed to commit task", cause)
def abortTask(): Unit = {
try {
if (writer != null) {
} finally {
* A writer that dynamically opens files based on the given partition columns. Internally this is
* done by maintaining a HashMap of open files until `maxFiles` is reached. If this occurs, the
* writer externally sorts the remaining rows and then writes out them out one file at a time.
private[datasources] class DynamicPartitionWriterContainer(
relation: WriteRelation,
job: Job,
partitionColumns: Seq[Attribute],
dataColumns: Seq[Attribute],
inputSchema: Seq[Attribute],
defaultPartitionName: String,
maxOpenFiles: Int,
isAppend: Boolean)
extends BaseWriterContainer(relation, job, isAppend) {
private val bucketSpec = relation.bucketSpec
private val bucketColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap {
spec => => inputSchema.find( == c).get)
private val sortColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap {
spec => => inputSchema.find( == c).get)
private def bucketIdExpression: Option[Expression] = { spec =>
// Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
// guarantee the data distribution is same between shuffle and bucketed data source, which
// enables us to only shuffle one side when join a bucketed table and a normal one.
HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
// Expressions that given a partition key build a string like: col1=val/col2=val/...
private def partitionStringExpression: Seq[Expression] = {
partitionColumns.zipWithIndex.flatMap { case (c, i) =>
val escaped =
PartitioningUtils.escapePathName _,
Seq(Cast(c, StringType)),
val str = If(IsNull(c), Literal(defaultPartitionName), escaped)
val partitionName = Literal( + "=") :: str :: Nil
if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName
private def getBucketIdFromKey(key: InternalRow): Option[Int] = { _ =>
* Open and returns a new OutputWriter given a partition key and optional bucket id.
* If bucket id is specified, we will append it to the end of the file name, but before the
* file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
private def newOutputWriter(
key: InternalRow,
getPartitionString: UnsafeProjection): OutputWriter = {
val path = if (partitionColumns.nonEmpty) {
val partitionPath = getPartitionString(key).getString(0)
new Path(getWorkPath, partitionPath).toString
} else {
val bucketId = getBucketIdFromKey(key)
val newWriter = super.newOutputWriter(path, bucketId)
def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = {
// We should first sort by partition columns, then bucket id, and finally sorting columns.
val sortingExpressions: Seq[Expression] = partitionColumns ++ bucketIdExpression ++ sortColumns
val getSortingKey = UnsafeProjection.create(sortingExpressions, inputSchema)
val sortingKeySchema = StructType( {
case a: Attribute => StructField(, a.dataType, a.nullable)
// The sorting expressions are all `Attribute` except bucket id.
case _ => StructField("bucketId", IntegerType, nullable = false)
// Returns the data columns to be written given an input row
val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema)
// Returns the partition path given a partition key.
val getPartitionString =
UnsafeProjection.create(Concat(partitionStringExpression) :: Nil, partitionColumns)
// Sorts the data before write, so that we only need one writer at the same time.
// TODO: inject a local sort operator in planning.
val sorter = new UnsafeKVExternalSorter(
while (iterator.hasNext) {
val currentRow =
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
logInfo(s"Sorting complete. Writing out partition files one at a time.")
val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
} else {
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length) {
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
val sortedIterator = sorter.sortedIterator()
// If anything below fails, we should abort the task.
var currentWriter: OutputWriter = null
try {
Utils.tryWithSafeFinallyAndFailureCallbacks {
var currentKey: UnsafeRow = null
while ( {
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
if (currentKey != nextKey) {
if (currentWriter != null) {
currentWriter = null
currentKey = nextKey.copy()
logDebug(s"Writing partition: $currentKey")
currentWriter = newOutputWriter(currentKey, getPartitionString)
if (currentWriter != null) {
currentWriter = null
}(catchBlock = {
if (currentWriter != null) {
} catch {
case t: Throwable =>
throw new SparkException("Task failed while writing rows", t)
......@@ -339,13 +339,6 @@ object SQLConf {
.doc("The maximum number of concurrent files to open before falling back on sorting when " +
"writing out files using dynamic partitioning.")
val BUCKETING_ENABLED = SQLConfigBuilder("spark.sql.sources.bucketing.enabled")
.doc("When false, we will treat bucketed table as normal table")
......@@ -733,8 +726,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def partitionColumnTypeInferenceEnabled: Boolean =
def partitionMaxFiles: Int = getConf(PARTITION_MAX_FILES)
def parallelPartitionDiscoveryThreshold: Int =
