Skip to content
Snippets Groups Projects
Commit 776b8f17 authored by Wenchen Fan's avatar Wenchen Fan
Browse files

[SPARK-19563][SQL] avoid unnecessary sort in FileFormatWriter

## What changes were proposed in this pull request?

In `FileFormatWriter`, we will sort the input rows by partition columns and bucket id and sort columns, if we want to write data out partitioned or bucketed.

However, if the data is already sorted, we will sort it again, which is unnecssary.

This PR removes the sorting logic in `FileFormatWriter` and use `SortExec` instead. We will not add `SortExec` if the data is already sorted.

## How was this patch tested?

I did a micro benchmark manually
```
val df = spark.range(10000000).select($"id", $"id" % 10 as "part").sort("part")
spark.time(df.write.partitionBy("part").parquet("/tmp/test"))
```
The result was about 6.4 seconds before this PR, and is 5.7 seconds afterwards.

close https://github.com/apache/spark/pull/16724

Author: Wenchen Fan <wenchen@databricks.com>

Closes #16898 from cloud-fan/writer.
parent 65fe902e
No related branches found
No related tags found
No related merge requests found
...@@ -38,10 +38,9 @@ import org.apache.spark.sql.catalyst.expressions._ ...@@ -38,10 +38,9 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter} import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.util.{SerializableConfiguration, Utils}
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
/** A helper object for writing FileFormat data out to a location. */ /** A helper object for writing FileFormat data out to a location. */
...@@ -64,9 +63,9 @@ object FileFormatWriter extends Logging { ...@@ -64,9 +63,9 @@ object FileFormatWriter extends Logging {
val serializableHadoopConf: SerializableConfiguration, val serializableHadoopConf: SerializableConfiguration,
val outputWriterFactory: OutputWriterFactory, val outputWriterFactory: OutputWriterFactory,
val allColumns: Seq[Attribute], val allColumns: Seq[Attribute],
val partitionColumns: Seq[Attribute],
val dataColumns: Seq[Attribute], val dataColumns: Seq[Attribute],
val bucketSpec: Option[BucketSpec], val partitionColumns: Seq[Attribute],
val bucketIdExpression: Option[Expression],
val path: String, val path: String,
val customPartitionLocations: Map[TablePartitionSpec, String], val customPartitionLocations: Map[TablePartitionSpec, String],
val maxRecordsPerFile: Long) val maxRecordsPerFile: Long)
...@@ -108,9 +107,21 @@ object FileFormatWriter extends Logging { ...@@ -108,9 +107,21 @@ object FileFormatWriter extends Logging {
job.setOutputValueClass(classOf[InternalRow]) job.setOutputValueClass(classOf[InternalRow])
FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))
val allColumns = queryExecution.logical.output
val partitionSet = AttributeSet(partitionColumns) val partitionSet = AttributeSet(partitionColumns)
val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains) val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains)
val bucketIdExpression = bucketSpec.map { spec =>
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
// Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
// guarantee the data distribution is same between shuffle and bucketed data source, which
// enables us to only shuffle one side when join a bucketed table and a normal one.
HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
}
val sortColumns = bucketSpec.toSeq.flatMap {
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
}
// Note: prepareWrite has side effect. It sets "job". // Note: prepareWrite has side effect. It sets "job".
val outputWriterFactory = val outputWriterFactory =
fileFormat.prepareWrite(sparkSession, job, options, dataColumns.toStructType) fileFormat.prepareWrite(sparkSession, job, options, dataColumns.toStructType)
...@@ -119,23 +130,45 @@ object FileFormatWriter extends Logging { ...@@ -119,23 +130,45 @@ object FileFormatWriter extends Logging {
uuid = UUID.randomUUID().toString, uuid = UUID.randomUUID().toString,
serializableHadoopConf = new SerializableConfiguration(job.getConfiguration), serializableHadoopConf = new SerializableConfiguration(job.getConfiguration),
outputWriterFactory = outputWriterFactory, outputWriterFactory = outputWriterFactory,
allColumns = queryExecution.logical.output, allColumns = allColumns,
partitionColumns = partitionColumns,
dataColumns = dataColumns, dataColumns = dataColumns,
bucketSpec = bucketSpec, partitionColumns = partitionColumns,
bucketIdExpression = bucketIdExpression,
path = outputSpec.outputPath, path = outputSpec.outputPath,
customPartitionLocations = outputSpec.customPartitionLocations, customPartitionLocations = outputSpec.customPartitionLocations,
maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong) maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong)
.getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile) .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile)
) )
// We should first sort by partition columns, then bucket id, and finally sorting columns.
val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
// the sort order doesn't matter
val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child)
val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
false
} else {
requiredOrdering.zip(actualOrdering).forall {
case (requiredOrder, childOutputOrder) =>
requiredOrder.semanticEquals(childOutputOrder)
}
}
SQLExecution.withNewExecutionId(sparkSession, queryExecution) { SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
// This call shouldn't be put into the `try` block below because it only initializes and // This call shouldn't be put into the `try` block below because it only initializes and
// prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
committer.setupJob(job) committer.setupJob(job)
try { try {
val ret = sparkSession.sparkContext.runJob(queryExecution.toRdd, val rdd = if (orderingMatched) {
queryExecution.toRdd
} else {
SortExec(
requiredOrdering.map(SortOrder(_, Ascending)),
global = false,
child = queryExecution.executedPlan).execute()
}
val ret = sparkSession.sparkContext.runJob(rdd,
(taskContext: TaskContext, iter: Iterator[InternalRow]) => { (taskContext: TaskContext, iter: Iterator[InternalRow]) => {
executeTask( executeTask(
description = description, description = description,
...@@ -189,7 +222,7 @@ object FileFormatWriter extends Logging { ...@@ -189,7 +222,7 @@ object FileFormatWriter extends Logging {
committer.setupTask(taskAttemptContext) committer.setupTask(taskAttemptContext)
val writeTask = val writeTask =
if (description.partitionColumns.isEmpty && description.bucketSpec.isEmpty) { if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) {
new SingleDirectoryWriteTask(description, taskAttemptContext, committer) new SingleDirectoryWriteTask(description, taskAttemptContext, committer)
} else { } else {
new DynamicPartitionWriteTask(description, taskAttemptContext, committer) new DynamicPartitionWriteTask(description, taskAttemptContext, committer)
...@@ -287,31 +320,16 @@ object FileFormatWriter extends Logging { ...@@ -287,31 +320,16 @@ object FileFormatWriter extends Logging {
* multiple directories (partitions) or files (bucketing). * multiple directories (partitions) or files (bucketing).
*/ */
private class DynamicPartitionWriteTask( private class DynamicPartitionWriteTask(
description: WriteJobDescription, desc: WriteJobDescription,
taskAttemptContext: TaskAttemptContext, taskAttemptContext: TaskAttemptContext,
committer: FileCommitProtocol) extends ExecuteWriteTask { committer: FileCommitProtocol) extends ExecuteWriteTask {
// currentWriter is initialized whenever we see a new key // currentWriter is initialized whenever we see a new key
private var currentWriter: OutputWriter = _ private var currentWriter: OutputWriter = _
private val bucketColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap { /** Expressions that given partition columns build a path string like: col1=val/col2=val/... */
spec => spec.bucketColumnNames.map(c => description.allColumns.find(_.name == c).get) private def partitionPathExpression: Seq[Expression] = {
} desc.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
private val sortColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap {
spec => spec.sortColumnNames.map(c => description.allColumns.find(_.name == c).get)
}
private def bucketIdExpression: Option[Expression] = description.bucketSpec.map { spec =>
// Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
// guarantee the data distribution is same between shuffle and bucketed data source, which
// enables us to only shuffle one side when join a bucketed table and a normal one.
HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
}
/** Expressions that given a partition key build a string like: col1=val/col2=val/... */
private def partitionStringExpression: Seq[Expression] = {
description.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
// TODO: use correct timezone for partition values. // TODO: use correct timezone for partition values.
val escaped = ScalaUDF( val escaped = ScalaUDF(
ExternalCatalogUtils.escapePathName _, ExternalCatalogUtils.escapePathName _,
...@@ -325,35 +343,46 @@ object FileFormatWriter extends Logging { ...@@ -325,35 +343,46 @@ object FileFormatWriter extends Logging {
} }
/** /**
* Open and returns a new OutputWriter given a partition key and optional bucket id. * Opens a new OutputWriter given a partition key and optional bucket id.
* If bucket id is specified, we will append it to the end of the file name, but before the * If bucket id is specified, we will append it to the end of the file name, but before the
* file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet
* *
* @param key vaues for fields consisting of partition keys for the current row * @param partColsAndBucketId a row consisting of partition columns and a bucket id for the
* @param partString a function that projects the partition values into a string * current row.
* @param getPartitionPath a function that projects the partition values into a path string.
* @param fileCounter the number of files that have been written in the past for this specific * @param fileCounter the number of files that have been written in the past for this specific
* partition. This is used to limit the max number of records written for a * partition. This is used to limit the max number of records written for a
* single file. The value should start from 0. * single file. The value should start from 0.
* @param updatedPartitions the set of updated partition paths, we should add the new partition
* path of this writer to it.
*/ */
private def newOutputWriter( private def newOutputWriter(
key: InternalRow, partString: UnsafeProjection, fileCounter: Int): Unit = { partColsAndBucketId: InternalRow,
val partDir = getPartitionPath: UnsafeProjection,
if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0)) fileCounter: Int,
updatedPartitions: mutable.Set[String]): Unit = {
val partDir = if (desc.partitionColumns.isEmpty) {
None
} else {
Option(getPartitionPath(partColsAndBucketId).getString(0))
}
partDir.foreach(updatedPartitions.add)
// If the bucket spec is defined, the bucket column is right after the partition columns // If the bucketId expression is defined, the bucketId column is right after the partition
val bucketId = if (description.bucketSpec.isDefined) { // columns.
BucketingUtils.bucketIdToString(key.getInt(description.partitionColumns.length)) val bucketId = if (desc.bucketIdExpression.isDefined) {
BucketingUtils.bucketIdToString(partColsAndBucketId.getInt(desc.partitionColumns.length))
} else { } else {
"" ""
} }
// This must be in a form that matches our bucketing format. See BucketingUtils. // This must be in a form that matches our bucketing format. See BucketingUtils.
val ext = f"$bucketId.c$fileCounter%03d" + val ext = f"$bucketId.c$fileCounter%03d" +
description.outputWriterFactory.getFileExtension(taskAttemptContext) desc.outputWriterFactory.getFileExtension(taskAttemptContext)
val customPath = partDir match { val customPath = partDir match {
case Some(dir) => case Some(dir) =>
description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
case _ => case _ =>
None None
} }
...@@ -363,80 +392,42 @@ object FileFormatWriter extends Logging { ...@@ -363,80 +392,42 @@ object FileFormatWriter extends Logging {
committer.newTaskTempFile(taskAttemptContext, partDir, ext) committer.newTaskTempFile(taskAttemptContext, partDir, ext)
} }
currentWriter = description.outputWriterFactory.newInstance( currentWriter = desc.outputWriterFactory.newInstance(
path = path, path = path,
dataSchema = description.dataColumns.toStructType, dataSchema = desc.dataColumns.toStructType,
context = taskAttemptContext) context = taskAttemptContext)
} }
override def execute(iter: Iterator[InternalRow]): Set[String] = { override def execute(iter: Iterator[InternalRow]): Set[String] = {
// We should first sort by partition columns, then bucket id, and finally sorting columns. val getPartitionColsAndBucketId = UnsafeProjection.create(
val sortingExpressions: Seq[Expression] = desc.partitionColumns ++ desc.bucketIdExpression, desc.allColumns)
description.partitionColumns ++ bucketIdExpression ++ sortColumns
val getSortingKey = UnsafeProjection.create(sortingExpressions, description.allColumns)
val sortingKeySchema = StructType(sortingExpressions.map {
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
// The sorting expressions are all `Attribute` except bucket id.
case _ => StructField("bucketId", IntegerType, nullable = false)
})
// Returns the data columns to be written given an input row
val getOutputRow = UnsafeProjection.create(
description.dataColumns, description.allColumns)
// Returns the partition path given a partition key.
val getPartitionStringFunc = UnsafeProjection.create(
Seq(Concat(partitionStringExpression)), description.partitionColumns)
// Sorts the data before write, so that we only need one writer at the same time.
val sorter = new UnsafeKVExternalSorter(
sortingKeySchema,
StructType.fromAttributes(description.dataColumns),
SparkEnv.get.blockManager,
SparkEnv.get.serializerManager,
TaskContext.get().taskMemoryManager().pageSizeBytes,
SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD))
while (iter.hasNext) {
val currentRow = iter.next()
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
}
val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { // Generates the partition path given the row generated by `getPartitionColsAndBucketId`.
identity val getPartPath = UnsafeProjection.create(
} else { Seq(Concat(partitionPathExpression)), desc.partitionColumns)
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
})
}
val sortedIterator = sorter.sortedIterator() // Returns the data columns to be written given an input row
val getOutputRow = UnsafeProjection.create(desc.dataColumns, desc.allColumns)
// If anything below fails, we should abort the task. // If anything below fails, we should abort the task.
var recordsInFile: Long = 0L var recordsInFile: Long = 0L
var fileCounter = 0 var fileCounter = 0
var currentKey: UnsafeRow = null var currentPartColsAndBucketId: UnsafeRow = null
val updatedPartitions = mutable.Set[String]() val updatedPartitions = mutable.Set[String]()
while (sortedIterator.next()) { for (row <- iter) {
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] val nextPartColsAndBucketId = getPartitionColsAndBucketId(row)
if (currentKey != nextKey) { if (currentPartColsAndBucketId != nextPartColsAndBucketId) {
// See a new key - write to a new partition (new file). // See a new partition or bucket - write to a new partition dir (or a new bucket file).
currentKey = nextKey.copy() currentPartColsAndBucketId = nextPartColsAndBucketId.copy()
logDebug(s"Writing partition: $currentKey") logDebug(s"Writing partition: $currentPartColsAndBucketId")
recordsInFile = 0 recordsInFile = 0
fileCounter = 0 fileCounter = 0
releaseResources() releaseResources()
newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions)
val partitionPath = getPartitionStringFunc(currentKey).getString(0) } else if (desc.maxRecordsPerFile > 0 &&
if (partitionPath.nonEmpty) { recordsInFile >= desc.maxRecordsPerFile) {
updatedPartitions.add(partitionPath)
}
} else if (description.maxRecordsPerFile > 0 &&
recordsInFile >= description.maxRecordsPerFile) {
// Exceeded the threshold in terms of the number of records per file. // Exceeded the threshold in terms of the number of records per file.
// Create a new file by increasing the file counter. // Create a new file by increasing the file counter.
recordsInFile = 0 recordsInFile = 0
...@@ -445,10 +436,10 @@ object FileFormatWriter extends Logging { ...@@ -445,10 +436,10 @@ object FileFormatWriter extends Logging {
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")
releaseResources() releaseResources()
newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions)
} }
currentWriter.write(sortedIterator.getValue) currentWriter.write(getOutputRow(row))
recordsInFile += 1 recordsInFile += 1
} }
releaseResources() releaseResources()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment