Skip to content
Snippets Groups Projects
Commit 7f9ec19e authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-18021][SQL] Refactor file name specification for data sources

## What changes were proposed in this pull request?
Currently each data source OutputWriter is responsible for specifying the entire file name for each file output. This, however, does not make any sense because we rely on file naming schemes for certain behaviors in Spark SQL, e.g. bucket id. The current approach allows individual data sources to break the implementation of bucketing.

On the flip side, we also don't want to move file naming entirely out of data sources, because different data sources do want to specify different extensions.

This patch divides file name specification into two parts: the first part is a prefix specified by the caller of OutputWriter (in WriteOutput), and the second part is the suffix that can be specified by the OutputWriter itself. Note that a side effect of this change is that now all file based data sources also support bucketing automatically.

There are also some other minor cleanups:

- Removed the UUID passed through generic Configuration string
- Some minor rewrites for better clarity
- Renamed "path" in multiple places to "stagingDir", to more accurately reflect its meaning

## How was this patch tested?
This should be covered by existing data source tests.

Author: Reynold Xin <rxin@databricks.com>

Closes #15562 from rxin/SPARK-18021.
parent 947f4f25
No related branches found
No related tags found
No related merge requests found
Showing
with 99 additions and 143 deletions
...@@ -40,7 +40,8 @@ import org.apache.spark.sql.types._ ...@@ -40,7 +40,8 @@ import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util.SerializableConfiguration
private[libsvm] class LibSVMOutputWriter( private[libsvm] class LibSVMOutputWriter(
path: String, stagingDir: String,
fileNamePrefix: String,
dataSchema: StructType, dataSchema: StructType,
context: TaskAttemptContext) context: TaskAttemptContext)
extends OutputWriter { extends OutputWriter {
...@@ -50,11 +51,7 @@ private[libsvm] class LibSVMOutputWriter( ...@@ -50,11 +51,7 @@ private[libsvm] class LibSVMOutputWriter(
private val recordWriter: RecordWriter[NullWritable, Text] = { private val recordWriter: RecordWriter[NullWritable, Text] = {
new TextOutputFormat[NullWritable, Text]() { new TextOutputFormat[NullWritable, Text]() {
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
val configuration = context.getConfiguration new Path(stagingDir, fileNamePrefix + extension)
val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID)
val taskAttemptId = context.getTaskAttemptID
val split = taskAttemptId.getTaskID.getId
new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension")
} }
}.getRecordWriter(context) }.getRecordWriter(context)
} }
...@@ -132,12 +129,11 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour ...@@ -132,12 +129,11 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
dataSchema: StructType): OutputWriterFactory = { dataSchema: StructType): OutputWriterFactory = {
new OutputWriterFactory { new OutputWriterFactory {
override def newInstance( override def newInstance(
path: String, stagingDir: String,
bucketId: Option[Int], fileNamePrefix: String,
dataSchema: StructType, dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = { context: TaskAttemptContext): OutputWriter = {
if (bucketId.isDefined) { sys.error("LibSVM doesn't support bucketing") } new LibSVMOutputWriter(stagingDir, fileNamePrefix, dataSchema, context)
new LibSVMOutputWriter(path, dataSchema, context)
} }
} }
} }
......
...@@ -34,18 +34,23 @@ abstract class OutputWriterFactory extends Serializable { ...@@ -34,18 +34,23 @@ abstract class OutputWriterFactory extends Serializable {
* When writing to a [[HadoopFsRelation]], this method gets called by each task on executor side * When writing to a [[HadoopFsRelation]], this method gets called by each task on executor side
* to instantiate new [[OutputWriter]]s. * to instantiate new [[OutputWriter]]s.
* *
* @param path Path of the file to which this [[OutputWriter]] is supposed to write. Note that * @param stagingDir Base path (directory) of the file to which this [[OutputWriter]] is supposed
* this may not point to the final output file. For example, `FileOutputFormat` writes to * to write. Note that this may not point to the final output file. For
* temporary directories and then merge written files back to the final destination. In * example, `FileOutputFormat` writes to temporary directories and then merge
* this case, `path` points to a temporary output file under the temporary directory. * written files back to the final destination. In this case, `path` points to
* a temporary output file under the temporary directory.
* @param fileNamePrefix Prefix of the file name. The returned OutputWriter must make sure this
* prefix is used in the actual file name. For example, if the prefix is
* "part-1-2-3", then the file name must start with "part_1_2_3" but can
* end in arbitrary extension.
* @param dataSchema Schema of the rows to be written. Partition columns are not included in the * @param dataSchema Schema of the rows to be written. Partition columns are not included in the
* schema if the relation being written is partitioned. * schema if the relation being written is partitioned.
* @param context The Hadoop MapReduce task context. * @param context The Hadoop MapReduce task context.
* @since 1.4.0 * @since 1.4.0
*/ */
def newInstance( def newInstance(
path: String, stagingDir: String,
bucketId: Option[Int], // TODO: This doesn't belong here... fileNamePrefix: String,
dataSchema: StructType, dataSchema: StructType,
context: TaskAttemptContext): OutputWriter context: TaskAttemptContext): OutputWriter
......
...@@ -46,6 +46,7 @@ object WriteOutput extends Logging { ...@@ -46,6 +46,7 @@ object WriteOutput extends Logging {
/** A shared job description for all the write tasks. */ /** A shared job description for all the write tasks. */
private class WriteJobDescription( private class WriteJobDescription(
val uuid: String, // prevent collision between different (appending) write jobs
val serializableHadoopConf: SerializableConfiguration, val serializableHadoopConf: SerializableConfiguration,
val outputWriterFactory: OutputWriterFactory, val outputWriterFactory: OutputWriterFactory,
val allColumns: Seq[Attribute], val allColumns: Seq[Attribute],
...@@ -102,6 +103,7 @@ object WriteOutput extends Logging { ...@@ -102,6 +103,7 @@ object WriteOutput extends Logging {
fileFormat.prepareWrite(sparkSession, job, options, dataColumns.toStructType) fileFormat.prepareWrite(sparkSession, job, options, dataColumns.toStructType)
val description = new WriteJobDescription( val description = new WriteJobDescription(
uuid = UUID.randomUUID().toString,
serializableHadoopConf = new SerializableConfiguration(job.getConfiguration), serializableHadoopConf = new SerializableConfiguration(job.getConfiguration),
outputWriterFactory = outputWriterFactory, outputWriterFactory = outputWriterFactory,
allColumns = plan.output, allColumns = plan.output,
...@@ -213,6 +215,11 @@ object WriteOutput extends Logging { ...@@ -213,6 +215,11 @@ object WriteOutput extends Logging {
private trait ExecuteWriteTask { private trait ExecuteWriteTask {
def execute(iterator: Iterator[InternalRow]): Unit def execute(iterator: Iterator[InternalRow]): Unit
def releaseResources(): Unit def releaseResources(): Unit
final def filePrefix(split: Int, uuid: String, bucketId: Option[Int]): String = {
val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")
f"part-r-$split%05d-$uuid$bucketString"
}
} }
/** Writes data to a single directory (used for non-dynamic-partition writes). */ /** Writes data to a single directory (used for non-dynamic-partition writes). */
...@@ -222,9 +229,11 @@ object WriteOutput extends Logging { ...@@ -222,9 +229,11 @@ object WriteOutput extends Logging {
stagingPath: String) extends ExecuteWriteTask { stagingPath: String) extends ExecuteWriteTask {
private[this] var outputWriter: OutputWriter = { private[this] var outputWriter: OutputWriter = {
val split = taskAttemptContext.getTaskAttemptID.getTaskID.getId
val outputWriter = description.outputWriterFactory.newInstance( val outputWriter = description.outputWriterFactory.newInstance(
path = stagingPath, stagingDir = stagingPath,
bucketId = None, fileNamePrefix = filePrefix(split, description.uuid, None),
dataSchema = description.nonPartitionColumns.toStructType, dataSchema = description.nonPartitionColumns.toStructType,
context = taskAttemptContext) context = taskAttemptContext)
outputWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType) outputWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType)
...@@ -287,29 +296,31 @@ object WriteOutput extends Logging { ...@@ -287,29 +296,31 @@ object WriteOutput extends Logging {
} }
} }
private def getBucketIdFromKey(key: InternalRow): Option[Int] =
description.bucketSpec.map { _ => key.getInt(description.partitionColumns.length) }
/** /**
* Open and returns a new OutputWriter given a partition key and optional bucket id. * Open and returns a new OutputWriter given a partition key and optional bucket id.
* If bucket id is specified, we will append it to the end of the file name, but before the * If bucket id is specified, we will append it to the end of the file name, but before the
* file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
*/ */
private def newOutputWriter( private def newOutputWriter(key: InternalRow, partString: UnsafeProjection): OutputWriter = {
key: InternalRow,
getPartitionString: UnsafeProjection): OutputWriter = {
val path = val path =
if (description.partitionColumns.nonEmpty) { if (description.partitionColumns.nonEmpty) {
val partitionPath = getPartitionString(key).getString(0) val partitionPath = partString(key).getString(0)
new Path(stagingPath, partitionPath).toString new Path(stagingPath, partitionPath).toString
} else { } else {
stagingPath stagingPath
} }
val bucketId = getBucketIdFromKey(key)
// If the bucket spec is defined, the bucket column is right after the partition columns
val bucketId = if (description.bucketSpec.isDefined) {
Some(key.getInt(description.partitionColumns.length))
} else {
None
}
val split = taskAttemptContext.getTaskAttemptID.getTaskID.getId
val newWriter = description.outputWriterFactory.newInstance( val newWriter = description.outputWriterFactory.newInstance(
path = path, stagingDir = path,
bucketId = bucketId, fileNamePrefix = filePrefix(split, description.uuid, bucketId),
dataSchema = description.nonPartitionColumns.toStructType, dataSchema = description.nonPartitionColumns.toStructType,
context = taskAttemptContext) context = taskAttemptContext)
newWriter.initConverter(description.nonPartitionColumns.toStructType) newWriter.initConverter(description.nonPartitionColumns.toStructType)
...@@ -319,7 +330,7 @@ object WriteOutput extends Logging { ...@@ -319,7 +330,7 @@ object WriteOutput extends Logging {
override def execute(iter: Iterator[InternalRow]): Unit = { override def execute(iter: Iterator[InternalRow]): Unit = {
// We should first sort by partition columns, then bucket id, and finally sorting columns. // We should first sort by partition columns, then bucket id, and finally sorting columns.
val sortingExpressions: Seq[Expression] = val sortingExpressions: Seq[Expression] =
description.partitionColumns ++ bucketIdExpression ++ sortColumns description.partitionColumns ++ bucketIdExpression ++ sortColumns
val getSortingKey = UnsafeProjection.create(sortingExpressions, description.allColumns) val getSortingKey = UnsafeProjection.create(sortingExpressions, description.allColumns)
val sortingKeySchema = StructType(sortingExpressions.map { val sortingKeySchema = StructType(sortingExpressions.map {
...@@ -333,8 +344,8 @@ object WriteOutput extends Logging { ...@@ -333,8 +344,8 @@ object WriteOutput extends Logging {
description.nonPartitionColumns, description.allColumns) description.nonPartitionColumns, description.allColumns)
// Returns the partition path given a partition key. // Returns the partition path given a partition key.
val getPartitionString = val getPartitionString = UnsafeProjection.create(
UnsafeProjection.create(Seq(Concat(partitionStringExpression)), description.partitionColumns) Seq(Concat(partitionStringExpression)), description.partitionColumns)
// Sorts the data before write, so that we only need one writer at the same time. // Sorts the data before write, so that we only need one writer at the same time.
val sorter = new UnsafeKVExternalSorter( val sorter = new UnsafeKVExternalSorter(
...@@ -405,17 +416,6 @@ object WriteOutput extends Logging { ...@@ -405,17 +416,6 @@ object WriteOutput extends Logging {
job.getConfiguration.setBoolean("mapred.task.is.map", true) job.getConfiguration.setBoolean("mapred.task.is.map", true)
job.getConfiguration.setInt("mapred.task.partition", 0) job.getConfiguration.setInt("mapred.task.partition", 0)
// This UUID is sent to executor side together with the serialized `Configuration` object within
// the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate
// unique task output files.
// This UUID is used to avoid output file name collision between different appending write jobs.
// These jobs may belong to different SparkContext instances. Concrete data source
// implementations may use this UUID to generate unique file names (e.g.,
// `part-r-<task-id>-<job-uuid>.parquet`). The reason why this ID is used to identify a job
// rather than a single task output file is that, speculative tasks must generate the same
// output file name as the original task.
job.getConfiguration.set(WriterContainer.DATASOURCE_WRITEJOBUUID, UUID.randomUUID().toString)
val taskAttemptContext = new TaskAttemptContextImpl(job.getConfiguration, taskAttemptId) val taskAttemptContext = new TaskAttemptContextImpl(job.getConfiguration, taskAttemptId)
val outputCommitter = newOutputCommitter( val outputCommitter = newOutputCommitter(
job.getOutputFormatClass, taskAttemptContext, path, isAppend) job.getOutputFormatClass, taskAttemptContext, path, isAppend)
...@@ -474,7 +474,3 @@ object WriteOutput extends Logging { ...@@ -474,7 +474,3 @@ object WriteOutput extends Logging {
} }
} }
} }
object WriterContainer {
val DATASOURCE_WRITEJOBUUID = "spark.sql.sources.writeJobUUID"
}
...@@ -31,7 +31,7 @@ import org.apache.spark.sql._ ...@@ -31,7 +31,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile, WriterContainer} import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile}
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
object CSVRelation extends Logging { object CSVRelation extends Logging {
...@@ -170,17 +170,17 @@ object CSVRelation extends Logging { ...@@ -170,17 +170,17 @@ object CSVRelation extends Logging {
private[csv] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory { private[csv] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory {
override def newInstance( override def newInstance(
path: String, stagingDir: String,
bucketId: Option[Int], fileNamePrefix: String,
dataSchema: StructType, dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = { context: TaskAttemptContext): OutputWriter = {
if (bucketId.isDefined) sys.error("csv doesn't support bucketing") new CsvOutputWriter(stagingDir, fileNamePrefix, dataSchema, context, params)
new CsvOutputWriter(path, dataSchema, context, params)
} }
} }
private[csv] class CsvOutputWriter( private[csv] class CsvOutputWriter(
path: String, stagingDir: String,
fileNamePrefix: String,
dataSchema: StructType, dataSchema: StructType,
context: TaskAttemptContext, context: TaskAttemptContext,
params: CSVOptions) extends OutputWriter with Logging { params: CSVOptions) extends OutputWriter with Logging {
...@@ -199,11 +199,7 @@ private[csv] class CsvOutputWriter( ...@@ -199,11 +199,7 @@ private[csv] class CsvOutputWriter(
private val recordWriter: RecordWriter[NullWritable, Text] = { private val recordWriter: RecordWriter[NullWritable, Text] = {
new TextOutputFormat[NullWritable, Text]() { new TextOutputFormat[NullWritable, Text]() {
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
val configuration = context.getConfiguration new Path(stagingDir, s"$fileNamePrefix.csv$extension")
val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID)
val taskAttemptId = context.getTaskAttemptID
val split = taskAttemptId.getTaskID.getId
new Path(path, f"part-r-$split%05d-$uniqueWriteJobId.csv$extension")
} }
}.getRecordWriter(context) }.getRecordWriter(context)
} }
......
...@@ -82,11 +82,11 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { ...@@ -82,11 +82,11 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
new OutputWriterFactory { new OutputWriterFactory {
override def newInstance( override def newInstance(
path: String, stagingDir: String,
bucketId: Option[Int], fileNamePrefix: String,
dataSchema: StructType, dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = { context: TaskAttemptContext): OutputWriter = {
new JsonOutputWriter(path, parsedOptions, bucketId, dataSchema, context) new JsonOutputWriter(stagingDir, parsedOptions, fileNamePrefix, dataSchema, context)
} }
} }
} }
...@@ -153,9 +153,9 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { ...@@ -153,9 +153,9 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
} }
private[json] class JsonOutputWriter( private[json] class JsonOutputWriter(
path: String, stagingDir: String,
options: JSONOptions, options: JSONOptions,
bucketId: Option[Int], fileNamePrefix: String,
dataSchema: StructType, dataSchema: StructType,
context: TaskAttemptContext) context: TaskAttemptContext)
extends OutputWriter with Logging { extends OutputWriter with Logging {
...@@ -168,12 +168,7 @@ private[json] class JsonOutputWriter( ...@@ -168,12 +168,7 @@ private[json] class JsonOutputWriter(
private val recordWriter: RecordWriter[NullWritable, Text] = { private val recordWriter: RecordWriter[NullWritable, Text] = {
new TextOutputFormat[NullWritable, Text]() { new TextOutputFormat[NullWritable, Text]() {
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
val configuration = context.getConfiguration new Path(stagingDir, s"$fileNamePrefix.json$extension")
val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID)
val taskAttemptId = context.getTaskAttemptID
val split = taskAttemptId.getTaskID.getId
val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")
new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString.json$extension")
} }
}.getRecordWriter(context) }.getRecordWriter(context)
} }
......
...@@ -27,7 +27,7 @@ import scala.util.{Failure, Try} ...@@ -27,7 +27,7 @@ import scala.util.{Failure, Try}
import org.apache.hadoop.conf.Configuration import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} import org.apache.hadoop.mapreduce.lib.input.FileSplit
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import org.apache.parquet.{Log => ApacheParquetLog} import org.apache.parquet.{Log => ApacheParquetLog}
import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.compat.FilterCompat
...@@ -45,7 +45,6 @@ import org.apache.spark.sql.catalyst.expressions._ ...@@ -45,7 +45,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser
import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
...@@ -134,10 +133,10 @@ class ParquetFileFormat ...@@ -134,10 +133,10 @@ class ParquetFileFormat
new OutputWriterFactory { new OutputWriterFactory {
override def newInstance( override def newInstance(
path: String, path: String,
bucketId: Option[Int], fileNamePrefix: String,
dataSchema: StructType, dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = { context: TaskAttemptContext): OutputWriter = {
new ParquetOutputWriter(path, bucketId, context) new ParquetOutputWriter(path, fileNamePrefix, context)
} }
} }
} }
......
...@@ -26,7 +26,7 @@ import org.apache.parquet.hadoop.util.ContextUtil ...@@ -26,7 +26,7 @@ import org.apache.parquet.hadoop.util.ContextUtil
import org.apache.spark.sql.Row import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.{BucketingUtils, OutputWriter, OutputWriterFactory, WriterContainer} import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory}
import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util.SerializableConfiguration
...@@ -122,13 +122,12 @@ private[parquet] class ParquetOutputWriterFactory( ...@@ -122,13 +122,12 @@ private[parquet] class ParquetOutputWriterFactory(
} }
/** Disable the use of the older API. */ /** Disable the use of the older API. */
def newInstance( override def newInstance(
path: String, path: String,
bucketId: Option[Int], fileNamePrefix: String,
dataSchema: StructType, dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = { context: TaskAttemptContext): OutputWriter = {
throw new UnsupportedOperationException( throw new UnsupportedOperationException("this version of newInstance not supported for " +
"this version of newInstance not supported for " +
"ParquetOutputWriterFactory") "ParquetOutputWriterFactory")
} }
} }
...@@ -136,33 +135,16 @@ private[parquet] class ParquetOutputWriterFactory( ...@@ -136,33 +135,16 @@ private[parquet] class ParquetOutputWriterFactory(
// NOTE: This class is instantiated and used on executor side only, no need to be serializable. // NOTE: This class is instantiated and used on executor side only, no need to be serializable.
private[parquet] class ParquetOutputWriter( private[parquet] class ParquetOutputWriter(
path: String, stagingDir: String,
bucketId: Option[Int], fileNamePrefix: String,
context: TaskAttemptContext) context: TaskAttemptContext)
extends OutputWriter { extends OutputWriter {
private val recordWriter: RecordWriter[Void, InternalRow] = { private val recordWriter: RecordWriter[Void, InternalRow] = {
val outputFormat = { val outputFormat = {
new ParquetOutputFormat[InternalRow]() { new ParquetOutputFormat[InternalRow]() {
// Here we override `getDefaultWorkFile` for two reasons:
//
// 1. To allow appending. We need to generate unique output file names to avoid
// overwriting existing files (either exist before the write job, or are just written
// by other tasks within the same write job).
//
// 2. To allow dynamic partitioning. Default `getDefaultWorkFile` uses
// `FileOutputCommitter.getWorkPath()`, which points to the base directory of all
// partitions in the case of dynamic partitioning.
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
val configuration = context.getConfiguration new Path(stagingDir, fileNamePrefix + extension)
val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID)
val taskAttemptId = context.getTaskAttemptID
val split = taskAttemptId.getTaskID.getId
val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")
// It has the `.parquet` extension at the end because (de)compression tools
// such as gunzip would not be able to decompress this as the compression
// is not applied on this whole file but on each "page" in Parquet format.
new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension")
} }
} }
} }
......
...@@ -73,14 +73,11 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { ...@@ -73,14 +73,11 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
new OutputWriterFactory { new OutputWriterFactory {
override def newInstance( override def newInstance(
path: String, stagingDir: String,
bucketId: Option[Int], fileNamePrefix: String,
dataSchema: StructType, dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = { context: TaskAttemptContext): OutputWriter = {
if (bucketId.isDefined) { new TextOutputWriter(stagingDir, fileNamePrefix, dataSchema, context)
throw new AnalysisException("Text doesn't support bucketing")
}
new TextOutputWriter(path, dataSchema, context)
} }
} }
} }
...@@ -124,7 +121,11 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { ...@@ -124,7 +121,11 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
} }
} }
class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext) class TextOutputWriter(
stagingDir: String,
fileNamePrefix: String,
dataSchema: StructType,
context: TaskAttemptContext)
extends OutputWriter { extends OutputWriter {
private[this] val buffer = new Text() private[this] val buffer = new Text()
...@@ -132,11 +133,7 @@ class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemp ...@@ -132,11 +133,7 @@ class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemp
private val recordWriter: RecordWriter[NullWritable, Text] = { private val recordWriter: RecordWriter[NullWritable, Text] = {
new TextOutputFormat[NullWritable, Text]() { new TextOutputFormat[NullWritable, Text]() {
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
val configuration = context.getConfiguration new Path(stagingDir, s"$fileNamePrefix.txt$extension")
val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID)
val taskAttemptId = context.getTaskAttemptID
val split = taskAttemptId.getTaskID.getId
new Path(path, f"part-r-$split%05d-$uniqueWriteJobId.txt$extension")
} }
}.getRecordWriter(context) }.getRecordWriter(context)
} }
......
...@@ -83,11 +83,11 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable ...@@ -83,11 +83,11 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable
new OutputWriterFactory { new OutputWriterFactory {
override def newInstance( override def newInstance(
path: String, stagingDir: String,
bucketId: Option[Int], fileNamePrefix: String,
dataSchema: StructType, dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = { context: TaskAttemptContext): OutputWriter = {
new OrcOutputWriter(path, bucketId, dataSchema, context) new OrcOutputWriter(stagingDir, fileNamePrefix, dataSchema, context)
} }
} }
} }
...@@ -210,8 +210,8 @@ private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) ...@@ -210,8 +210,8 @@ private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration)
} }
private[orc] class OrcOutputWriter( private[orc] class OrcOutputWriter(
path: String, stagingDir: String,
bucketId: Option[Int], fileNamePrefix: String,
dataSchema: StructType, dataSchema: StructType,
context: TaskAttemptContext) context: TaskAttemptContext)
extends OutputWriter { extends OutputWriter {
...@@ -226,10 +226,7 @@ private[orc] class OrcOutputWriter( ...@@ -226,10 +226,7 @@ private[orc] class OrcOutputWriter(
private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { private lazy val recordWriter: RecordWriter[NullWritable, Writable] = {
recordWriterInstantiated = true recordWriterInstantiated = true
val uniqueWriteJobId = conf.get(WriterContainer.DATASOURCE_WRITEJOBUUID)
val taskAttemptId = context.getTaskAttemptID
val partition = taskAttemptId.getTaskID.getId
val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")
val compressionExtension = { val compressionExtension = {
val name = conf.get(OrcRelation.ORC_COMPRESSION) val name = conf.get(OrcRelation.ORC_COMPRESSION)
OrcRelation.extensionsForCompressionCodecNames.getOrElse(name, "") OrcRelation.extensionsForCompressionCodecNames.getOrElse(name, "")
...@@ -237,12 +234,12 @@ private[orc] class OrcOutputWriter( ...@@ -237,12 +234,12 @@ private[orc] class OrcOutputWriter(
// It has the `.orc` extension at the end because (de)compression tools // It has the `.orc` extension at the end because (de)compression tools
// such as gunzip would not be able to decompress this as the compression // such as gunzip would not be able to decompress this as the compression
// is not applied on this whole file but on each "stream" in ORC format. // is not applied on this whole file but on each "stream" in ORC format.
val filename = f"part-r-$partition%05d-$uniqueWriteJobId$bucketString$compressionExtension.orc" val filename = s"$fileNamePrefix$compressionExtension.orc"
new OrcOutputFormat().getRecordWriter( new OrcOutputFormat().getRecordWriter(
new Path(path, filename).getFileSystem(conf), new Path(stagingDir, filename).getFileSystem(conf),
conf.asInstanceOf[JobConf], conf.asInstanceOf[JobConf],
new Path(path, filename).toString, new Path(stagingDir, filename).toString,
Reporter.NULL Reporter.NULL
).asInstanceOf[RecordWriter[NullWritable, Writable]] ).asInstanceOf[RecordWriter[NullWritable, Writable]]
} }
......
...@@ -54,11 +54,6 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle ...@@ -54,11 +54,6 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
intercept[AnalysisException](df.write.bucketBy(2, "i").sortBy("j").saveAsTable("tt")) intercept[AnalysisException](df.write.bucketBy(2, "i").sortBy("j").saveAsTable("tt"))
} }
test("write bucketed data to unsupported data source") {
val df = Seq(Tuple1("a"), Tuple1("b")).toDF("i")
intercept[SparkException](df.write.bucketBy(3, "i").format("text").saveAsTable("tt"))
}
test("write bucketed data using save()") { test("write bucketed data using save()") {
val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j")
......
...@@ -39,11 +39,11 @@ class CommitFailureTestSource extends SimpleTextSource { ...@@ -39,11 +39,11 @@ class CommitFailureTestSource extends SimpleTextSource {
dataSchema: StructType): OutputWriterFactory = dataSchema: StructType): OutputWriterFactory =
new OutputWriterFactory { new OutputWriterFactory {
override def newInstance( override def newInstance(
path: String, stagingDir: String,
bucketId: Option[Int], fileNamePrefix: String,
dataSchema: StructType, dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = { context: TaskAttemptContext): OutputWriter = {
new SimpleTextOutputWriter(path, context) { new SimpleTextOutputWriter(stagingDir, fileNamePrefix, context) {
var failed = false var failed = false
TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) => TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
failed = true failed = true
......
...@@ -23,7 +23,7 @@ import org.apache.hadoop.conf.Configuration ...@@ -23,7 +23,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
import org.apache.spark.sql.{sources, Row, SparkSession} import org.apache.spark.sql.{sources, Row, SparkSession}
import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.{expressions, InternalRow}
...@@ -51,11 +51,11 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { ...@@ -51,11 +51,11 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister {
SimpleTextRelation.lastHadoopConf = Option(job.getConfiguration) SimpleTextRelation.lastHadoopConf = Option(job.getConfiguration)
new OutputWriterFactory { new OutputWriterFactory {
override def newInstance( override def newInstance(
path: String, stagingDir: String,
bucketId: Option[Int], fileNamePrefix: String,
dataSchema: StructType, dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = { context: TaskAttemptContext): OutputWriter = {
new SimpleTextOutputWriter(path, context) new SimpleTextOutputWriter(stagingDir, fileNamePrefix, context)
} }
} }
} }
...@@ -120,9 +120,11 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { ...@@ -120,9 +120,11 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister {
} }
} }
class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends OutputWriter { class SimpleTextOutputWriter(
stagingDir: String, fileNamePrefix: String, context: TaskAttemptContext)
extends OutputWriter {
private val recordWriter: RecordWriter[NullWritable, Text] = private val recordWriter: RecordWriter[NullWritable, Text] =
new AppendingTextOutputFormat(new Path(path)).getRecordWriter(context) new AppendingTextOutputFormat(new Path(stagingDir), fileNamePrefix).getRecordWriter(context)
override def write(row: Row): Unit = { override def write(row: Row): Unit = {
val serialized = row.toSeq.map { v => val serialized = row.toSeq.map { v =>
...@@ -136,19 +138,15 @@ class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends ...@@ -136,19 +138,15 @@ class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) extends
} }
} }
class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullWritable, Text] { class AppendingTextOutputFormat(stagingDir: Path, fileNamePrefix: String)
val numberFormat = NumberFormat.getInstance() extends TextOutputFormat[NullWritable, Text] {
val numberFormat = NumberFormat.getInstance()
numberFormat.setMinimumIntegerDigits(5) numberFormat.setMinimumIntegerDigits(5)
numberFormat.setGroupingUsed(false) numberFormat.setGroupingUsed(false)
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
val configuration = context.getConfiguration new Path(stagingDir, fileNamePrefix)
val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID)
val taskAttemptId = context.getTaskAttemptID
val split = taskAttemptId.getTaskID.getId
val name = FileOutputFormat.getOutputName(context)
new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId")
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment