From 064d4315f246450043a52882fcf59e95d79701e8 Mon Sep 17 00:00:00 2001
From: Eric Liang <ekl@databricks.com>
Date: Thu, 10 Nov 2016 17:00:43 -0800
Subject: [PATCH] [SPARK-18185] Fix all forms of INSERT / OVERWRITE TABLE for
 Datasource tables

## What changes were proposed in this pull request?

As of current 2.1, INSERT OVERWRITE with dynamic partitions against a Datasource table will overwrite the entire table instead of only the partitions matching the static keys, as in Hive. It also doesn't respect custom partition locations.

This PR adds support for all these operations to Datasource tables managed by the Hive metastore. It is implemented as follows
- During planning time, the full set of partitions affected by an INSERT or OVERWRITE command is read from the Hive metastore.
- The planner identifies any partitions with custom locations and includes this in the write task metadata.
- FileFormatWriter tasks refer to this custom locations map when determining where to write for dynamic partition output.
- When the write job finishes, the set of written partitions is compared against the initial set of matched partitions, and the Hive metastore is updated to reflect the newly added / removed partitions.

It was necessary to introduce a method for staging files with absolute output paths to `FileCommitProtocol`. These files are not handled by the Hadoop output committer but are moved to their final locations when the job commits.

The overwrite behavior of legacy Datasource tables is also changed: no longer will the entire table be overwritten if a partial partition spec is present.

cc cloud-fan yhuai

## How was this patch tested?

Unit tests, existing tests.

Author: Eric Liang <ekl@databricks.com>
Author: Wenchen Fan <wenchen@databricks.com>

Closes #15814 from ericl/sc-5027.

(cherry picked from commit a3356343cbf58b930326f45721fb4ecade6f8029)
Signed-off-by: Reynold Xin <rxin@databricks.com>
---
 .../internal/io/FileCommitProtocol.scala      |  15 ++
 .../io/HadoopMapReduceCommitProtocol.scala    |  63 ++++++-
 .../sql/catalyst/parser/AstBuilder.scala      |  12 +-
 .../plans/logical/basicLogicalOperators.scala |  10 +-
 .../sql/catalyst/parser/PlanParserSuite.scala |   4 +-
 .../execution/datasources/DataSource.scala    |  20 ++-
 .../datasources/DataSourceStrategy.scala      |  94 +++++++---
 .../datasources/FileFormatWriter.scala        |  26 ++-
 .../InsertIntoHadoopFsRelationCommand.scala   |  61 ++++++-
 .../datasources/PartitioningUtils.scala       |  10 ++
 .../execution/streaming/FileStreamSink.scala  |   2 +-
 .../ManifestFileCommitProtocol.scala          |   6 +
 .../PartitionProviderCompatibilitySuite.scala | 161 +++++++++++++++++-
 13 files changed, 411 insertions(+), 73 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
