diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 996109865fdc7dd233756d77661c5705ebc540c9..d980e6a15aabe41b89d673c1bbb2d49753a6e1d9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -528,7 +528,7 @@ case class DataSource(
             columns,
             bucketSpec,
             format,
-            () => Unit, // No existing table needs to be refreshed.
+            _ => Unit, // No existing table needs to be refreshed.
             options,
             data.logicalPlan,
             mode)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index f0bcf94eadc960dfdb6c78f47faafcad02c2a5c6..34b77cab65def7de11ed2047f34c563ed8b2be59 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, Inte
 import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
 import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SimpleCatalogRelation}
+import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning.PhysicalOperation
@@ -34,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
-import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, DDLUtils, ExecutedCommandExec}
+import org.apache.spark.sql.execution.command.{AlterTableAddPartitionCommand, DDLUtils, ExecutedCommandExec}
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -179,24 +180,30 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
           "Cannot overwrite a path that is also being read from.")
       }
 
+      def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = {
+        if (l.catalogTable.isDefined &&
+            l.catalogTable.get.partitionColumnNames.nonEmpty &&
+            l.catalogTable.get.partitionProviderIsHive) {
+          val metastoreUpdater = AlterTableAddPartitionCommand(
+            l.catalogTable.get.identifier,
+            updatedPartitions.map(p => (p, None)),
+            ifNotExists = true)
+          metastoreUpdater.run(t.sparkSession)
+        }
+        t.location.refresh()
+      }
+
       val insertCmd = InsertIntoHadoopFsRelationCommand(
         outputPath,
         query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver),
         t.bucketSpec,
         t.fileFormat,
-        () => t.location.refresh(),
+        refreshPartitionsCallback,
         t.options,
         query,
         mode)
 
-      if (l.catalogTable.isDefined && l.catalogTable.get.partitionColumnNames.nonEmpty &&
-          l.catalogTable.get.partitionProviderIsHive) {
-        // TODO(ekl) we should be more efficient here and only recover the newly added partitions
-        val recoverPartitionCmd = AlterTableRecoverPartitionsCommand(l.catalogTable.get.identifier)
-        Union(insertCmd, recoverPartitionCmd)
-      } else {
-        insertCmd
-      }
+      insertCmd
   }
 }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
index 22dbe7149531c3d0c8f6c63c1919f2a66f4b3bf3..a1221d0ae6d27ad067b97b00f923ad4ae2dcb990 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala
@@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path
 
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
+import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.command.RunnableCommand
@@ -40,7 +41,7 @@ case class InsertIntoHadoopFsRelationCommand(
     partitionColumns: Seq[Attribute],
     bucketSpec: Option[BucketSpec],
     fileFormat: FileFormat,
-    refreshFunction: () => Unit,
+    refreshFunction: (Seq[TablePartitionSpec]) => Unit,
     options: Map[String, String],
     @transient query: LogicalPlan,
     mode: SaveMode)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index f66e8b4e2b5515c257b38f38e3d30d859559bbef..b51b41869bf06241480c448a961561bf80dd94c9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -30,6 +30,7 @@ import org.apache.hadoop.util.Shell
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.Resolver
+import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
 import org.apache.spark.sql.types._
 
@@ -244,6 +245,17 @@ object PartitioningUtils {
     }
   }
 
+  /**
+   * Given a partition path fragment, e.g. `fieldOne=1/fieldTwo=2`, returns a parsed spec
+   * for that fragment, e.g. `Map(("fieldOne", "1"), ("fieldTwo", "2"))`.
+   */
+  def parsePathFragment(pathFragment: String): TablePartitionSpec = {
+    pathFragment.split("/").map { kv =>
+      val pair = kv.split("=", 2)
+      (unescapePathName(pair(0)), unescapePathName(pair(1)))
+    }.toMap
+  }
+
   /**
    * Normalize the column names in partition specification, w.r.t. the real partition column names
    * and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala
index bd56e511d0ccf9440e9a17f2829a0254dbeb051c..0eb86fdd6caa8a29567836983c6588e0357689a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources
 
 import java.util.{Date, UUID}
 
+import scala.collection.mutable
+
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.Path
 import org.apache.hadoop.mapreduce._
@@ -30,6 +32,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.mapred.SparkHadoopMapRedUtil
 import org.apache.spark.sql.{Dataset, SparkSession}
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
+import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
 import org.apache.spark.sql.catalyst.InternalRow
@@ -85,7 +88,7 @@ object WriteOutput extends Logging {
       hadoopConf: Configuration,
       partitionColumns: Seq[Attribute],
       bucketSpec: Option[BucketSpec],
-      refreshFunction: () => Unit,
+      refreshFunction: (Seq[TablePartitionSpec]) => Unit,
       options: Map[String, String],
       isAppend: Boolean): Unit = {
 
@@ -120,7 +123,7 @@ object WriteOutput extends Logging {
       val committer = setupDriverCommitter(job, outputPath.toString, isAppend)
 
       try {
-        sparkSession.sparkContext.runJob(queryExecution.toRdd,
+        val updatedPartitions = sparkSession.sparkContext.runJob(queryExecution.toRdd,
           (taskContext: TaskContext, iter: Iterator[InternalRow]) => {
             executeTask(
               description = description,
@@ -128,11 +131,11 @@ object WriteOutput extends Logging {
               sparkPartitionId = taskContext.partitionId(),
               sparkAttemptNumber = taskContext.attemptNumber(),
               iterator = iter)
-          })
+          }).flatten.distinct
 
         committer.commitJob(job)
         logInfo(s"Job ${job.getJobID} committed.")
-        refreshFunction()
+        refreshFunction(updatedPartitions.map(PartitioningUtils.parsePathFragment))
       } catch { case cause: Throwable =>
         logError(s"Aborting job ${job.getJobID}.", cause)
         committer.abortJob(job, JobStatus.State.FAILED)
@@ -147,7 +150,7 @@ object WriteOutput extends Logging {
       sparkStageId: Int,
       sparkPartitionId: Int,
       sparkAttemptNumber: Int,
-      iterator: Iterator[InternalRow]): Unit = {
+      iterator: Iterator[InternalRow]): Set[String] = {
 
     val jobId = SparkHadoopWriter.createJobID(new Date, sparkStageId)
     val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId)
@@ -187,11 +190,12 @@ object WriteOutput extends Logging {
     try {
       Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
         // Execute the task to write rows out
-        writeTask.execute(iterator)
+        val outputPaths = writeTask.execute(iterator)
         writeTask.releaseResources()
 
         // Commit the task
         SparkHadoopMapRedUtil.commitTask(committer, taskAttemptContext, jobId.getId, taskId.getId)
+        outputPaths
       })(catchBlock = {
         // If there is an error, release resource and then abort the task
         try {
@@ -213,7 +217,7 @@ object WriteOutput extends Logging {
    * automatically trigger task aborts.
    */
   private trait ExecuteWriteTask {
-    def execute(iterator: Iterator[InternalRow]): Unit
+    def execute(iterator: Iterator[InternalRow]): Set[String]
     def releaseResources(): Unit
 
     final def filePrefix(split: Int, uuid: String, bucketId: Option[Int]): String = {
@@ -240,11 +244,12 @@ object WriteOutput extends Logging {
       outputWriter
     }
 
-    override def execute(iter: Iterator[InternalRow]): Unit = {
+    override def execute(iter: Iterator[InternalRow]): Set[String] = {
       while (iter.hasNext) {
         val internalRow = iter.next()
         outputWriter.writeInternal(internalRow)
       }
+      Set.empty
     }
 
     override def releaseResources(): Unit = {
@@ -327,7 +332,7 @@ object WriteOutput extends Logging {
       newWriter
     }
 
-    override def execute(iter: Iterator[InternalRow]): Unit = {
+    override def execute(iter: Iterator[InternalRow]): Set[String] = {
       // We should first sort by partition columns, then bucket id, and finally sorting columns.
       val sortingExpressions: Seq[Expression] =
         description.partitionColumns ++ bucketIdExpression ++ sortColumns
@@ -375,6 +380,7 @@ object WriteOutput extends Logging {
 
       // If anything below fails, we should abort the task.
       var currentKey: UnsafeRow = null
+      val updatedPartitions = mutable.Set[String]()
       while (sortedIterator.next()) {
         val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
         if (currentKey != nextKey) {
@@ -386,6 +392,10 @@ object WriteOutput extends Logging {
           logDebug(s"Writing partition: $currentKey")
 
           currentWriter = newOutputWriter(currentKey, getPartitionString)
+          val partitionPath = getPartitionString(currentKey).getString(0)
+          if (partitionPath.nonEmpty) {
+            updatedPartitions.add(partitionPath)
+          }
         }
         currentWriter.writeInternal(sortedIterator.getValue)
       }
@@ -393,6 +403,7 @@ object WriteOutput extends Logging {
         currentWriter.close()
         currentWriter = null
       }
+      updatedPartitions.toSet
     }
 
     override def releaseResources(): Unit = {