diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala new file mode 100644 index 0000000000000000000000000000000000000000..fb8020585cf89eeb581900a559f4c1bdffaafabd --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import org.apache.hadoop.mapreduce._ + +import org.apache.spark.util.Utils + + +/** + * An interface to define how a single Spark job commits its outputs. Two notes: + * + * 1. Implementations must be serializable, as the committer instance instantiated on the driver + * will be used for tasks on executors. + * 2. Implementations should have a constructor with either 2 or 3 arguments: + * (jobId: String, path: String) or (jobId: String, path: String, isAppend: Boolean). + * 3. A committer should not be reused across multiple Spark jobs. + * + * The proper call sequence is: + * + * 1. Driver calls setupJob. + * 2. As part of each task's execution, executor calls setupTask and then commitTask + * (or abortTask if task failed). + * 3. When all necessary tasks completed successfully, the driver calls commitJob. If the job + * failed to execute (e.g. too many failed tasks), the job should call abortJob. + */ +abstract class FileCommitProtocol { + import FileCommitProtocol._ + + /** + * Setups up a job. Must be called on the driver before any other methods can be invoked. + */ + def setupJob(jobContext: JobContext): Unit + + /** + * Commits a job after the writes succeed. Must be called on the driver. + */ + def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit + + /** + * Aborts a job after the writes fail. Must be called on the driver. + * + * Calling this function is a best-effort attempt, because it is possible that the driver + * just crashes (or killed) before it can call abort. + */ + def abortJob(jobContext: JobContext): Unit + + /** + * Sets up a task within a job. + * Must be called before any other task related methods can be invoked. + */ + def setupTask(taskContext: TaskAttemptContext): Unit + + /** + * Notifies the commit protocol to add a new file, and gets back the full path that should be + * used. Must be called on the executors when running tasks. + * + * Note that the returned temp file may have an arbitrary path. The commit protocol only + * promises that the file will be at the location specified by the arguments after job commit. + * + * A full file path consists of the following parts: + * 1. the base path + * 2. some sub-directory within the base path, used to specify partitioning + * 3. file prefix, usually some unique job id with the task id + * 4. bucket id + * 5. source specific file extension, e.g. ".snappy.parquet" + * + * The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest + * are left to the commit protocol implementation to decide. + */ + def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String + + /** + * Commits a task after the writes succeed. Must be called on the executors when running tasks. + */ + def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage + + /** + * Aborts a task after the writes have failed. Must be called on the executors when running tasks. + * + * Calling this function is a best-effort attempt, because it is possible that the executor + * just crashes (or killed) before it can call abort. + */ + def abortTask(taskContext: TaskAttemptContext): Unit +} + + +object FileCommitProtocol { + class TaskCommitMessage(val obj: Any) extends Serializable + + object EmptyTaskCommitMessage extends TaskCommitMessage(null) + + /** + * Instantiates a FileCommitProtocol using the given className. + */ + def instantiate(className: String, jobId: String, outputPath: String, isAppend: Boolean) + : FileCommitProtocol = { + val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] + + // First try the one with argument (jobId: String, outputPath: String, isAppend: Boolean). + // If that doesn't exist, try the one with (jobId: string, outputPath: String). + try { + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean]) + ctor.newInstance(jobId, outputPath, isAppend.asInstanceOf[java.lang.Boolean]) + } catch { + case _: NoSuchMethodException => + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) + ctor.newInstance(jobId, outputPath) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala new file mode 100644 index 0000000000000000000000000000000000000000..66ccb6d437708a31f3dc9353ee849966f2dcaa61 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import java.util.Date + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + +import org.apache.spark.SparkHadoopWriter +import org.apache.spark.internal.Logging +import org.apache.spark.mapred.SparkHadoopMapRedUtil + +/** + * An [[FileCommitProtocol]] implementation backed by an underlying Hadoop OutputCommitter + * (from the newer mapreduce API, not the old mapred API). + * + * Unlike Hadoop's OutputCommitter, this implementation is serializable. + */ +class HadoopMapReduceCommitProtocol(jobId: String, path: String) + extends FileCommitProtocol with Serializable with Logging { + + import FileCommitProtocol._ + + /** OutputCommitter from Hadoop is not serializable so marking it transient. */ + @transient private var committer: OutputCommitter = _ + + protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { + context.getOutputFormatClass.newInstance().getOutputCommitter(context) + } + + override def newTaskTempFile( + taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { + // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet + // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, + // the file name is fine and won't overflow. + val split = taskContext.getTaskAttemptID.getTaskID.getId + val filename = f"part-$split%05d-$jobId$ext" + + val stagingDir: String = committer match { + // For FileOutputCommitter it has its own staging path called "work path". + case f: FileOutputCommitter => Option(f.getWorkPath.toString).getOrElse(path) + case _ => path + } + + dir.map { d => + new Path(new Path(stagingDir, d), filename).toString + }.getOrElse { + new Path(stagingDir, filename).toString + } + } + + override def setupJob(jobContext: JobContext): Unit = { + // Setup IDs + val jobId = SparkHadoopWriter.createJobID(new Date, 0) + val taskId = new TaskID(jobId, TaskType.MAP, 0) + val taskAttemptId = new TaskAttemptID(taskId, 0) + + // Set up the configuration object + jobContext.getConfiguration.set("mapred.job.id", jobId.toString) + jobContext.getConfiguration.set("mapred.tip.id", taskAttemptId.getTaskID.toString) + jobContext.getConfiguration.set("mapred.task.id", taskAttemptId.toString) + jobContext.getConfiguration.setBoolean("mapred.task.is.map", true) + jobContext.getConfiguration.setInt("mapred.task.partition", 0) + + val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId) + committer = setupCommitter(taskAttemptContext) + committer.setupJob(jobContext) + } + + override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { + committer.commitJob(jobContext) + } + + override def abortJob(jobContext: JobContext): Unit = { + committer.abortJob(jobContext, JobStatus.State.FAILED) + } + + override def setupTask(taskContext: TaskAttemptContext): Unit = { + committer = setupCommitter(taskContext) + committer.setupTask(taskContext) + } + + override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { + val attemptId = taskContext.getTaskAttemptID + SparkHadoopMapRedUtil.commitTask( + committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) + EmptyTaskCommitMessage + } + + override def abortTask(taskContext: TaskAttemptContext): Unit = { + committer.abortTask(taskContext) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCommitProtocol.scala deleted file mode 100644 index f5dd5ce22919d540ff5bb071c735a26b80942bb5..0000000000000000000000000000000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileCommitProtocol.scala +++ /dev/null @@ -1,257 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources - -import java.util.{Date, UUID} - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter -import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl - -import org.apache.spark.SparkHadoopWriter -import org.apache.spark.internal.Logging -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.Utils - - -object FileCommitProtocol { - class TaskCommitMessage(val obj: Any) extends Serializable - - object EmptyTaskCommitMessage extends TaskCommitMessage(null) - - /** - * Instantiates a FileCommitProtocol using the given className. - */ - def instantiate(className: String, outputPath: String, isAppend: Boolean): FileCommitProtocol = { - try { - val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] - - // First try the one with argument (outputPath: String, isAppend: Boolean). - // If that doesn't exist, try the one with (outputPath: String). - try { - val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[Boolean]) - ctor.newInstance(outputPath, isAppend.asInstanceOf[java.lang.Boolean]) - } catch { - case _: NoSuchMethodException => - val ctor = clazz.getDeclaredConstructor(classOf[String]) - ctor.newInstance(outputPath) - } - } catch { - case e: ClassNotFoundException => - throw e - } - } -} - - -/** - * An interface to define how a single Spark job commits its outputs. Two notes: - * - * 1. Implementations must be serializable, as the committer instance instantiated on the driver - * will be used for tasks on executors. - * 2. A committer should not be reused across multiple Spark jobs. - * - * The proper call sequence is: - * - * 1. Driver calls setupJob. - * 2. As part of each task's execution, executor calls setupTask and then commitTask - * (or abortTask if task failed). - * 3. When all necessary tasks completed successfully, the driver calls commitJob. If the job - * failed to execute (e.g. too many failed tasks), the job should call abortJob. - */ -abstract class FileCommitProtocol { - import FileCommitProtocol._ - - /** - * Setups up a job. Must be called on the driver before any other methods can be invoked. - */ - def setupJob(jobContext: JobContext): Unit - - /** - * Commits a job after the writes succeed. Must be called on the driver. - */ - def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit - - /** - * Aborts a job after the writes fail. Must be called on the driver. - * - * Calling this function is a best-effort attempt, because it is possible that the driver - * just crashes (or killed) before it can call abort. - */ - def abortJob(jobContext: JobContext): Unit - - /** - * Sets up a task within a job. - * Must be called before any other task related methods can be invoked. - */ - def setupTask(taskContext: TaskAttemptContext): Unit - - /** - * Notifies the commit protocol to add a new file, and gets back the full path that should be - * used. Must be called on the executors when running tasks. - * - * Note that the returned temp file may have an arbitrary path. The commit protocol only - * promises that the file will be at the location specified by the arguments after job commit. - * - * A full file path consists of the following parts: - * 1. the base path - * 2. some sub-directory within the base path, used to specify partitioning - * 3. file prefix, usually some unique job id with the task id - * 4. bucket id - * 5. source specific file extension, e.g. ".snappy.parquet" - * - * The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest - * are left to the commit protocol implementation to decide. - */ - def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String - - /** - * Commits a task after the writes succeed. Must be called on the executors when running tasks. - */ - def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage - - /** - * Aborts a task after the writes have failed. Must be called on the executors when running tasks. - * - * Calling this function is a best-effort attempt, because it is possible that the executor - * just crashes (or killed) before it can call abort. - */ - def abortTask(taskContext: TaskAttemptContext): Unit -} - - -/** - * An [[FileCommitProtocol]] implementation backed by an underlying Hadoop OutputCommitter - * (from the newer mapreduce API, not the old mapred API). - * - * Unlike Hadoop's OutputCommitter, this implementation is serializable. - */ -class HadoopCommitProtocolWrapper(path: String, isAppend: Boolean) - extends FileCommitProtocol with Serializable with Logging { - - import FileCommitProtocol._ - - /** OutputCommitter from Hadoop is not serializable so marking it transient. */ - @transient private var committer: OutputCommitter = _ - - /** UUID used to identify the job in file name. */ - private val uuid: String = UUID.randomUUID().toString - - private def setupCommitter(context: TaskAttemptContext): Unit = { - committer = context.getOutputFormatClass.newInstance().getOutputCommitter(context) - - if (!isAppend) { - // If we are appending data to an existing dir, we will only use the output committer - // associated with the file output format since it is not safe to use a custom - // committer for appending. For example, in S3, direct parquet output committer may - // leave partial data in the destination dir when the appending job fails. - // See SPARK-8578 for more details. - val configuration = context.getConfiguration - val clazz = - configuration.getClass(SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) - - if (clazz != null) { - logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") - - // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat - // has an associated output committer. To override this output committer, - // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. - // If a data source needs to override the output committer, it needs to set the - // output committer in prepareForWrite method. - if (classOf[FileOutputCommitter].isAssignableFrom(clazz)) { - // The specified output committer is a FileOutputCommitter. - // So, we will use the FileOutputCommitter-specified constructor. - val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) - committer = ctor.newInstance(new Path(path), context) - } else { - // The specified output committer is just an OutputCommitter. - // So, we will use the no-argument constructor. - val ctor = clazz.getDeclaredConstructor() - committer = ctor.newInstance() - } - } - } - logInfo(s"Using output committer class ${committer.getClass.getCanonicalName}") - } - - override def newTaskTempFile( - taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { - // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet - // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, - // the file name is fine and won't overflow. - val split = taskContext.getTaskAttemptID.getTaskID.getId - val filename = f"part-$split%05d-$uuid$ext" - - val stagingDir: String = committer match { - // For FileOutputCommitter it has its own staging path called "work path". - case f: FileOutputCommitter => Option(f.getWorkPath.toString).getOrElse(path) - case _ => path - } - - dir.map { d => - new Path(new Path(stagingDir, d), filename).toString - }.getOrElse { - new Path(stagingDir, filename).toString - } - } - - override def setupJob(jobContext: JobContext): Unit = { - // Setup IDs - val jobId = SparkHadoopWriter.createJobID(new Date, 0) - val taskId = new TaskID(jobId, TaskType.MAP, 0) - val taskAttemptId = new TaskAttemptID(taskId, 0) - - // Set up the configuration object - jobContext.getConfiguration.set("mapred.job.id", jobId.toString) - jobContext.getConfiguration.set("mapred.tip.id", taskAttemptId.getTaskID.toString) - jobContext.getConfiguration.set("mapred.task.id", taskAttemptId.toString) - jobContext.getConfiguration.setBoolean("mapred.task.is.map", true) - jobContext.getConfiguration.setInt("mapred.task.partition", 0) - - val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId) - setupCommitter(taskAttemptContext) - - committer.setupJob(jobContext) - } - - override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { - committer.commitJob(jobContext) - } - - override def abortJob(jobContext: JobContext): Unit = { - committer.abortJob(jobContext, JobStatus.State.FAILED) - } - - override def setupTask(taskContext: TaskAttemptContext): Unit = { - setupCommitter(taskContext) - committer.setupTask(taskContext) - } - - override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { - val attemptId = taskContext.getTaskAttemptID - SparkHadoopMapRedUtil.commitTask( - committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) - EmptyTaskCommitMessage - } - - override def abortTask(taskContext: TaskAttemptContext): Unit = { - committer.abortTask(taskContext) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index bc00a0a749c0931f77a225e4b0633b3b9767c783..e404dcd5452b9e5d72db87e271495f10cdbef3e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -29,6 +29,8 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -37,7 +39,6 @@ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{SQLExecution, UnsafeKVExternalSorter} -import org.apache.spark.sql.execution.datasources.FileCommitProtocol.TaskCommitMessage import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 230c74a47ba2a4ac213275a17a7c4de999cee025..927c0c5b95a17693399d4d3e0810045cf20251a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -21,6 +21,7 @@ import java.io.IOException import org.apache.hadoop.fs.Path +import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -86,8 +87,9 @@ case class InsertIntoHadoopFsRelationCommand( if (doInsertion) { val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, - outputPath.toString, - isAppend) + jobId = java.util.UUID.randomUUID().toString, + outputPath = outputPath.toString, + isAppend = isAppend) FileFormatWriter.write( sparkSession = sparkSession, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala new file mode 100644 index 0000000000000000000000000000000000000000..9b9ed28412cac904d92b5225f77c2994cc912ad7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{OutputCommitter, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter + +import org.apache.spark.internal.Logging +import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol +import org.apache.spark.sql.internal.SQLConf + +/** + * A variant of [[HadoopMapReduceCommitProtocol]] that allows specifying the actual + * Hadoop output committer using an option specified in SQLConf. + */ +class SQLHadoopMapReduceCommitProtocol(jobId: String, path: String, isAppend: Boolean) + extends HadoopMapReduceCommitProtocol(jobId, path) with Serializable with Logging { + + override protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { + var committer = context.getOutputFormatClass.newInstance().getOutputCommitter(context) + + if (!isAppend) { + // If we are appending data to an existing dir, we will only use the output committer + // associated with the file output format since it is not safe to use a custom + // committer for appending. For example, in S3, direct parquet output committer may + // leave partial data in the destination dir when the appending job fails. + // See SPARK-8578 for more details. + val configuration = context.getConfiguration + val clazz = + configuration.getClass(SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) + + if (clazz != null) { + logInfo(s"Using user defined output committer class ${clazz.getCanonicalName}") + + // Every output format based on org.apache.hadoop.mapreduce.lib.output.OutputFormat + // has an associated output committer. To override this output committer, + // we will first try to use the output committer set in SQLConf.OUTPUT_COMMITTER_CLASS. + // If a data source needs to override the output committer, it needs to set the + // output committer in prepareForWrite method. + if (classOf[FileOutputCommitter].isAssignableFrom(clazz)) { + // The specified output committer is a FileOutputCommitter. + // So, we will use the FileOutputCommitter-specified constructor. + val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) + committer = ctor.newInstance(new Path(path), context) + } else { + // The specified output committer is just an OutputCommitter. + // So, we will use the no-argument constructor. + val ctor = clazz.getDeclaredConstructor() + committer = ctor.newInstance() + } + } + } + logInfo(s"Using output committer class ${committer.getClass.getCanonicalName}") + committer + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index daec2b545097120c6405ab2b851f085bcf092f9b..e849cafef41843fb359caea9cd09c37fecc15c6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql.execution.streaming import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging +import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.datasources.{FileCommitProtocol, FileFormat, FileFormatWriter} +import org.apache.spark.sql.execution.datasources.{FileFormat, FileFormatWriter} object FileStreamSink { // The name of the subdirectory that is used to store metadata about which files are valid. @@ -54,7 +55,11 @@ class FileStreamSink( logInfo(s"Skipping already committed batch $batchId") } else { val committer = FileCommitProtocol.instantiate( - sparkSession.sessionState.conf.streamingFileCommitProtocolClass, path, isAppend = false) + className = sparkSession.sessionState.conf.streamingFileCommitProtocolClass, + jobId = batchId.toString, + outputPath = path, + isAppend = false) + committer match { case manifestCommitter: ManifestFileCommitProtocol => manifestCommitter.setupManifestOptions(fileLog, batchId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala index 510312267a98d6c24b1b7591fa3f7e6f05c92d1d..1fe13fa1623fc452a2a8ee105863e9a3ae39560a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala @@ -25,8 +25,8 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.spark.internal.Logging -import org.apache.spark.sql.execution.datasources.FileCommitProtocol -import org.apache.spark.sql.execution.datasources.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage /** * A [[FileCommitProtocol]] that tracks the list of valid files in a manifest file, used in @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.datasources.FileCommitProtocol.TaskCommitM * * @param path path to write the final output to. */ -class ManifestFileCommitProtocol(path: String) +class ManifestFileCommitProtocol(jobId: String, path: String) extends FileCommitProtocol with Serializable with Logging { // Track the list of files added by a task, only used on the executors. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7bb3ac02fa5d04a1b16bda01099956c6bf817a3d..7b8ed65054c3ca7569e3e000264b0bcf4fc81888 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.execution.datasources.HadoopCommitProtocolWrapper +import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol import org.apache.spark.sql.execution.streaming.ManifestFileCommitProtocol import org.apache.spark.util.Utils @@ -385,7 +385,7 @@ object SQLConf { SQLConfigBuilder("spark.sql.sources.commitProtocolClass") .internal() .stringConf - .createWithDefault(classOf[HadoopCommitProtocolWrapper].getName) + .createWithDefault(classOf[SQLHadoopMapReduceCommitProtocol].getName) val PARALLEL_PARTITION_DISCOVERY_THRESHOLD = SQLConfigBuilder("spark.sql.sources.parallelPartitionDiscovery.threshold")