index fb8020585c..afd2250c93 100644
--- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
@@ -82,9 +82,24 @@ abstract class FileCommitProtocol {
    *
    * The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest
    * are left to the commit protocol implementation to decide.
+   *
+   * Important: it is the caller's responsibility to add uniquely identifying content to "ext"
+   * if a task is going to write out multiple files to the same dir. The file commit protocol only
+   * guarantees that files written by different tasks will not conflict.
    */
   def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String
 
+  /**
+   * Similar to newTaskTempFile(), but allows files to committed to an absolute output location.
+   * Depending on the implementation, there may be weaker guarantees around adding files this way.
+   *
+   * Important: it is the caller's responsibility to add uniquely identifying content to "ext"
+   * if a task is going to write out multiple files to the same dir. The file commit protocol only
+   * guarantees that files written by different tasks will not conflict.
+   */
+  def newTaskTempFileAbsPath(
+      taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String
+
   /**
    * Commits a task after the writes succeed. Must be called on the executors when running tasks.
    */
diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
index 66ccb6d437..c99b75e523 100644
--- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.internal.io
 
-import java.util.Date
+import java.util.{Date, UUID}
+
+import scala.collection.mutable
 
 import org.apache.hadoop.fs.Path
 import org.apache.hadoop.mapreduce._
@@ -42,17 +44,26 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
   /** OutputCommitter from Hadoop is not serializable so marking it transient. */
   @transient private var committer: OutputCommitter = _
 
+  /**
+   * Tracks files staged by this task for absolute output paths. These outputs are not managed by
+   * the Hadoop OutputCommitter, so we must move these to their final locations on job commit.
+   *
+   * The mapping is from the temp output path to the final desired output path of the file.
+   */
+  @transient private var addedAbsPathFiles: mutable.Map[String, String] = null
+
+  /**
+   * The staging directory for all files committed with absolute output paths.
+   */
+  private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId)
+
   protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = {
     context.getOutputFormatClass.newInstance().getOutputCommitter(context)
   }
 
   override def newTaskTempFile(
       taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = {
-    // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet
-    // Note that %05d does not truncate the split number, so if we have more than 100000 tasks,
-    // the file name is fine and won't overflow.
-    val split = taskContext.getTaskAttemptID.getTaskID.getId
-    val filename = f"part-$split%05d-$jobId$ext"
+    val filename = getFilename(taskContext, ext)
 
     val stagingDir: String = committer match {
       // For FileOutputCommitter it has its own staging path called "work path".
@@ -67,6 +78,28 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
     }
   }
 
+  override def newTaskTempFileAbsPath(
+      taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = {
+    val filename = getFilename(taskContext, ext)
+    val absOutputPath = new Path(absoluteDir, filename).toString
+
+    // Include a UUID here to prevent file collisions for one task writing to different dirs.
+    // In principle we could include hash(absoluteDir) instead but this is simpler.
+    val tmpOutputPath = new Path(
+      absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString
+
+    addedAbsPathFiles(tmpOutputPath) = absOutputPath
+    tmpOutputPath
+  }
+
+  private def getFilename(taskContext: TaskAttemptContext, ext: String): String = {
+    // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet
+    // Note that %05d does not truncate the split number, so if we have more than 100000 tasks,
+    // the file name is fine and won't overflow.
+    val split = taskContext.getTaskAttemptID.getTaskID.getId
+    f"part-$split%05d-$jobId$ext"
+  }
+
   override def setupJob(jobContext: JobContext): Unit = {
     // Setup IDs
     val jobId = SparkHadoopWriter.createJobID(new Date, 0)
@@ -87,25 +120,41 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
 
   override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = {
     committer.commitJob(jobContext)
+    val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]])
+      .foldLeft(Map[String, String]())(_ ++ _)
+    logDebug(s"Committing files staged for absolute locations $filesToMove")
+    val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration)
+    for ((src, dst) <- filesToMove) {
+      fs.rename(new Path(src), new Path(dst))
+    }
+    fs.delete(absPathStagingDir, true)
   }
 
   override def abortJob(jobContext: JobContext): Unit = {
     committer.abortJob(jobContext, JobStatus.State.FAILED)
+    val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration)
+    fs.delete(absPathStagingDir, true)
   }
 
   override def setupTask(taskContext: TaskAttemptContext): Unit = {
     committer = setupCommitter(taskContext)
     committer.setupTask(taskContext)
+    addedAbsPathFiles = mutable.Map[String, String]()
   }
 
   override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = {
     val attemptId = taskContext.getTaskAttemptID
     SparkHadoopMapRedUtil.commitTask(
       committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId)
-    EmptyTaskCommitMessage
+    new TaskCommitMessage(addedAbsPathFiles.toMap)
   }
 
   override def abortTask(taskContext: TaskAttemptContext): Unit = {
     committer.abortTask(taskContext)
+    // best effort cleanup of other staged files
+    for ((src, _) <- addedAbsPathFiles) {
+      val tmp = new Path(src)
+      tmp.getFileSystem(taskContext.getConfiguration).delete(tmp, false)
+    }
   }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 4b151c81d8..2006844923 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -172,24 +172,20 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
     val tableIdent = visitTableIdentifier(ctx.tableIdentifier)
     val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
 
-    val dynamicPartitionKeys = partitionKeys.filter(_._2.isEmpty)
+    val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty)
     if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) {
       throw new ParseException(s"Dynamic partitions do not support IF NOT EXISTS. Specified " +
         "partitions with value: " + dynamicPartitionKeys.keys.mkString("[", ",", "]"), ctx)
     }
     val overwrite = ctx.OVERWRITE != null
