From efc254a82bc3331d78023f00d29d4c4318dfb734 Mon Sep 17 00:00:00 2001
From: Eric Liang <ekl@databricks.com>
Date: Mon, 31 Oct 2016 19:46:55 -0700
Subject: [PATCH] [SPARK-18087][SQL] Optimize insert to not require REPAIR
 TABLE

## What changes were proposed in this pull request?

When inserting into datasource tables with partitions managed by the hive metastore, we need to notify the metastore of newly added partitions. Previously this was implemented via `msck repair table`, but this is more expensive than needed.

This optimizes the insertion path to add only the updated partitions.
## How was this patch tested?

Existing tests (I verified manually that tests fail if the repair operation is omitted).

Author: Eric Liang <ekl@databricks.com>

Closes #15633 from ericl/spark-18087.
---
 .../execution/datasources/DataSource.scala    |  2 +-
 .../datasources/DataSourceStrategy.scala      | 27 ++++++++++-------
 .../InsertIntoHadoopFsRelationCommand.scala   |  3 +-
 .../datasources/PartitioningUtils.scala       | 12 ++++++++
 .../execution/datasources/WriteOutput.scala   | 29 +++++++++++++------
 5 files changed, 52 insertions(+), 21 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 996109865f..d980e6a15a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -528,7 +528,7 @@ case class DataSource(
             columns,
             bucketSpec,
             format,
-            () => Unit, // No existing table needs to be refreshed.
+            _ => Unit, // No existing table needs to be refreshed.
             options,
             data.logicalPlan,
             mode)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index f0bcf94ead..34b77cab65 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, Inte
 import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
 import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SimpleCatalogRelation}
+import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning.PhysicalOperation
@@ -34,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
-import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, DDLUtils, ExecutedCommandExec}
+import org.apache.spark.sql.execution.command.{AlterTableAddPartitionCommand, DDLUtils, ExecutedCommandExec}
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -179,24 +180,30 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
           "Cannot overwrite a path that is also being read from.")
       }
 
+      def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = {
+        if (l.catalogTable.isDefined &&
+            l.catalogTable.get.partitionColumnNames.nonEmpty &&
+            l.catalogTable.get.partitionProviderIsHive) {
+          val metastoreUpdater = AlterTableAddPartitionCommand(
+            l.catalogTable.get.identifier,
+            updatedPartitions.map(p => (p, None)),
+            ifNotExists = true)
+          metastoreUpdater.run(t.sparkSession)
+        }
+        t.location.refresh()
+      }
+
       val insertCmd = InsertIntoHadoopFsRelationCommand(
         outputPath,
         query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver),
         t.bucketSpec,
         t.fileFormat,
-        () => t.location.refresh(),
+        refreshPartitionsCallback,
         t.options,
         query,
         mode)
 
-      if (l.catalogTable.isDefined && l.catalogTable.get.partitionColumnNames.nonEmpty &&
-          l.catalogTable.get.partitionProviderIsHive) {
-        // TODO(ekl) we should be more efficient here and only recover the newly added partitions
-        val recoverPartitionCmd = AlterTableRecoverPartitionsCommand(l.catalogTable.get.identifier)
-        Union(insertCmd, recoverPartitionCmd)
-      } else {
-        insertCmd
-      }
+      insertCmd
   }
 }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
index 22dbe71495..a1221d0ae6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
@@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path
 
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
+import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.command.RunnableCommand
@@ -40,7 +41,7 @@ case class InsertIntoHadoopFsRelationCommand(
     partitionColumns: Seq[Attribute],
     bucketSpec: Option[BucketSpec],
     fileFormat: FileFormat,
-    refreshFunction: () => Unit,
+    refreshFunction: (Seq[TablePartitionSpec]) => Unit,
     options: Map[String, String],
     @transient query: LogicalPlan,
     mode: SaveMode)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index f66e8b4e2b..b51b41869b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -30,6 +30,7 @@ import org.apache.hadoop.util.Shell
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.Resolver
+import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
 import org.apache.spark.sql.types._
 
@@ -244,6 +245,17 @@ object PartitioningUtils {
     }
   }
 
