diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 69b8b059fde1c1a3c2c3a1e300e3f05980ef979e..ff3dcbc957ac191baf842baa32925661424329e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -369,6 +369,8 @@ case class InsertIntoTable( if (table.output.isEmpty) { None } else { + // Note: The parser (visitPartitionSpec in AstBuilder) already turns + // keys in partition to their lowercase forms. val staticPartCols = partition.filter(_._2.isDefined).keySet Some(table.output.filterNot(a => staticPartCols.contains(a.name))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index e6fc9749c72670d0be7fedc8a00678bb8d7ea0eb..ca3972d62dfb5767eccaa126538c49ebbba10e25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -245,29 +245,17 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { if (partitioningColumns.isDefined) { throw new AnalysisException( "insertInto() can't be used together with partitionBy(). " + - "Partition columns are defined by the table into which is being inserted." + "Partition columns have already be defined for the table. " + + "It is not necessary to use partitionBy()." ) } - val partitions = normalizedParCols.map(_.map(col => col -> Option.empty[String]).toMap) - val overwrite = mode == SaveMode.Overwrite - - // A partitioned relation's schema can be different from the input logicalPlan, since - // partition columns are all moved after data columns. We Project to adjust the ordering. - // TODO: this belongs to the analyzer. - val input = normalizedParCols.map { parCols => - val (inputPartCols, inputDataCols) = df.logicalPlan.output.partition { attr => - parCols.contains(attr.name) - } - Project(inputDataCols ++ inputPartCols, df.logicalPlan) - }.getOrElse(df.logicalPlan) - df.sparkSession.sessionState.executePlan( InsertIntoTable( - UnresolvedRelation(tableIdent), - partitions.getOrElse(Map.empty[String, Option[String]]), - input, - overwrite, + table = UnresolvedRelation(tableIdent), + partition = Map.empty[String, Option[String]], + child = df.logicalPlan, + overwrite = mode == SaveMode.Overwrite, ifNotExists = false)).toRdd } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index f274fc77daef08d841691745e64f3f665909033c..557445c2bc91f1d7961bae7cf8d0f709f51b8e97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -435,7 +435,7 @@ case class DataSource( // If we are appending to a table that already exists, make sure the partitioning matches // up. If we fail to load the table for whatever reason, ignore the check. if (mode == SaveMode.Append) { - val existingColumns = Try { + val existingPartitionColumns = Try { resolveRelation() .asInstanceOf[HadoopFsRelation] .location @@ -444,13 +444,14 @@ case class DataSource( .fieldNames .toSeq }.getOrElse(Seq.empty[String]) + // TODO: Case sensitivity. val sameColumns = - existingColumns.map(_.toLowerCase) == partitionColumns.map(_.toLowerCase) - if (existingColumns.size > 0 && !sameColumns) { + existingPartitionColumns.map(_.toLowerCase()) == partitionColumns.map(_.toLowerCase()) + if (existingPartitionColumns.size > 0 && !sameColumns) { throw new AnalysisException( s"""Requested partitioning does not match existing partitioning. |Existing partitioning columns: - | ${existingColumns.mkString(", ")} + | ${existingPartitionColumns.mkString(", ")} |Requested partitioning columns: | ${partitionColumns.mkString(", ")} |""".stripMargin) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 5963c53a1b1cb6ccbcd3a75c62b161533615c1f8..10425af3e1f186e9bef083198eff35a15cfa15bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -67,7 +67,7 @@ private[sql] class ResolveDataSource(sparkSession: SparkSession) extends Rule[Lo * table. It also does data type casting and field renaming, to make sure that the columns to be * inserted have the correct data type and fields have the correct names. */ -private[sql] object PreprocessTableInsertion extends Rule[LogicalPlan] { +private[sql] case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { private def preprocess( insert: InsertIntoTable, tblName: String, @@ -84,7 +84,13 @@ private[sql] object PreprocessTableInsertion extends Rule[LogicalPlan] { if (insert.partition.nonEmpty) { // the query's partitioning must match the table's partitioning // this is set for queries like: insert into ... partition (one = "a", two = <expr>) - if (insert.partition.keySet != partColNames.toSet) { + val samePartitionColumns = + if (conf.caseSensitiveAnalysis) { + insert.partition.keySet == partColNames.toSet + } else { + insert.partition.keySet.map(_.toLowerCase) == partColNames.map(_.toLowerCase).toSet + } + if (!samePartitionColumns) { throw new AnalysisException( s""" |Requested partitioning does not match the table $tblName: @@ -94,7 +100,8 @@ private[sql] object PreprocessTableInsertion extends Rule[LogicalPlan] { } expectedColumns.map(castAndRenameChildOutput(insert, _)).getOrElse(insert) } else { - // All partition columns are dynamic because this InsertIntoTable had no partitioning + // All partition columns are dynamic because because the InsertIntoTable command does + // not explicitly specify partitioning columns. expectedColumns.map(castAndRenameChildOutput(insert, _)).getOrElse(insert) .copy(partition = partColNames.map(_ -> None).toMap) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index b033e19ddf06c692e5f41a775dd8b761fdcf1135..5300cfa8a7b06ce963dd7a7340a841346d214dc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -111,7 +111,7 @@ private[sql] class SessionState(sparkSession: SparkSession) { lazy val analyzer: Analyzer = { new Analyzer(catalog, conf) { override val extendedResolutionRules = - PreprocessTableInsertion :: + PreprocessTableInsertion(conf) :: new FindDataSourceTable(sparkSession) :: DataSourceAnalysis :: (if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 8827649d0a0d9fbf6ff9f3a87f4d874d47039c94..f40ddcc95affd3c6812aee1294dda51f2e361ac4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1337,8 +1337,24 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(sql("select * from partitionedTable").collect().size == 1) // Inserts new data successfully when partition columns are correctly specified in // partitionBy(...). - df.write.mode("append").partitionBy("a", "b").saveAsTable("partitionedTable") - assert(sql("select * from partitionedTable").collect().size == 2) + // TODO: Right now, partition columns are always treated in a case-insensitive way. + // See the write method in DataSource.scala. + Seq((4, 5, 6)).toDF("a", "B", "c") + .write + .mode("append") + .partitionBy("a", "B") + .saveAsTable("partitionedTable") + + Seq((7, 8, 9)).toDF("a", "b", "c") + .write + .mode("append") + .partitionBy("a", "b") + .saveAsTable("partitionedTable") + + checkAnswer( + sql("select a, b, c from partitionedTable"), + Row(1, 2, 3) :: Row(4, 5, 6) :: Row(7, 8, 9) :: Nil + ) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 2d286715b57b67deaca5e100c4c07871f38424a9..f6675f09041d3abe34523166511651834b574911 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -65,7 +65,7 @@ private[hive] class HiveSessionState(sparkSession: SparkSession) catalog.ParquetConversions :: catalog.OrcConversions :: catalog.CreateTables :: - PreprocessTableInsertion :: + PreprocessTableInsertion(conf) :: DataSourceAnalysis :: (if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index c48735142dd0056e2434a7cd333e5e4706137663..d4ebd051d2eed91890586ed31f5708e57e6fd554 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -372,4 +372,24 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef assert(!logical.resolved, "Should not resolve: missing partition data") } } + + testPartitionedTable( + "SPARK-16036: better error message when insert into a table with mismatch schema") { + tableName => + val e = intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $tableName PARTITION(b=1, c=2) SELECT 1, 2, 3") + } + assert(e.message.contains("the number of columns are different")) + } + + testPartitionedTable( + "SPARK-16037: INSERT statement should match columns by position") { + tableName => + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + sql(s"INSERT INTO TABLE $tableName SELECT 1, 2 AS c, 3 AS b") + checkAnswer(sql(s"SELECT a, b, c FROM $tableName"), Row(1, 2, 3)) + sql(s"INSERT OVERWRITE TABLE $tableName SELECT 1, 2, 3") + checkAnswer(sql(s"SELECT a, b, c FROM $tableName"), Row(1, 2, 3)) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index a16b5b2e23c3d74faf33a6625c26d028a3f7f0a3..85b159e2a5dedd731295ca8599b7f9aa46a43150 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1033,6 +1033,41 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("SELECT * FROM boom").queryExecution.analyzed } + test("SPARK-3810: PreprocessTableInsertion static partitioning support") { + val analyzedPlan = { + loadTestTable("srcpart") + sql("DROP TABLE IF EXISTS withparts") + sql("CREATE TABLE withparts LIKE srcpart") + sql("INSERT INTO TABLE withparts PARTITION(ds='1', hr='2') SELECT key, value FROM src") + .queryExecution.analyzed + } + + assertResult(1, "Duplicated project detected\n" + analyzedPlan) { + analyzedPlan.collect { + case _: Project => () + }.size + } + } + + test("SPARK-3810: PreprocessTableInsertion dynamic partitioning support") { + val analyzedPlan = { + loadTestTable("srcpart") + sql("DROP TABLE IF EXISTS withparts") + sql("CREATE TABLE withparts LIKE srcpart") + sql("SET hive.exec.dynamic.partition.mode=nonstrict") + + sql("CREATE TABLE IF NOT EXISTS withparts LIKE srcpart") + sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value, '1', '2' FROM src") + .queryExecution.analyzed + } + + assertResult(2, "Duplicated project detected\n" + analyzedPlan) { + analyzedPlan.collect { + case _: Project => () + }.size + } + } + test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 46a77dd917fb360782c4abbe1489ffbf1ed52a6a..9c1f21825315b3db1589eea533424764f2f44b2b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1684,36 +1684,4 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ) } } - - test("SPARK-16036: better error message when insert into a table with mismatch schema") { - withTable("hive_table", "datasource_table") { - sql("CREATE TABLE hive_table(a INT) PARTITIONED BY (b INT, c INT)") - sql("CREATE TABLE datasource_table(a INT, b INT, c INT) USING parquet PARTITIONED BY (b, c)") - val e1 = intercept[AnalysisException] { - sql("INSERT INTO TABLE hive_table PARTITION(b=1, c=2) SELECT 1, 2, 3") - } - assert(e1.message.contains("the number of columns are different")) - val e2 = intercept[AnalysisException] { - sql("INSERT INTO TABLE datasource_table PARTITION(b=1, c=2) SELECT 1, 2, 3") - } - assert(e2.message.contains("the number of columns are different")) - } - } - - test("SPARK-16037: INSERT statement should match columns by position") { - withTable("hive_table", "datasource_table") { - sql("CREATE TABLE hive_table(a INT) PARTITIONED BY (b INT, c INT)") - sql("CREATE TABLE datasource_table(a INT, b INT, c INT) USING parquet PARTITIONED BY (b, c)") - - withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { - sql("INSERT INTO TABLE hive_table SELECT 1, 2 AS c, 3 AS b") - checkAnswer(sql("SELECT a, b, c FROM hive_table"), Row(1, 2, 3)) - sql("INSERT OVERWRITE TABLE hive_table SELECT 1, 2, 3") - checkAnswer(sql("SELECT a, b, c FROM hive_table"), Row(1, 2, 3)) - } - - sql("INSERT INTO TABLE datasource_table SELECT 1, 2 AS c, 3 AS b") - checkAnswer(sql("SELECT a, b, c FROM datasource_table"), Row(1, 2, 3)) - } - } }