-    val overwritePartition =
-      if (overwrite && partitionKeys.nonEmpty && dynamicPartitionKeys.isEmpty) {
-        Some(partitionKeys.map(t => (t._1, t._2.get)))
-      } else {
-        None
-      }
+    val staticPartitionKeys: Map[String, String] =
+      partitionKeys.filter(_._2.nonEmpty).map(t => (t._1, t._2.get))
 
     InsertIntoTable(
       UnresolvedRelation(tableIdent, None),
       partitionKeys,
       query,
-      OverwriteOptions(overwrite, overwritePartition),
+      OverwriteOptions(overwrite, if (overwrite) staticPartitionKeys else Map.empty),
       ctx.EXISTS != null)
   }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 65ceab2ce2..574caf039d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -350,13 +350,15 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
  * Options for writing new data into a table.
  *
  * @param enabled whether to overwrite existing data in the table.
- * @param specificPartition only data in the specified partition will be overwritten.
+ * @param staticPartitionKeys if non-empty, specifies that we only want to overwrite partitions
+ *                            that match this partial partition spec. If empty, all partitions
+ *                            will be overwritten.
  */
 case class OverwriteOptions(
     enabled: Boolean,
-    specificPartition: Option[CatalogTypes.TablePartitionSpec] = None) {
-  if (specificPartition.isDefined) {
-    assert(enabled, "Overwrite must be enabled when specifying a partition to overwrite.")
+    staticPartitionKeys: CatalogTypes.TablePartitionSpec = Map.empty) {
+  if (staticPartitionKeys.nonEmpty) {
+    assert(enabled, "Overwrite must be enabled when specifying specific partitions.")
   }
 }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 7400f3430e..e5f1f7b3bd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -185,9 +185,9 @@ class PlanParserSuite extends PlanTest {
         OverwriteOptions(
           overwrite,
           if (overwrite && partition.nonEmpty) {
-            Some(partition.map(kv => (kv._1, kv._2.get)))
+            partition.map(kv => (kv._1, kv._2.get))
           } else {
-            None
+            Map.empty
           }),
         ifNotExists)
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 5d663949df..65422f1495 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -417,15 +417,17 @@ case class DataSource(
         // will be adjusted within InsertIntoHadoopFsRelation.
         val plan =
           InsertIntoHadoopFsRelationCommand(
-            outputPath,
-            columns,
-            bucketSpec,
-            format,
-            _ => Unit, // No existing table needs to be refreshed.
-            options,
-            data.logicalPlan,
-            mode,
-            catalogTable)
+            outputPath = outputPath,
+            staticPartitionKeys = Map.empty,
+            customPartitionLocations = Map.empty,
+            partitionColumns = columns,
+            bucketSpec = bucketSpec,
+            fileFormat = format,
+            refreshFunction = _ => Unit, // No existing table needs to be refreshed.
+            options = options,
+            query = data.logicalPlan,
+            mode = mode,
+            catalogTable = catalogTable)
         sparkSession.sessionState.executePlan(plan).toRdd
         // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it.
         copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 739aeac877..4f19a2d00b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -24,10 +24,10 @@ import org.apache.hadoop.fs.Path
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow, TableIdentifier}
 import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
 import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SimpleCatalogRelation}
+import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTablePartition, SimpleCatalogRelation}
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.catalyst.expressions._
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
-import org.apache.spark.sql.execution.command.{AlterTableAddPartitionCommand, DDLUtils, ExecutedCommandExec}
+import org.apache.spark.sql.execution.command._
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -182,41 +182,53 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
           "Cannot overwrite a path that is also being read from.")
       }
 
-      val overwritingSinglePartition =
-        overwrite.specificPartition.isDefined &&
+      val partitionSchema = query.resolve(
+        t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver)
+      val partitionsTrackedByCatalog =
         t.sparkSession.sessionState.conf.manageFilesourcePartitions &&
+        l.catalogTable.isDefined && l.catalogTable.get.partitionColumnNames.nonEmpty &&
         l.catalogTable.get.tracksPartitionsInCatalog
 
-      val effectiveOutputPath = if (overwritingSinglePartition) {
-        val partition = t.sparkSession.sessionState.catalog.getPartition(
-          l.catalogTable.get.identifier, overwrite.specificPartition.get)
-        new Path(partition.location)
-      } else {
-        outputPath
-      }
-
-      val effectivePartitionSchema = if (overwritingSinglePartition) {
-        Nil
-      } else {
-        query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver)
+      var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil
+      var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty
+
+      // When partitions are tracked by the catalog, compute all custom partition locations that
+      // may be relevant to the insertion job.
+      if (partitionsTrackedByCatalog) {
+        val matchingPartitions = t.sparkSession.sessionState.catalog.listPartitions(
+          l.catalogTable.get.identifier, Some(overwrite.staticPartitionKeys))
+        initialMatchingPartitions = matchingPartitions.map(_.spec)
+        customPartitionLocations = getCustomPartitionLocations(
+          t.sparkSession, l.catalogTable.get, outputPath, matchingPartitions)
       }
 
+      // Callback for updating metastore partition metadata after the insertion job completes.
+      // TODO(ekl) consider moving this into InsertIntoHadoopFsRelationCommand
       def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = {
-        if (l.catalogTable.isDefined && updatedPartitions.nonEmpty &&
-            l.catalogTable.get.partitionColumnNames.nonEmpty &&
-            l.catalogTable.get.tracksPartitionsInCatalog) {
-          val metastoreUpdater = AlterTableAddPartitionCommand(
-            l.catalogTable.get.identifier,
-            updatedPartitions.map(p => (p, None)),
-            ifNotExists = true)
-          metastoreUpdater.run(t.sparkSession)
+        if (partitionsTrackedByCatalog) {
+          val newPartitions = updatedPartitions.toSet -- initialMatchingPartitions
+          if (newPartitions.nonEmpty) {
+            AlterTableAddPartitionCommand(
+              l.catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)),
+              ifNotExists = true).run(t.sparkSession)
+          }
+          if (overwrite.enabled) {
+            val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions
+            if (deletedPartitions.nonEmpty) {
+              AlterTableDropPartitionCommand(
+                l.catalogTable.get.identifier, deletedPartitions.toSeq,
+                ifExists = true, purge = true).run(t.sparkSession)
+            }
+          }
         }
         t.location.refresh()
       }
 
       val insertCmd = InsertIntoHadoopFsRelationCommand(
-        effectiveOutputPath,
-        effectivePartitionSchema,
+        outputPath,
+        if (overwrite.enabled) overwrite.staticPartitionKeys else Map.empty,
+        customPartitionLocations,
+        partitionSchema,
         t.bucketSpec,
         t.fileFormat,
         refreshPartitionsCallback,
@@ -227,6 +239,34 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
 
       insertCmd
   }
+
+  /**
+   * Given a set of input partitions, returns those that have locations that differ from the
+   * Hive default (e.g. /k1=v1/k2=v2). These partitions were manually assigned locations by
+   * the user.
+   *
+   * @return a mapping from partition specs to their custom locations
+   */
+  private def getCustomPartitionLocations(
+      spark: SparkSession,
+      table: CatalogTable,
+      basePath: Path,
+      partitions: Seq[CatalogTablePartition]): Map[TablePartitionSpec, String] = {
+    val hadoopConf = spark.sessionState.newHadoopConf
+    val fs = basePath.getFileSystem(hadoopConf)
+    val qualifiedBasePath = basePath.makeQualified(fs.getUri, fs.getWorkingDirectory)
+    partitions.flatMap { p =>
+      val defaultLocation = qualifiedBasePath.suffix(
+        "/" + PartitioningUtils.getPathFragment(p.spec, table.partitionSchema)).toString
+      val catalogLocation = new Path(p.location).makeQualified(
+        fs.getUri, fs.getWorkingDirectory).toString
+      if (catalogLocation != defaultLocation) {
+        Some(p.spec -> catalogLocation)
+      } else {
+        None
+      }
+    }.toMap
+  }
 }
 
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 0f8ed9e23f..edcce103d0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -47,6 +47,10 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
 /** A helper object for writing FileFormat data out to a location. */
 object FileFormatWriter extends Logging {
 
+  /** Describes how output files should be placed in the filesystem. */
+  case class OutputSpec(
+    outputPath: String, customPartitionLocations: Map[TablePartitionSpec, String])
+
   /** A shared job description for all the write tasks. */
   private class WriteJobDescription(
       val uuid: String,  // prevent collision between different (appending) write jobs
@@ -56,7 +60,8 @@ object FileFormatWriter extends Logging {
       val partitionColumns: Seq[Attribute],
       val nonPartitionColumns: Seq[Attribute],
       val bucketSpec: Option[BucketSpec],
-      val path: String)
+      val path: String,
+      val customPartitionLocations: Map[TablePartitionSpec, String])
     extends Serializable {
 
     assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns),
@@ -83,7 +88,7 @@ object FileFormatWriter extends Logging {
       plan: LogicalPlan,
       fileFormat: FileFormat,
       committer: FileCommitProtocol,
-      outputPath: String,
+      outputSpec: OutputSpec,
       hadoopConf: Configuration,
       partitionColumns: Seq[Attribute],
       bucketSpec: Option[BucketSpec],
@@ -93,7 +98,7 @@ object FileFormatWriter extends Logging {
     val job = Job.getInstance(hadoopConf)
     job.setOutputKeyClass(classOf[Void])
     job.setOutputValueClass(classOf[InternalRow])
-    FileOutputFormat.setOutputPath(job, new Path(outputPath))
+    FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))
 
     val partitionSet = AttributeSet(partitionColumns)
     val dataColumns = plan.output.filterNot(partitionSet.contains)
@@ -111,7 +116,8 @@ object FileFormatWriter extends Logging {
       partitionColumns = partitionColumns,
       nonPartitionColumns = dataColumns,
       bucketSpec = bucketSpec,
-      path = outputPath)
+      path = outputSpec.outputPath,
+      customPartitionLocations = outputSpec.customPartitionLocations)
 
     SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
       // This call shouldn't be put into the `try` block below because it only initializes and