+  /**
+   * Given a partition path fragment, e.g. `fieldOne=1/fieldTwo=2`, returns a parsed spec
+   * for that fragment, e.g. `Map(("fieldOne", "1"), ("fieldTwo", "2"))`.
+   */
+  def parsePathFragment(pathFragment: String): TablePartitionSpec = {
+    pathFragment.split("/").map { kv =>
+      val pair = kv.split("=", 2)
+      (unescapePathName(pair(0)), unescapePathName(pair(1)))
+    }.toMap
+  }
+
   /**
    * Normalize the column names in partition specification, w.r.t. the real partition column names
    * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala
index bd56e511d0..0eb86fdd6c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources
 
 import java.util.{Date, UUID}
 
+import scala.collection.mutable
+
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.Path
 import org.apache.hadoop.mapreduce._
@@ -30,6 +32,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.mapred.SparkHadoopMapRedUtil
 import org.apache.spark.sql.{Dataset, SparkSession}
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
+import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
 import org.apache.spark.sql.catalyst.InternalRow
@@ -85,7 +88,7 @@ object WriteOutput extends Logging {
       hadoopConf: Configuration,
       partitionColumns: Seq[Attribute],
       bucketSpec: Option[BucketSpec],
-      refreshFunction: () => Unit,
+      refreshFunction: (Seq[TablePartitionSpec]) => Unit,
       options: Map[String, String],
       isAppend: Boolean): Unit = {
 
@@ -120,7 +123,7 @@ object WriteOutput extends Logging {
       val committer = setupDriverCommitter(job, outputPath.toString, isAppend)
 
       try {
-        sparkSession.sparkContext.runJob(queryExecution.toRdd,
+        val updatedPartitions = sparkSession.sparkContext.runJob(queryExecution.toRdd,
           (taskContext: TaskContext, iter: Iterator[InternalRow]) => {
             executeTask(
               description = description,
@@ -128,11 +131,11 @@ object WriteOutput extends Logging {
               sparkPartitionId = taskContext.partitionId(),
               sparkAttemptNumber = taskContext.attemptNumber(),
               iterator = iter)
-          })
+          }).flatten.distinct
 
         committer.commitJob(job)
         logInfo(s"Job ${job.getJobID} committed.")
-        refreshFunction()
+        refreshFunction(updatedPartitions.map(PartitioningUtils.parsePathFragment))
       } catch { case cause: Throwable =>
         logError(s"Aborting job ${job.getJobID}.", cause)
         committer.abortJob(job, JobStatus.State.FAILED)
@@ -147,7 +150,7 @@ object WriteOutput extends Logging {
       sparkStageId: Int,
       sparkPartitionId: Int,
       sparkAttemptNumber: Int,
-      iterator: Iterator[InternalRow]): Unit = {
+      iterator: Iterator[InternalRow]): Set[String] = {
 
     val jobId = SparkHadoopWriter.createJobID(new Date, sparkStageId)
     val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId)
@@ -187,11 +190,12 @@ object WriteOutput extends Logging {
     try {
       Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
         // Execute the task to write rows out
-        writeTask.execute(iterator)
+        val outputPaths = writeTask.execute(iterator)
         writeTask.releaseResources()
 
         // Commit the task
         SparkHadoopMapRedUtil.commitTask(committer, taskAttemptContext, jobId.getId, taskId.getId)
+        outputPaths
       })(catchBlock = {
         // If there is an error, release resource and then abort the task
         try {
@@ -213,7 +217,7 @@ object WriteOutput extends Logging {
    * automatically trigger task aborts.
    */
   private trait ExecuteWriteTask {
-    def execute(iterator: Iterator[InternalRow]): Unit
+    def execute(iterator: Iterator[InternalRow]): Set[String]
     def releaseResources(): Unit
 
     final def filePrefix(split: Int, uuid: String, bucketId: Option[Int]): String = {
@@ -240,11 +244,12 @@ object WriteOutput extends Logging {
       outputWriter
     }
 
-    override def execute(iter: Iterator[InternalRow]): Unit = {
+    override def execute(iter: Iterator[InternalRow]): Set[String] = {
       while (iter.hasNext) {
         val internalRow = iter.next()
         outputWriter.writeInternal(internalRow)
       }
+      Set.empty
     }
 
     override def releaseResources(): Unit = {
@@ -327,7 +332,7 @@ object WriteOutput extends Logging {
       newWriter
     }
 
-    override def execute(iter: Iterator[InternalRow]): Unit = {
+    override def execute(iter: Iterator[InternalRow]): Set[String] = {
       // We should first sort by partition columns, then bucket id, and finally sorting columns.
       val sortingExpressions: Seq[Expression] =
         description.partitionColumns ++ bucketIdExpression ++ sortColumns
@@ -375,6 +380,7 @@ object WriteOutput extends Logging {
 
       // If anything below fails, we should abort the task.
       var currentKey: UnsafeRow = null
+      val updatedPartitions = mutable.Set[String]()
       while (sortedIterator.next()) {
         val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
         if (currentKey != nextKey) {
@@ -386,6 +392,10 @@ object WriteOutput extends Logging {
           logDebug(s"Writing partition: $currentKey")
 
           currentWriter = newOutputWriter(currentKey, getPartitionString)
+          val partitionPath = getPartitionString(currentKey).getString(0)
+          if (partitionPath.nonEmpty) {
+            updatedPartitions.add(partitionPath)
+          }
         }
         currentWriter.writeInternal(sortedIterator.getValue)
       }
@@ -393,6 +403,7 @@ object WriteOutput extends Logging {
         currentWriter.close()
         currentWriter = null
       }
+      updatedPartitions.toSet
     }
 
     override def releaseResources(): Unit = {
-- 
GitLab