diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 996109865fdc7dd233756d77661c5705ebc540c9..d980e6a15aabe41b89d673c1bbb2d49753a6e1d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -528,7 +528,7 @@ case class DataSource( columns, bucketSpec, format, - () => Unit, // No existing table needs to be refreshed. + _ => Unit, // No existing table needs to be refreshed. options, data.logicalPlan, mode) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index f0bcf94eadc960dfdb6c78f47faafcad02c2a5c6..34b77cab65def7de11ed2047f34c563ed8b2be59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, Inte import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SimpleCatalogRelation} +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation @@ -34,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} -import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, DDLUtils, ExecutedCommandExec} +import org.apache.spark.sql.execution.command.{AlterTableAddPartitionCommand, DDLUtils, ExecutedCommandExec} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -179,24 +180,30 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] { "Cannot overwrite a path that is also being read from.") } + def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = { + if (l.catalogTable.isDefined && + l.catalogTable.get.partitionColumnNames.nonEmpty && + l.catalogTable.get.partitionProviderIsHive) { + val metastoreUpdater = AlterTableAddPartitionCommand( + l.catalogTable.get.identifier, + updatedPartitions.map(p => (p, None)), + ifNotExists = true) + metastoreUpdater.run(t.sparkSession) + } + t.location.refresh() + } + val insertCmd = InsertIntoHadoopFsRelationCommand( outputPath, query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver), t.bucketSpec, t.fileFormat, - () => t.location.refresh(), + refreshPartitionsCallback, t.options, query, mode) - if (l.catalogTable.isDefined && l.catalogTable.get.partitionColumnNames.nonEmpty && - l.catalogTable.get.partitionProviderIsHive) { - // TODO(ekl) we should be more efficient here and only recover the newly added partitions - val recoverPartitionCmd = AlterTableRecoverPartitionsCommand(l.catalogTable.get.identifier) - Union(insertCmd, recoverPartitionCmd) - } else { - insertCmd - } + insertCmd } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 22dbe7149531c3d0c8f6c63c1919f2a66f4b3bf3..a1221d0ae6d27ad067b97b00f923ad4ae2dcb990 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand @@ -40,7 +41,7 @@ case class InsertIntoHadoopFsRelationCommand( partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], fileFormat: FileFormat, - refreshFunction: () => Unit, + refreshFunction: (Seq[TablePartitionSpec]) => Unit, options: Map[String, String], @transient query: LogicalPlan, mode: SaveMode) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index f66e8b4e2b5515c257b38f38e3d30d859559bbef..b51b41869bf06241480c448a961561bf80dd94c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.util.Shell import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ @@ -244,6 +245,17 @@ object PartitioningUtils { } } + /** + * Given a partition path fragment, e.g. `fieldOne=1/fieldTwo=2`, returns a parsed spec + * for that fragment, e.g. `Map(("fieldOne", "1"), ("fieldTwo", "2"))`. + */ + def parsePathFragment(pathFragment: String): TablePartitionSpec = { + pathFragment.split("/").map { kv => + val pair = kv.split("=", 2) + (unescapePathName(pair(0)), unescapePathName(pair(1))) + }.toMap + } + /** * Normalize the column names in partition specification, w.r.t. the real partition column names * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala index bd56e511d0ccf9440e9a17f2829a0254dbeb051c..0eb86fdd6caa8a29567836983c6588e0357689a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources import java.util.{Date, UUID} +import scala.collection.mutable + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ @@ -30,6 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow @@ -85,7 +88,7 @@ object WriteOutput extends Logging { hadoopConf: Configuration, partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], - refreshFunction: () => Unit, + refreshFunction: (Seq[TablePartitionSpec]) => Unit, options: Map[String, String], isAppend: Boolean): Unit = { @@ -120,7 +123,7 @@ object WriteOutput extends Logging { val committer = setupDriverCommitter(job, outputPath.toString, isAppend) try { - sparkSession.sparkContext.runJob(queryExecution.toRdd, + val updatedPartitions = sparkSession.sparkContext.runJob(queryExecution.toRdd, (taskContext: TaskContext, iter: Iterator[InternalRow]) => { executeTask( description = description, @@ -128,11 +131,11 @@ object WriteOutput extends Logging { sparkPartitionId = taskContext.partitionId(), sparkAttemptNumber = taskContext.attemptNumber(), iterator = iter) - }) + }).flatten.distinct committer.commitJob(job) logInfo(s"Job ${job.getJobID} committed.") - refreshFunction() + refreshFunction(updatedPartitions.map(PartitioningUtils.parsePathFragment)) } catch { case cause: Throwable => logError(s"Aborting job ${job.getJobID}.", cause) committer.abortJob(job, JobStatus.State.FAILED) @@ -147,7 +150,7 @@ object WriteOutput extends Logging { sparkStageId: Int, sparkPartitionId: Int, sparkAttemptNumber: Int, - iterator: Iterator[InternalRow]): Unit = { + iterator: Iterator[InternalRow]): Set[String] = { val jobId = SparkHadoopWriter.createJobID(new Date, sparkStageId) val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) @@ -187,11 +190,12 @@ object WriteOutput extends Logging { try { Utils.tryWithSafeFinallyAndFailureCallbacks(block = { // Execute the task to write rows out - writeTask.execute(iterator) + val outputPaths = writeTask.execute(iterator) writeTask.releaseResources() // Commit the task SparkHadoopMapRedUtil.commitTask(committer, taskAttemptContext, jobId.getId, taskId.getId) + outputPaths })(catchBlock = { // If there is an error, release resource and then abort the task try { @@ -213,7 +217,7 @@ object WriteOutput extends Logging { * automatically trigger task aborts. */ private trait ExecuteWriteTask { - def execute(iterator: Iterator[InternalRow]): Unit + def execute(iterator: Iterator[InternalRow]): Set[String] def releaseResources(): Unit final def filePrefix(split: Int, uuid: String, bucketId: Option[Int]): String = { @@ -240,11 +244,12 @@ object WriteOutput extends Logging { outputWriter } - override def execute(iter: Iterator[InternalRow]): Unit = { + override def execute(iter: Iterator[InternalRow]): Set[String] = { while (iter.hasNext) { val internalRow = iter.next() outputWriter.writeInternal(internalRow) } + Set.empty } override def releaseResources(): Unit = { @@ -327,7 +332,7 @@ object WriteOutput extends Logging { newWriter } - override def execute(iter: Iterator[InternalRow]): Unit = { + override def execute(iter: Iterator[InternalRow]): Set[String] = { // We should first sort by partition columns, then bucket id, and finally sorting columns. val sortingExpressions: Seq[Expression] = description.partitionColumns ++ bucketIdExpression ++ sortColumns @@ -375,6 +380,7 @@ object WriteOutput extends Logging { // If anything below fails, we should abort the task. var currentKey: UnsafeRow = null + val updatedPartitions = mutable.Set[String]() while (sortedIterator.next()) { val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] if (currentKey != nextKey) { @@ -386,6 +392,10 @@ object WriteOutput extends Logging { logDebug(s"Writing partition: $currentKey") currentWriter = newOutputWriter(currentKey, getPartitionString) + val partitionPath = getPartitionString(currentKey).getString(0) + if (partitionPath.nonEmpty) { + updatedPartitions.add(partitionPath) + } } currentWriter.writeInternal(sortedIterator.getValue) } @@ -393,6 +403,7 @@ object WriteOutput extends Logging { currentWriter.close() currentWriter = null } + updatedPartitions.toSet } override def releaseResources(): Unit = {