@@ -308,7 +314,17 @@ object FileFormatWriter extends Logging {
       }
       val ext = bucketId + description.outputWriterFactory.getFileExtension(taskAttemptContext)
 
-      val path = committer.newTaskTempFile(taskAttemptContext, partDir, ext)
+      val customPath = partDir match {
+        case Some(dir) =>
+          description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir))
+        case _ =>
+          None
+      }
+      val path = if (customPath.isDefined) {
+        committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext)
+      } else {
+        committer.newTaskTempFile(taskAttemptContext, partDir, ext)
+      }
       val newWriter = description.outputWriterFactory.newInstance(
         path = path,
         dataSchema = description.nonPartitionColumns.toStructType,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
index a0a8cb5024..28975e1546 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources
 
 import java.io.IOException
 
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.{FileSystem, Path}
 
 import org.apache.spark.internal.io.FileCommitProtocol
 import org.apache.spark.sql._
@@ -32,19 +32,32 @@ import org.apache.spark.sql.execution.command.RunnableCommand
 /**
  * A command for writing data to a [[HadoopFsRelation]].  Supports both overwriting and appending.
  * Writing to dynamic partitions is also supported.
+ *
+ * @param staticPartitionKeys partial partitioning spec for write. This defines the scope of
+ *                            partition overwrites: when the spec is empty, all partitions are
+ *                            overwritten. When it covers a prefix of the partition keys, only
+ *                            partitions matching the prefix are overwritten.
+ * @param customPartitionLocations mapping of partition specs to their custom locations. The
+ *                                 caller should guarantee that exactly those table partitions
+ *                                 falling under the specified static partition keys are contained
+ *                                 in this map, and that no other partitions are.
  */
 case class InsertIntoHadoopFsRelationCommand(
     outputPath: Path,
+    staticPartitionKeys: TablePartitionSpec,
+    customPartitionLocations: Map[TablePartitionSpec, String],
     partitionColumns: Seq[Attribute],
     bucketSpec: Option[BucketSpec],
     fileFormat: FileFormat,
-    refreshFunction: (Seq[TablePartitionSpec]) => Unit,
+    refreshFunction: Seq[TablePartitionSpec] => Unit,
     options: Map[String, String],
     @transient query: LogicalPlan,
     mode: SaveMode,
     catalogTable: Option[CatalogTable])
   extends RunnableCommand {
 
+  import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
+
   override protected def innerChildren: Seq[LogicalPlan] = query :: Nil
 
   override def run(sparkSession: SparkSession): Seq[Row] = {
@@ -66,10 +79,7 @@ case class InsertIntoHadoopFsRelationCommand(
       case (SaveMode.ErrorIfExists, true) =>
         throw new AnalysisException(s"path $qualifiedOutputPath already exists.")
       case (SaveMode.Overwrite, true) =>
-        if (!fs.delete(qualifiedOutputPath, true /* recursively */)) {
-          throw new IOException(s"Unable to clear output " +
-            s"directory $qualifiedOutputPath prior to writing to it")
-        }
+        deleteMatchingPartitions(fs, qualifiedOutputPath)
         true
       case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) =>
         true
@@ -93,7 +103,8 @@ case class InsertIntoHadoopFsRelationCommand(
         plan = query,
         fileFormat = fileFormat,
         committer = committer,
-        outputPath = qualifiedOutputPath.toString,
+        outputSpec = FileFormatWriter.OutputSpec(
+          qualifiedOutputPath.toString, customPartitionLocations),
         hadoopConf = hadoopConf,
         partitionColumns = partitionColumns,
         bucketSpec = bucketSpec,
@@ -105,4 +116,40 @@ case class InsertIntoHadoopFsRelationCommand(
 
     Seq.empty[Row]
   }
+
+  /**
+   * Deletes all partition files that match the specified static prefix. Partitions with custom
+   * locations are also cleared based on the custom locations map given to this class.
+   */
+  private def deleteMatchingPartitions(fs: FileSystem, qualifiedOutputPath: Path): Unit = {
+    val staticPartitionPrefix = if (staticPartitionKeys.nonEmpty) {
+      "/" + partitionColumns.flatMap { p =>
+        staticPartitionKeys.get(p.name) match {
+          case Some(value) =>
+            Some(escapePathName(p.name) + "=" + escapePathName(value))
+          case None =>
+            None
+        }
+      }.mkString("/")
+    } else {
+      ""
+    }
+    // first clear the path determined by the static partition keys (e.g. /table/foo=1)
+    val staticPrefixPath = qualifiedOutputPath.suffix(staticPartitionPrefix)
+    if (fs.exists(staticPrefixPath) && !fs.delete(staticPrefixPath, true /* recursively */)) {
+      throw new IOException(s"Unable to clear output " +
+        s"directory $staticPrefixPath prior to writing to it")
+    }
+    // now clear all custom partition locations (e.g. /custom/dir/where/foo=2/bar=4)
+    for ((spec, customLoc) <- customPartitionLocations) {
+      assert(
+        (staticPartitionKeys.toSet -- spec).isEmpty,
+        "Custom partition location did not match static partitioning keys")
+      val path = new Path(customLoc)
+      if (fs.exists(path) && !fs.delete(path, true)) {
+        throw new IOException(s"Unable to clear partition " +
+          s"directory $path prior to writing to it")
+      }
+    }
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index a28b04ca3f..bf9f318780 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -62,6 +62,7 @@ object PartitioningUtils {
   }
 
   import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.DEFAULT_PARTITION_NAME
+  import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName
   import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.unescapePathName
 
   /**
@@ -252,6 +253,15 @@ object PartitioningUtils {
     }.toMap
   }
 
+  /**
+   * This is the inverse of parsePathFragment().
+   */
+  def getPathFragment(spec: TablePartitionSpec, partitionSchema: StructType): String = {
+    partitionSchema.map { field =>
+      escapePathName(field.name) + "=" + escapePathName(spec(field.name))
+    }.mkString("/")
+  }
+
   /**
    * Normalize the column names in partition specification, w.r.t. the real partition column names
    * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
index e849cafef4..f1c5f9ab50 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
@@ -80,7 +80,7 @@ class FileStreamSink(
         plan = data.logicalPlan,
         fileFormat = fileFormat,
         committer = committer,
-        outputPath = path,
+        outputSpec = FileFormatWriter.OutputSpec(path, Map.empty),
         hadoopConf = hadoopConf,
         partitionColumns = partitionColumns,
         bucketSpec = None,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala
index 1fe13fa162..92191c8b64 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala
@@ -96,6 +96,12 @@ class ManifestFileCommitProtocol(jobId: String, path: String)
     file
   }
 
+  override def newTaskTempFileAbsPath(
+      taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = {
+    throw new UnsupportedOperationException(
+      s"$this does not support adding files with an absolute path")
+  }
+
   override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = {
     if (addedFiles.nonEmpty) {
       val fs = new Path(addedFiles.head).getFileSystem(taskContext.getConfiguration)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala
index ac435bf619..a1aa07456f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionProviderCompatibilitySuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.{AnalysisException, QueryTest}
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.util.Utils
 
 class PartitionProviderCompatibilitySuite
   extends QueryTest with TestHiveSingleton with SQLTestUtils {
@@ -135,7 +136,7 @@ class PartitionProviderCompatibilitySuite
     }
   }
 
-  test("insert overwrite partition of legacy datasource table overwrites entire table") {
+  test("insert overwrite partition of legacy datasource table") {
     withSQLConf(SQLConf.HIVE_MANAGE_FILESOURCE_PARTITIONS.key -> "false") {
       withTable("test") {
         withTempDir { dir =>
@@ -144,9 +145,9 @@ class PartitionProviderCompatibilitySuite
             """insert overwrite table test
               |partition (partCol=1)
               |select * from range(100)""".stripMargin)
-          assert(spark.sql("select * from test").count() == 100)
+          assert(spark.sql("select * from test").count() == 104)
 
-          // Dynamic partitions case
+          // Overwriting entire table
           spark.sql("insert overwrite table test select id, id from range(10)".stripMargin)
           assert(spark.sql("select * from test").count() == 10)
         }
@@ -186,4 +187,158 @@ class PartitionProviderCompatibilitySuite
       }
     }
   }
+
+  /**
+   * Runs a test against a multi-level partitioned table, then validates that the custom locations
+   * were respected by the output writer.
+   *
+   * The initial partitioning structure is:
+   *   /P1=0/P2=0  -- custom location a
+   *   /P1=0/P2=1  -- custom location b
+   *   /P1=1/P2=0  -- custom location c
+   *   /P1=1/P2=1  -- default location
+   */
+  private def testCustomLocations(testFn: => Unit): Unit = {
+    val base = Utils.createTempDir(namePrefix = "base")
+    val a = Utils.createTempDir(namePrefix = "a")
+    val b = Utils.createTempDir(namePrefix = "b")
+    val c = Utils.createTempDir(namePrefix = "c")
+    try {
+      spark.sql(s"""
+        |create table test (id long, P1 int, P2 int)
+        |using parquet
+        |options (path "${base.getAbsolutePath}")
+        |partitioned by (P1, P2)""".stripMargin)
+      spark.sql(s"alter table test add partition (P1=0, P2=0) location '${a.getAbsolutePath}'")
+      spark.sql(s"alter table test add partition (P1=0, P2=1) location '${b.getAbsolutePath}'")
+      spark.sql(s"alter table test add partition (P1=1, P2=0) location '${c.getAbsolutePath}'")
+      spark.sql(s"alter table test add partition (P1=1, P2=1)")
+
+      testFn
+
+      // Now validate the partition custom locations were respected
+      val initialCount = spark.sql("select * from test").count()
+      val numA = spark.sql("select * from test where P1=0 and P2=0").count()
+      val numB = spark.sql("select * from test where P1=0 and P2=1").count()
+      val numC = spark.sql("select * from test where P1=1 and P2=0").count()
+      Utils.deleteRecursively(a)
+      spark.sql("refresh table test")
+      assert(spark.sql("select * from test where P1=0 and P2=0").count() == 0)
+      assert(spark.sql("select * from test").count() == initialCount - numA)
+      Utils.deleteRecursively(b)
+      spark.sql("refresh table test")
+      assert(spark.sql("select * from test where P1=0 and P2=1").count() == 0)
+      assert(spark.sql("select * from test").count() == initialCount - numA - numB)
+      Utils.deleteRecursively(c)
+      spark.sql("refresh table test")
+      assert(spark.sql("select * from test where P1=1 and P2=0").count() == 0)
+      assert(spark.sql("select * from test").count() == initialCount - numA - numB - numC)
+    } finally {
+      Utils.deleteRecursively(base)
+      Utils.deleteRecursively(a)
+      Utils.deleteRecursively(b)
+      Utils.deleteRecursively(c)
+      spark.sql("drop table test")
+    }
+  }
+
+  test("sanity check table setup") {
+    testCustomLocations {
+      assert(spark.sql("select * from test").count() == 0)
+      assert(spark.sql("show partitions test").count() == 4)
+    }
+  }
+
+  test("insert into partial dynamic partitions") {
+    testCustomLocations {
+      spark.sql("insert into test partition (P1=0, P2) select id, id from range(10)")
+      assert(spark.sql("select * from test").count() == 10)
+      assert(spark.sql("show partitions test").count() == 12)
+      spark.sql("insert into test partition (P1=0, P2) select id, id from range(10)")
+      assert(spark.sql("select * from test").count() == 20)
+      assert(spark.sql("show partitions test").count() == 12)
+      spark.sql("insert into test partition (P1=1, P2) select id, id from range(10)")
+      assert(spark.sql("select * from test").count() == 30)
+      assert(spark.sql("show partitions test").count() == 20)
+      spark.sql("insert into test partition (P1=2, P2) select id, id from range(10)")
+      assert(spark.sql("select * from test").count() == 40)
+      assert(spark.sql("show partitions test").count() == 30)
+    }
+  }
+
+  test("insert into fully dynamic partitions") {
+    testCustomLocations {
+      spark.sql("insert into test partition (P1, P2) select id, id, id from range(10)")
+      assert(spark.sql("select * from test").count() == 10)
+      assert(spark.sql("show partitions test").count() == 12)
+      spark.sql("insert into test partition (P1, P2) select id, id, id from range(10)")
+      assert(spark.sql("select * from test").count() == 20)
+      assert(spark.sql("show partitions test").count() == 12)
+    }
+  }
+
+  test("insert into static partition") {
+    testCustomLocations {
+      spark.sql("insert into test partition (P1=0, P2=0) select id from range(10)")
+      assert(spark.sql("select * from test").count() == 10)
+      assert(spark.sql("show partitions test").count() == 4)
+      spark.sql("insert into test partition (P1=0, P2=0) select id from range(10)")
+      assert(spark.sql("select * from test").count() == 20)
+      assert(spark.sql("show partitions test").count() == 4)
+      spark.sql("insert into test partition (P1=1, P2=1) select id from range(10)")
+      assert(spark.sql("select * from test").count() == 30)
+      assert(spark.sql("show partitions test").count() == 4)
+    }
+  }
+
+  test("overwrite partial dynamic partitions") {
+    testCustomLocations {
+      spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(10)")
+      assert(spark.sql("select * from test").count() == 10)
+      assert(spark.sql("show partitions test").count() == 12)
+      spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(5)")
+      assert(spark.sql("select * from test").count() == 5)
+      assert(spark.sql("show partitions test").count() == 7)
+      spark.sql("insert overwrite table test partition (P1=0, P2) select id, id from range(1)")
+      assert(spark.sql("select * from test").count() == 1)
+      assert(spark.sql("show partitions test").count() == 3)
+      spark.sql("insert overwrite table test partition (P1=1, P2) select id, id from range(10)")
+      assert(spark.sql("select * from test").count() == 11)
+      assert(spark.sql("show partitions test").count() == 11)
+      spark.sql("insert overwrite table test partition (P1=1, P2) select id, id from range(1)")
+      assert(spark.sql("select * from test").count() == 2)
+      assert(spark.sql("show partitions test").count() == 2)
+      spark.sql("insert overwrite table test partition (P1=3, P2) select id, id from range(100)")
+      assert(spark.sql("select * from test").count() == 102)
+      assert(spark.sql("show partitions test").count() == 102)
+    }
+  }
+
+  test("overwrite fully dynamic partitions") {
+    testCustomLocations {
+      spark.sql("insert overwrite table test partition (P1, P2) select id, id, id from range(10)")
+      assert(spark.sql("select * from test").count() == 10)
+      assert(spark.sql("show partitions test").count() == 10)
+      spark.sql("insert overwrite table test partition (P1, P2) select id, id, id from range(5)")
+      assert(spark.sql("select * from test").count() == 5)
+      assert(spark.sql("show partitions test").count() == 5)
+    }
+  }
+
+  test("overwrite static partition") {
+    testCustomLocations {
+      spark.sql("insert overwrite table test partition (P1=0, P2=0) select id from range(10)")
+      assert(spark.sql("select * from test").count() == 10)
+      assert(spark.sql("show partitions test").count() == 4)
+      spark.sql("insert overwrite table test partition (P1=0, P2=0) select id from range(5)")
+      assert(spark.sql("select * from test").count() == 5)
+      assert(spark.sql("show partitions test").count() == 4)
+      spark.sql("insert overwrite table test partition (P1=1, P2=1) select id from range(5)")
+      assert(spark.sql("select * from test").count() == 10)
+      assert(spark.sql("show partitions test").count() == 4)
+      spark.sql("insert overwrite table test partition (P1=1, P2=2) select id from range(5)")
+      assert(spark.sql("select * from test").count() == 15)
+      assert(spark.sql("show partitions test").count() == 5)
+    }
+  }
 }
-- 
GitLab