diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a5755616329ab3a1b4fd17ddac6c3a2a7bf7e3fb..96f2e38946f1cad412f913fb6f127cbbba0ec050 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -452,42 +452,7 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => - val table = lookupTableFromCatalog(u) - // adding the table's partitions or validate the query's partition info - table match { - case relation: CatalogRelation if relation.catalogTable.partitionColumns.nonEmpty => - val tablePartitionNames = relation.catalogTable.partitionColumns.map(_.name) - if (parts.keys.nonEmpty) { - // the query's partitioning must match the table's partitioning - // this is set for queries like: insert into ... partition (one = "a", two = <expr>) - // TODO: add better checking to pre-inserts to avoid needing this here - if (tablePartitionNames.size != parts.keySet.size) { - throw new AnalysisException( - s"""Requested partitioning does not match the ${u.tableIdentifier} table: - |Requested partitions: ${parts.keys.mkString(",")} - |Table partitions: ${tablePartitionNames.mkString(",")}""".stripMargin) - } - // Assume partition columns are correctly placed at the end of the child's output - i.copy(table = EliminateSubqueryAliases(table)) - } else { - // Set up the table's partition scheme with all dynamic partitions by moving partition - // columns to the end of the column list, in partition order. - val (inputPartCols, columns) = child.output.partition { attr => - tablePartitionNames.contains(attr.name) - } - // All partition columns are dynamic because this InsertIntoTable had no partitioning - val partColumns = tablePartitionNames.map { name => - inputPartCols.find(_.name == name).getOrElse( - throw new AnalysisException(s"Cannot find partition column $name")) - } - i.copy( - table = EliminateSubqueryAliases(table), - partition = tablePartitionNames.map(_ -> None).toMap, - child = Project(columns ++ partColumns, child)) - } - case _ => - i.copy(table = EliminateSubqueryAliases(table)) - } + i.copy(table = EliminateSubqueryAliases(lookupTableFromCatalog(u))) case u: UnresolvedRelation => val table = u.tableIdentifier if (table.database.isDefined && conf.runSQLonFile && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 6c3eb3a5a28ab3541235633055d41a2129aba50d..69b8b059fde1c1a3c2c3a1e300e3f05980ef979e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -369,10 +369,8 @@ case class InsertIntoTable( if (table.output.isEmpty) { None } else { - val numDynamicPartitions = partition.values.count(_.isEmpty) - val (partitionColumns, dataColumns) = table.output - .partition(a => partition.keySet.contains(a.name)) - Some(dataColumns ++ partitionColumns.takeRight(numDynamicPartitions)) + val staticPartCols = partition.filter(_._2.isDefined).keySet + Some(table.output.filterNot(a => staticPartCols.contains(a.name))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 543389efd5b956ddbc44b1a326289b7fdf5f21b2..5963c53a1b1cb6ccbcd3a75c62b161533615c1f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -21,7 +21,7 @@ import scala.util.control.NonFatal import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ @@ -62,53 +62,79 @@ private[sql] class ResolveDataSource(sparkSession: SparkSession) extends Rule[Lo } /** - * A rule to do pre-insert data type casting and field renaming. Before we insert into - * an [[InsertableRelation]], we will use this rule to make sure that - * the columns to be inserted have the correct data type and fields have the correct names. + * Preprocess the [[InsertIntoTable]] plan. Throws exception if the number of columns mismatch, or + * specified partition columns are different from the existing partition columns in the target + * table. It also does data type casting and field renaming, to make sure that the columns to be + * inserted have the correct data type and fields have the correct names. */ -private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Wait until children are resolved. - case p: LogicalPlan if !p.childrenResolved => p - - // We are inserting into an InsertableRelation or HadoopFsRelation. - case i @ InsertIntoTable( - l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _), _, child, _, _) => - // First, make sure the data to be inserted have the same number of fields with the - // schema of the relation. - if (l.output.size != child.output.size) { - sys.error( - s"$l requires that the data to be inserted have the same number of columns as the " + - s"target table: target table has ${l.output.size} column(s) but " + - s"the inserted data has ${child.output.size} column(s).") - } - castAndRenameChildOutput(i, l.output, child) +private[sql] object PreprocessTableInsertion extends Rule[LogicalPlan] { + private def preprocess( + insert: InsertIntoTable, + tblName: String, + partColNames: Seq[String]): InsertIntoTable = { + + val expectedColumns = insert.expectedColumns + if (expectedColumns.isDefined && expectedColumns.get.length != insert.child.schema.length) { + throw new AnalysisException( + s"Cannot insert into table $tblName because the number of columns are different: " + + s"need ${expectedColumns.get.length} columns, " + + s"but query has ${insert.child.schema.length} columns.") + } + + if (insert.partition.nonEmpty) { + // the query's partitioning must match the table's partitioning + // this is set for queries like: insert into ... partition (one = "a", two = <expr>) + if (insert.partition.keySet != partColNames.toSet) { + throw new AnalysisException( + s""" + |Requested partitioning does not match the table $tblName: + |Requested partitions: ${insert.partition.keys.mkString(",")} + |Table partitions: ${partColNames.mkString(",")} + """.stripMargin) + } + expectedColumns.map(castAndRenameChildOutput(insert, _)).getOrElse(insert) + } else { + // All partition columns are dynamic because this InsertIntoTable had no partitioning + expectedColumns.map(castAndRenameChildOutput(insert, _)).getOrElse(insert) + .copy(partition = partColNames.map(_ -> None).toMap) + } } - /** If necessary, cast data types and rename fields to the expected types and names. */ + // TODO: do we really need to rename? def castAndRenameChildOutput( - insertInto: InsertIntoTable, - expectedOutput: Seq[Attribute], - child: LogicalPlan): InsertIntoTable = { - val newChildOutput = expectedOutput.zip(child.output).map { + insert: InsertIntoTable, + expectedOutput: Seq[Attribute]): InsertIntoTable = { + val newChildOutput = expectedOutput.zip(insert.child.output).map { case (expected, actual) => - val needCast = !expected.dataType.sameType(actual.dataType) - // We want to make sure the filed names in the data to be inserted exactly match - // names in the schema. - val needRename = expected.name != actual.name - (needCast, needRename) match { - case (true, _) => Alias(Cast(actual, expected.dataType), expected.name)() - case (false, true) => Alias(actual, expected.name)() - case (_, _) => actual + if (expected.dataType.sameType(actual.dataType) && expected.name == actual.name) { + actual + } else { + Alias(Cast(actual, expected.dataType), expected.name)() } } - if (newChildOutput == child.output) { - insertInto + if (newChildOutput == insert.child.output) { + insert } else { - insertInto.copy(child = Project(newChildOutput, child)) + insert.copy(child = Project(newChildOutput, insert.child)) } } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case i @ InsertIntoTable(table, partition, child, _, _) if table.resolved && child.resolved => + table match { + case relation: CatalogRelation => + val metadata = relation.catalogTable + preprocess(i, metadata.identifier.quotedString, metadata.partitionColumnNames) + case LogicalRelation(h: HadoopFsRelation, _, identifier) => + val tblName = identifier.map(_.quotedString).getOrElse("unknown") + preprocess(i, tblName, h.partitionSchema.map(_.name)) + case LogicalRelation(_: InsertableRelation, _, identifier) => + val tblName = identifier.map(_.quotedString).getOrElse("unknown") + preprocess(i, tblName, Nil) + case other => i + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index dc95123d0088b7d55dee498b9b675c9a44f690d3..b033e19ddf06c692e5f41a775dd8b761fdcf1135 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.AnalyzeTableCommand -import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, FindDataSourceTable, PreInsertCastAndRename, ResolveDataSource} +import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, FindDataSourceTable, PreprocessTableInsertion, ResolveDataSource} import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager} import org.apache.spark.sql.util.ExecutionListenerManager @@ -111,7 +111,7 @@ private[sql] class SessionState(sparkSession: SparkSession) { lazy val analyzer: Analyzer = { new Analyzer(catalog, conf) { override val extendedResolutionRules = - PreInsertCastAndRename :: + PreprocessTableInsertion :: new FindDataSourceTable(sparkSession) :: DataSourceAnalysis :: (if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index d7179551d62dfe4110c06d491ee53ffd759d55fa..6454d716ec0dbc0a7a454f7101195d7b61dde7db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -88,15 +88,13 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } test("SELECT clause generating a different number of columns is not allowed.") { - val message = intercept[RuntimeException] { + val message = intercept[AnalysisException] { sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt """.stripMargin) }.getMessage - assert( - message.contains("requires that the data to be inserted have the same number of columns"), - "SELECT clause generating a different number of columns should not be not allowed." + assert(message.contains("the number of columns are different") ) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 224ff3823b59427039807cd4443598e186ddc922..2e0b5d59b5783180e76d1b4d69671985d5e4766d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -457,49 +457,6 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log allowExisting) } } - - /** - * Casts input data to correct data types according to table definition before inserting into - * that table. - */ - object PreInsertionCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transform { - // Wait until children are resolved. - case p: LogicalPlan if !p.childrenResolved => p - - case p @ InsertIntoTable(table: MetastoreRelation, _, child, _, _) => - castChildOutput(p, table, child) - } - - def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) - : LogicalPlan = { - val childOutputDataTypes = child.output.map(_.dataType) - val numDynamicPartitions = p.partition.values.count(_.isEmpty) - val tableOutputDataTypes = - (table.attributes ++ table.partitionKeys.takeRight(numDynamicPartitions)) - .take(child.output.length).map(_.dataType) - - if (childOutputDataTypes == tableOutputDataTypes) { - InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists) - } else if (childOutputDataTypes.size == tableOutputDataTypes.size && - childOutputDataTypes.zip(tableOutputDataTypes) - .forall { case (left, right) => left.sameType(right) }) { - // If both types ignoring nullability of ArrayType, MapType, StructType are the same, - // use InsertIntoHiveTable instead of InsertIntoTable. - InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists) - } else { - // Only do the casting when child output data types differ from table output data types. - val castedChildOutput = child.output.zip(table.output).map { - case (input, output) if input.dataType != output.dataType => - Alias(Cast(input, output.dataType), input.name)() - case (input, _) => input - } - - p.copy(child = logical.Project(castedChildOutput, child)) - } - } - } - } /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 4f8aac8c2fcdd62e855257447057b572e6850409..2f6a2207855ecdd9a66cef0174bee2dc2396e322 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -87,7 +87,6 @@ private[sql] class HiveSessionCatalog( val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions val OrcConversions: Rule[LogicalPlan] = metastoreCatalog.OrcConversions val CreateTables: Rule[LogicalPlan] = metastoreCatalog.CreateTables - val PreInsertionCasts: Rule[LogicalPlan] = metastoreCatalog.PreInsertionCasts override def refreshTable(name: TableIdentifier): Unit = { metastoreCatalog.refreshTable(name) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index ca8e5f822396840b9d0bda27321f1e02ddd70c35..2d286715b57b67deaca5e100c4c07871f38424a9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -65,8 +65,7 @@ private[hive] class HiveSessionState(sparkSession: SparkSession) catalog.ParquetConversions :: catalog.OrcConversions :: catalog.CreateTables :: - catalog.PreInsertionCasts :: - PreInsertCastAndRename :: + PreprocessTableInsertion :: DataSourceAnalysis :: (if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index b890b4bffdcfe51b1c9748d17e2de3e1325cf404..c48735142dd0056e2434a7cd333e5e4706137663 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -325,27 +325,6 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef } } - test("Detect table partitioning with correct partition order") { - withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { - sql("CREATE TABLE source (id bigint, part2 string, part1 string, data string)") - val data = (1 to 10).map(i => (i, if ((i % 2) == 0) "even" else "odd", "p", s"data-$i")) - .toDF("id", "part2", "part1", "data") - - data.write.insertInto("source") - checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq) - - // the original data with part1 and part2 at the end - val expected = data.select("id", "data", "part1", "part2") - - sql( - """CREATE TABLE partitioned (id bigint, data string) - |PARTITIONED BY (part1 string, part2 string)""".stripMargin) - spark.table("source").write.insertInto("partitioned") - - checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq) - } - } - private def testPartitionedHiveSerDeTable(testName: String)(f: String => Unit): Unit = { test(s"Hive SerDe table - $testName") { val hiveTable = "hive_table" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index a846711b84ec93fe84e33719ceb98a7f1d0c2696..f5d2f02d512be5368cd3affc29b9f44789dd9b8d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -348,6 +348,7 @@ abstract class HiveComparisonTest queryString.replace("../../data", testDataPath)) val containsCommands = originalQuery.analyzed.collectFirst { case _: Command => () + case _: InsertIntoTable => () case _: LogicalInsertIntoHiveTable => () }.nonEmpty diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index e0f6ccf04dd3355b567a8105409019d85031efc5..a16b5b2e23c3d74faf33a6625c26d028a3f7f0a3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1033,41 +1033,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("SELECT * FROM boom").queryExecution.analyzed } - test("SPARK-3810: PreInsertionCasts static partitioning support") { - val analyzedPlan = { - loadTestTable("srcpart") - sql("DROP TABLE IF EXISTS withparts") - sql("CREATE TABLE withparts LIKE srcpart") - sql("INSERT INTO TABLE withparts PARTITION(ds='1', hr='2') SELECT key, value FROM src") - .queryExecution.analyzed - } - - assertResult(1, "Duplicated project detected\n" + analyzedPlan) { - analyzedPlan.collect { - case _: Project => () - }.size - } - } - - test("SPARK-3810: PreInsertionCasts dynamic partitioning support") { - val analyzedPlan = { - loadTestTable("srcpart") - sql("DROP TABLE IF EXISTS withparts") - sql("CREATE TABLE withparts LIKE srcpart") - sql("SET hive.exec.dynamic.partition.mode=nonstrict") - - sql("CREATE TABLE IF NOT EXISTS withparts LIKE srcpart") - sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM src") - .queryExecution.analyzed - } - - assertResult(1, "Duplicated project detected\n" + analyzedPlan) { - analyzedPlan.collect { - case _: Project => () - }.size - } - } - test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 9c1f21825315b3db1589eea533424764f2f44b2b..46a77dd917fb360782c4abbe1489ffbf1ed52a6a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1684,4 +1684,36 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ) } } + + test("SPARK-16036: better error message when insert into a table with mismatch schema") { + withTable("hive_table", "datasource_table") { + sql("CREATE TABLE hive_table(a INT) PARTITIONED BY (b INT, c INT)") + sql("CREATE TABLE datasource_table(a INT, b INT, c INT) USING parquet PARTITIONED BY (b, c)") + val e1 = intercept[AnalysisException] { + sql("INSERT INTO TABLE hive_table PARTITION(b=1, c=2) SELECT 1, 2, 3") + } + assert(e1.message.contains("the number of columns are different")) + val e2 = intercept[AnalysisException] { + sql("INSERT INTO TABLE datasource_table PARTITION(b=1, c=2) SELECT 1, 2, 3") + } + assert(e2.message.contains("the number of columns are different")) + } + } + + test("SPARK-16037: INSERT statement should match columns by position") { + withTable("hive_table", "datasource_table") { + sql("CREATE TABLE hive_table(a INT) PARTITIONED BY (b INT, c INT)") + sql("CREATE TABLE datasource_table(a INT, b INT, c INT) USING parquet PARTITIONED BY (b, c)") + + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + sql("INSERT INTO TABLE hive_table SELECT 1, 2 AS c, 3 AS b") + checkAnswer(sql("SELECT a, b, c FROM hive_table"), Row(1, 2, 3)) + sql("INSERT OVERWRITE TABLE hive_table SELECT 1, 2, 3") + checkAnswer(sql("SELECT a, b, c FROM hive_table"), Row(1, 2, 3)) + } + + sql("INSERT INTO TABLE datasource_table SELECT 1, 2 AS c, 3 AS b") + checkAnswer(sql("SELECT a, b, c FROM datasource_table"), Row(1, 2, 3)) + } + } }