Skip to content
Snippets Groups Projects
Commit efc254a8 authored by Eric Liang's avatar Eric Liang Committed by Reynold Xin
Browse files

[SPARK-18087][SQL] Optimize insert to not require REPAIR TABLE

## What changes were proposed in this pull request?

When inserting into datasource tables with partitions managed by the hive metastore, we need to notify the metastore of newly added partitions. Previously this was implemented via `msck repair table`, but this is more expensive than needed.

This optimizes the insertion path to add only the updated partitions.
## How was this patch tested?

Existing tests (I verified manually that tests fail if the repair operation is omitted).

Author: Eric Liang <ekl@databricks.com>

Closes #15633 from ericl/spark-18087.
parent 6633b97b
No related branches found
No related tags found
No related merge requests found
......@@ -528,7 +528,7 @@ case class DataSource(
columns,
bucketSpec,
format,
() => Unit, // No existing table needs to be refreshed.
_ => Unit, // No existing table needs to be refreshed.
options,
data.logicalPlan,
mode)
......
......@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, Inte
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SimpleCatalogRelation}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
......@@ -34,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, DDLUtils, ExecutedCommandExec}
import org.apache.spark.sql.execution.command.{AlterTableAddPartitionCommand, DDLUtils, ExecutedCommandExec}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
......@@ -179,24 +180,30 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
"Cannot overwrite a path that is also being read from.")
}
def refreshPartitionsCallback(updatedPartitions: Seq[TablePartitionSpec]): Unit = {
if (l.catalogTable.isDefined &&
l.catalogTable.get.partitionColumnNames.nonEmpty &&
l.catalogTable.get.partitionProviderIsHive) {
val metastoreUpdater = AlterTableAddPartitionCommand(
l.catalogTable.get.identifier,
updatedPartitions.map(p => (p, None)),
ifNotExists = true)
metastoreUpdater.run(t.sparkSession)
}
t.location.refresh()
}
val insertCmd = InsertIntoHadoopFsRelationCommand(
outputPath,
query.resolve(t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver),
t.bucketSpec,
t.fileFormat,
() => t.location.refresh(),
refreshPartitionsCallback,
t.options,
query,
mode)
if (l.catalogTable.isDefined && l.catalogTable.get.partitionColumnNames.nonEmpty &&
l.catalogTable.get.partitionProviderIsHive) {
// TODO(ekl) we should be more efficient here and only recover the newly added partitions
val recoverPartitionCmd = AlterTableRecoverPartitionsCommand(l.catalogTable.get.identifier)
Union(insertCmd, recoverPartitionCmd)
} else {
insertCmd
}
insertCmd
}
}
......
......@@ -23,6 +23,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand
......@@ -40,7 +41,7 @@ case class InsertIntoHadoopFsRelationCommand(
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
fileFormat: FileFormat,
refreshFunction: () => Unit,
refreshFunction: (Seq[TablePartitionSpec]) => Unit,
options: Map[String, String],
@transient query: LogicalPlan,
mode: SaveMode)
......
......@@ -30,6 +30,7 @@ import org.apache.hadoop.util.Shell
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
import org.apache.spark.sql.types._
......@@ -244,6 +245,17 @@ object PartitioningUtils {
}
}
/**
* Given a partition path fragment, e.g. `fieldOne=1/fieldTwo=2`, returns a parsed spec
* for that fragment, e.g. `Map(("fieldOne", "1"), ("fieldTwo", "2"))`.
*/
def parsePathFragment(pathFragment: String): TablePartitionSpec = {
pathFragment.split("/").map { kv =>
val pair = kv.split("=", 2)
(unescapePathName(pair(0)), unescapePathName(pair(1)))
}.toMap
}
/**
* Normalize the column names in partition specification, w.r.t. the real partition column names
* and case sensitivity. e.g., if the partition spec has a column named `monTh`, and there is a
......
......@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources
import java.util.{Date, UUID}
import scala.collection.mutable
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce._
......@@ -30,6 +32,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.InternalRow
......@@ -85,7 +88,7 @@ object WriteOutput extends Logging {
hadoopConf: Configuration,
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
refreshFunction: () => Unit,
refreshFunction: (Seq[TablePartitionSpec]) => Unit,
options: Map[String, String],
isAppend: Boolean): Unit = {
......@@ -120,7 +123,7 @@ object WriteOutput extends Logging {
val committer = setupDriverCommitter(job, outputPath.toString, isAppend)
try {
sparkSession.sparkContext.runJob(queryExecution.toRdd,
val updatedPartitions = sparkSession.sparkContext.runJob(queryExecution.toRdd,
(taskContext: TaskContext, iter: Iterator[InternalRow]) => {
executeTask(
description = description,
......@@ -128,11 +131,11 @@ object WriteOutput extends Logging {
sparkPartitionId = taskContext.partitionId(),
sparkAttemptNumber = taskContext.attemptNumber(),
iterator = iter)
})
}).flatten.distinct
committer.commitJob(job)
logInfo(s"Job ${job.getJobID} committed.")
refreshFunction()
refreshFunction(updatedPartitions.map(PartitioningUtils.parsePathFragment))
} catch { case cause: Throwable =>
logError(s"Aborting job ${job.getJobID}.", cause)
committer.abortJob(job, JobStatus.State.FAILED)
......@@ -147,7 +150,7 @@ object WriteOutput extends Logging {
sparkStageId: Int,
sparkPartitionId: Int,
sparkAttemptNumber: Int,
iterator: Iterator[InternalRow]): Unit = {
iterator: Iterator[InternalRow]): Set[String] = {
val jobId = SparkHadoopWriter.createJobID(new Date, sparkStageId)
val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId)
......@@ -187,11 +190,12 @@ object WriteOutput extends Logging {
try {
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
// Execute the task to write rows out
writeTask.execute(iterator)
val outputPaths = writeTask.execute(iterator)
writeTask.releaseResources()
// Commit the task
SparkHadoopMapRedUtil.commitTask(committer, taskAttemptContext, jobId.getId, taskId.getId)
outputPaths
})(catchBlock = {
// If there is an error, release resource and then abort the task
try {
......@@ -213,7 +217,7 @@ object WriteOutput extends Logging {
* automatically trigger task aborts.
*/
private trait ExecuteWriteTask {
def execute(iterator: Iterator[InternalRow]): Unit
def execute(iterator: Iterator[InternalRow]): Set[String]
def releaseResources(): Unit
final def filePrefix(split: Int, uuid: String, bucketId: Option[Int]): String = {
......@@ -240,11 +244,12 @@ object WriteOutput extends Logging {
outputWriter
}
override def execute(iter: Iterator[InternalRow]): Unit = {
override def execute(iter: Iterator[InternalRow]): Set[String] = {
while (iter.hasNext) {
val internalRow = iter.next()
outputWriter.writeInternal(internalRow)
}
Set.empty
}
override def releaseResources(): Unit = {
......@@ -327,7 +332,7 @@ object WriteOutput extends Logging {
newWriter
}
override def execute(iter: Iterator[InternalRow]): Unit = {
override def execute(iter: Iterator[InternalRow]): Set[String] = {
// We should first sort by partition columns, then bucket id, and finally sorting columns.
val sortingExpressions: Seq[Expression] =
description.partitionColumns ++ bucketIdExpression ++ sortColumns
......@@ -375,6 +380,7 @@ object WriteOutput extends Logging {
// If anything below fails, we should abort the task.
var currentKey: UnsafeRow = null
val updatedPartitions = mutable.Set[String]()
while (sortedIterator.next()) {
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
if (currentKey != nextKey) {
......@@ -386,6 +392,10 @@ object WriteOutput extends Logging {
logDebug(s"Writing partition: $currentKey")
currentWriter = newOutputWriter(currentKey, getPartitionString)
val partitionPath = getPartitionString(currentKey).getString(0)
if (partitionPath.nonEmpty) {
updatedPartitions.add(partitionPath)
}
}
currentWriter.writeInternal(sortedIterator.getValue)
}
......@@ -393,6 +403,7 @@ object WriteOutput extends Logging {
currentWriter.close()
currentWriter = null
}
updatedPartitions.toSet
}
override def releaseResources(): Unit = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment