diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index be13cbc51a9d3f2f636e9e48d57f414aa3043b6a..644358493e2ebc616002ec04ae682350dce93e64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -38,10 +38,9 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter} -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution} +import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** A helper object for writing FileFormat data out to a location. */ @@ -64,9 +63,9 @@ object FileFormatWriter extends Logging { val serializableHadoopConf: SerializableConfiguration, val outputWriterFactory: OutputWriterFactory, val allColumns: Seq[Attribute], - val partitionColumns: Seq[Attribute], val dataColumns: Seq[Attribute], - val bucketSpec: Option[BucketSpec], + val partitionColumns: Seq[Attribute], + val bucketIdExpression: Option[Expression], val path: String, val customPartitionLocations: Map[TablePartitionSpec, String], val maxRecordsPerFile: Long) @@ -108,9 +107,21 @@ object FileFormatWriter extends Logging { job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) + val allColumns = queryExecution.logical.output val partitionSet = AttributeSet(partitionColumns) val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains) + val bucketIdExpression = bucketSpec.map { spec => + val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) + // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can + // guarantee the data distribution is same between shuffle and bucketed data source, which + // enables us to only shuffle one side when join a bucketed table and a normal one. + HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression + } + val sortColumns = bucketSpec.toSeq.flatMap { + spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) + } + // Note: prepareWrite has side effect. It sets "job". val outputWriterFactory = fileFormat.prepareWrite(sparkSession, job, options, dataColumns.toStructType) @@ -119,23 +130,45 @@ object FileFormatWriter extends Logging { uuid = UUID.randomUUID().toString, serializableHadoopConf = new SerializableConfiguration(job.getConfiguration), outputWriterFactory = outputWriterFactory, - allColumns = queryExecution.logical.output, - partitionColumns = partitionColumns, + allColumns = allColumns, dataColumns = dataColumns, - bucketSpec = bucketSpec, + partitionColumns = partitionColumns, + bucketIdExpression = bucketIdExpression, path = outputSpec.outputPath, customPartitionLocations = outputSpec.customPartitionLocations, maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong) .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile) ) + // We should first sort by partition columns, then bucket id, and finally sorting columns. + val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns + // the sort order doesn't matter + val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child) + val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { + false + } else { + requiredOrdering.zip(actualOrdering).forall { + case (requiredOrder, childOutputOrder) => + requiredOrder.semanticEquals(childOutputOrder) + } + } + SQLExecution.withNewExecutionId(sparkSession, queryExecution) { // This call shouldn't be put into the `try` block below because it only initializes and // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. committer.setupJob(job) try { - val ret = sparkSession.sparkContext.runJob(queryExecution.toRdd, + val rdd = if (orderingMatched) { + queryExecution.toRdd + } else { + SortExec( + requiredOrdering.map(SortOrder(_, Ascending)), + global = false, + child = queryExecution.executedPlan).execute() + } + + val ret = sparkSession.sparkContext.runJob(rdd, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( description = description, @@ -189,7 +222,7 @@ object FileFormatWriter extends Logging { committer.setupTask(taskAttemptContext) val writeTask = - if (description.partitionColumns.isEmpty && description.bucketSpec.isEmpty) { + if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) { new SingleDirectoryWriteTask(description, taskAttemptContext, committer) } else { new DynamicPartitionWriteTask(description, taskAttemptContext, committer) @@ -287,31 +320,16 @@ object FileFormatWriter extends Logging { * multiple directories (partitions) or files (bucketing). */ private class DynamicPartitionWriteTask( - description: WriteJobDescription, + desc: WriteJobDescription, taskAttemptContext: TaskAttemptContext, committer: FileCommitProtocol) extends ExecuteWriteTask { // currentWriter is initialized whenever we see a new key private var currentWriter: OutputWriter = _ - private val bucketColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap { - spec => spec.bucketColumnNames.map(c => description.allColumns.find(_.name == c).get) - } - - private val sortColumns: Seq[Attribute] = description.bucketSpec.toSeq.flatMap { - spec => spec.sortColumnNames.map(c => description.allColumns.find(_.name == c).get) - } - - private def bucketIdExpression: Option[Expression] = description.bucketSpec.map { spec => - // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can - // guarantee the data distribution is same between shuffle and bucketed data source, which - // enables us to only shuffle one side when join a bucketed table and a normal one. - HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression - } - - /** Expressions that given a partition key build a string like: col1=val/col2=val/... */ - private def partitionStringExpression: Seq[Expression] = { - description.partitionColumns.zipWithIndex.flatMap { case (c, i) => + /** Expressions that given partition columns build a path string like: col1=val/col2=val/... */ + private def partitionPathExpression: Seq[Expression] = { + desc.partitionColumns.zipWithIndex.flatMap { case (c, i) => // TODO: use correct timezone for partition values. val escaped = ScalaUDF( ExternalCatalogUtils.escapePathName _, @@ -325,35 +343,46 @@ object FileFormatWriter extends Logging { } /** - * Open and returns a new OutputWriter given a partition key and optional bucket id. + * Opens a new OutputWriter given a partition key and optional bucket id. * If bucket id is specified, we will append it to the end of the file name, but before the * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet * - * @param key vaues for fields consisting of partition keys for the current row - * @param partString a function that projects the partition values into a string + * @param partColsAndBucketId a row consisting of partition columns and a bucket id for the + * current row. + * @param getPartitionPath a function that projects the partition values into a path string. * @param fileCounter the number of files that have been written in the past for this specific * partition. This is used to limit the max number of records written for a * single file. The value should start from 0. + * @param updatedPartitions the set of updated partition paths, we should add the new partition + * path of this writer to it. */ private def newOutputWriter( - key: InternalRow, partString: UnsafeProjection, fileCounter: Int): Unit = { - val partDir = - if (description.partitionColumns.isEmpty) None else Option(partString(key).getString(0)) + partColsAndBucketId: InternalRow, + getPartitionPath: UnsafeProjection, + fileCounter: Int, + updatedPartitions: mutable.Set[String]): Unit = { + val partDir = if (desc.partitionColumns.isEmpty) { + None + } else { + Option(getPartitionPath(partColsAndBucketId).getString(0)) + } + partDir.foreach(updatedPartitions.add) - // If the bucket spec is defined, the bucket column is right after the partition columns - val bucketId = if (description.bucketSpec.isDefined) { - BucketingUtils.bucketIdToString(key.getInt(description.partitionColumns.length)) + // If the bucketId expression is defined, the bucketId column is right after the partition + // columns. + val bucketId = if (desc.bucketIdExpression.isDefined) { + BucketingUtils.bucketIdToString(partColsAndBucketId.getInt(desc.partitionColumns.length)) } else { "" } // This must be in a form that matches our bucketing format. See BucketingUtils. val ext = f"$bucketId.c$fileCounter%03d" + - description.outputWriterFactory.getFileExtension(taskAttemptContext) + desc.outputWriterFactory.getFileExtension(taskAttemptContext) val customPath = partDir match { case Some(dir) => - description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) + desc.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) case _ => None } @@ -363,80 +392,42 @@ object FileFormatWriter extends Logging { committer.newTaskTempFile(taskAttemptContext, partDir, ext) } - currentWriter = description.outputWriterFactory.newInstance( + currentWriter = desc.outputWriterFactory.newInstance( path = path, - dataSchema = description.dataColumns.toStructType, + dataSchema = desc.dataColumns.toStructType, context = taskAttemptContext) } override def execute(iter: Iterator[InternalRow]): Set[String] = { - // We should first sort by partition columns, then bucket id, and finally sorting columns. - val sortingExpressions: Seq[Expression] = - description.partitionColumns ++ bucketIdExpression ++ sortColumns - val getSortingKey = UnsafeProjection.create(sortingExpressions, description.allColumns) - - val sortingKeySchema = StructType(sortingExpressions.map { - case a: Attribute => StructField(a.name, a.dataType, a.nullable) - // The sorting expressions are all `Attribute` except bucket id. - case _ => StructField("bucketId", IntegerType, nullable = false) - }) - - // Returns the data columns to be written given an input row - val getOutputRow = UnsafeProjection.create( - description.dataColumns, description.allColumns) - - // Returns the partition path given a partition key. - val getPartitionStringFunc = UnsafeProjection.create( - Seq(Concat(partitionStringExpression)), description.partitionColumns) - - // Sorts the data before write, so that we only need one writer at the same time. - val sorter = new UnsafeKVExternalSorter( - sortingKeySchema, - StructType.fromAttributes(description.dataColumns), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - TaskContext.get().taskMemoryManager().pageSizeBytes, - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD)) - - while (iter.hasNext) { - val currentRow = iter.next() - sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) - } + val getPartitionColsAndBucketId = UnsafeProjection.create( + desc.partitionColumns ++ desc.bucketIdExpression, desc.allColumns) - val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) { - identity - } else { - UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map { - case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable) - }) - } + // Generates the partition path given the row generated by `getPartitionColsAndBucketId`. + val getPartPath = UnsafeProjection.create( + Seq(Concat(partitionPathExpression)), desc.partitionColumns) - val sortedIterator = sorter.sortedIterator() + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create(desc.dataColumns, desc.allColumns) // If anything below fails, we should abort the task. var recordsInFile: Long = 0L var fileCounter = 0 - var currentKey: UnsafeRow = null + var currentPartColsAndBucketId: UnsafeRow = null val updatedPartitions = mutable.Set[String]() - while (sortedIterator.next()) { - val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] - if (currentKey != nextKey) { - // See a new key - write to a new partition (new file). - currentKey = nextKey.copy() - logDebug(s"Writing partition: $currentKey") + for (row <- iter) { + val nextPartColsAndBucketId = getPartitionColsAndBucketId(row) + if (currentPartColsAndBucketId != nextPartColsAndBucketId) { + // See a new partition or bucket - write to a new partition dir (or a new bucket file). + currentPartColsAndBucketId = nextPartColsAndBucketId.copy() + logDebug(s"Writing partition: $currentPartColsAndBucketId") recordsInFile = 0 fileCounter = 0 releaseResources() - newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) - val partitionPath = getPartitionStringFunc(currentKey).getString(0) - if (partitionPath.nonEmpty) { - updatedPartitions.add(partitionPath) - } - } else if (description.maxRecordsPerFile > 0 && - recordsInFile >= description.maxRecordsPerFile) { + newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions) + } else if (desc.maxRecordsPerFile > 0 && + recordsInFile >= desc.maxRecordsPerFile) { // Exceeded the threshold in terms of the number of records per file. // Create a new file by increasing the file counter. recordsInFile = 0 @@ -445,10 +436,10 @@ object FileFormatWriter extends Logging { s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") releaseResources() - newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) + newOutputWriter(currentPartColsAndBucketId, getPartPath, fileCounter, updatedPartitions) } - currentWriter.write(sortedIterator.getValue) + currentWriter.write(getOutputRow(row)) recordsInFile += 1 } releaseResources()