diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 817c1ab688471641621dc6f676430418f88ccdd8..4331841fbffb45142181a469390bdb71e414a337 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec object ExternalCatalogUtils { @@ -133,4 +135,39 @@ object CatalogUtils { case o => o } } + + def normalizePartCols( + tableName: String, + tableCols: Seq[String], + partCols: Seq[String], + resolver: Resolver): Seq[String] = { + partCols.map(normalizeColumnName(tableName, tableCols, _, "partition", resolver)) + } + + def normalizeBucketSpec( + tableName: String, + tableCols: Seq[String], + bucketSpec: BucketSpec, + resolver: Resolver): BucketSpec = { + val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames) = bucketSpec + val normalizedBucketCols = bucketColumnNames.map { colName => + normalizeColumnName(tableName, tableCols, colName, "bucket", resolver) + } + val normalizedSortCols = sortColumnNames.map { colName => + normalizeColumnName(tableName, tableCols, colName, "sort", resolver) + } + BucketSpec(numBuckets, normalizedBucketCols, normalizedSortCols) + } + + private def normalizeColumnName( + tableName: String, + tableCols: Seq[String], + colName: String, + colType: String, + resolver: Resolver): String = { + tableCols.find(resolver(_, colName)).getOrElse { + throw new AnalysisException(s"$colType column $colName is not defined in table $tableName, " + + s"defined table columns are: ${tableCols.mkString(", ")}") + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index d2a1af08009144bcd7fb2fbc0ccbd595dede8698..5b5378c09e540ad0186e1761abd6b340fd82b4ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -133,6 +133,16 @@ case class BucketSpec( if (numBuckets <= 0) { throw new AnalysisException(s"Expected positive number of buckets, but got `$numBuckets`.") } + + override def toString: String = { + val bucketString = s"bucket columns: [${bucketColumnNames.mkString(", ")}]" + val sortString = if (sortColumnNames.nonEmpty) { + s", sort columns: [${sortColumnNames.mkString(", ")}]" + } else { + "" + } + s"$numBuckets buckets, $bucketString$sortString" + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 630adb0d994ec785e123677cafd527261019abfc..182d182faa2110861bd3d476aace7551be85f6a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -18,13 +18,11 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.sources.BaseRelation /** * A command used to create a data source table. @@ -143,8 +141,9 @@ case class CreateDataSourceTableAsSelectCommand( val tableName = tableIdentWithDB.unquotedString var createMetastoreTable = false - var existingSchema = Option.empty[StructType] - if (sparkSession.sessionState.catalog.tableExists(tableIdentWithDB)) { + // We may need to reorder the columns of the query to match the existing table. + var reorderedColumns = Option.empty[Seq[NamedExpression]] + if (sessionState.catalog.tableExists(tableIdentWithDB)) { // Check if we need to throw an exception or just return. mode match { case SaveMode.ErrorIfExists => @@ -157,39 +156,76 @@ case class CreateDataSourceTableAsSelectCommand( // Since the table already exists and the save mode is Ignore, we will just return. return Seq.empty[Row] case SaveMode.Append => + val existingTable = sessionState.catalog.getTableMetadata(tableIdentWithDB) + + if (existingTable.provider.get == DDLUtils.HIVE_PROVIDER) { + throw new AnalysisException(s"Saving data in the Hive serde table $tableName is " + + "not supported yet. Please use the insertInto() API as an alternative.") + } + // Check if the specified data source match the data source of the existing table. - val existingProvider = DataSource.lookupDataSource(provider) + val existingProvider = DataSource.lookupDataSource(existingTable.provider.get) + val specifiedProvider = DataSource.lookupDataSource(table.provider.get) // TODO: Check that options from the resolved relation match the relation that we are // inserting into (i.e. using the same compression). + if (existingProvider != specifiedProvider) { + throw new AnalysisException(s"The format of the existing table $tableName is " + + s"`${existingProvider.getSimpleName}`. It doesn't match the specified format " + + s"`${specifiedProvider.getSimpleName}`.") + } - // Pass a table identifier with database part, so that `lookupRelation` won't get temp - // views unexpectedly. - EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdentWithDB)) match { - case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _) => - // check if the file formats match - l.relation match { - case r: HadoopFsRelation if r.fileFormat.getClass != existingProvider => - throw new AnalysisException( - s"The file format of the existing table $tableName is " + - s"`${r.fileFormat.getClass.getName}`. It doesn't match the specified " + - s"format `$provider`") - case _ => - } - if (query.schema.size != l.schema.size) { - throw new AnalysisException( - s"The column number of the existing schema[${l.schema}] " + - s"doesn't match the data schema[${query.schema}]'s") - } - existingSchema = Some(l.schema) - case s: SimpleCatalogRelation if DDLUtils.isDatasourceTable(s.metadata) => - existingSchema = Some(s.metadata.schema) - case c: CatalogRelation if c.catalogTable.provider == Some(DDLUtils.HIVE_PROVIDER) => - throw new AnalysisException("Saving data in the Hive serde table " + - s"${c.catalogTable.identifier} is not supported yet. Please use the " + - "insertInto() API as an alternative..") - case o => - throw new AnalysisException(s"Saving data in ${o.toString} is not supported.") + if (query.schema.length != existingTable.schema.length) { + throw new AnalysisException( + s"The column number of the existing table $tableName" + + s"(${existingTable.schema.catalogString}) doesn't match the data schema" + + s"(${query.schema.catalogString})") } + + val resolver = sessionState.conf.resolver + val tableCols = existingTable.schema.map(_.name) + + reorderedColumns = Some(existingTable.schema.map { f => + query.resolve(Seq(f.name), resolver).getOrElse { + val inputColumns = query.schema.map(_.name).mkString(", ") + throw new AnalysisException( + s"cannot resolve '${f.name}' given input columns: [$inputColumns]") + } + }) + + // In `AnalyzeCreateTable`, we verified the consistency between the user-specified table + // definition(partition columns, bucketing) and the SELECT query, here we also need to + // verify the the consistency between the user-specified table definition and the existing + // table definition. + + // Check if the specified partition columns match the existing table. + val specifiedPartCols = CatalogUtils.normalizePartCols( + tableName, tableCols, table.partitionColumnNames, resolver) + if (specifiedPartCols != existingTable.partitionColumnNames) { + throw new AnalysisException( + s""" + |Specified partitioning does not match that of the existing table $tableName. + |Specified partition columns: [${specifiedPartCols.mkString(", ")}] + |Existing partition columns: [${existingTable.partitionColumnNames.mkString(", ")}] + """.stripMargin) + } + + // Check if the specified bucketing match the existing table. + val specifiedBucketSpec = table.bucketSpec.map { bucketSpec => + CatalogUtils.normalizeBucketSpec(tableName, tableCols, bucketSpec, resolver) + } + if (specifiedBucketSpec != existingTable.bucketSpec) { + val specifiedBucketString = + specifiedBucketSpec.map(_.toString).getOrElse("not bucketed") + val existingBucketString = + existingTable.bucketSpec.map(_.toString).getOrElse("not bucketed") + throw new AnalysisException( + s""" + |Specified bucketing does not match that of the existing table $tableName. + |Specified bucketing: $specifiedBucketString + |Existing bucketing: $existingBucketString + """.stripMargin) + } + case SaveMode.Overwrite => sessionState.catalog.dropTable(tableIdentWithDB, ignoreIfNotExists = true, purge = false) // Need to create the table again. @@ -201,9 +237,9 @@ case class CreateDataSourceTableAsSelectCommand( } val data = Dataset.ofRows(sparkSession, query) - val df = existingSchema match { - // If we are inserting into an existing table, just use the existing schema. - case Some(s) => data.selectExpr(s.fieldNames: _*) + val df = reorderedColumns match { + // Reorder the columns of the query to match the existing table. + case Some(cols) => data.select(cols.map(Column(_)): _*) case None => data } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 7154e3e41c93bf4dcc3a02bf24ef2a499f5949f7..2b2fbddd12e45fa446c8bb091ebc4c777a5563b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -17,14 +17,11 @@ package org.apache.spark.sql.execution.datasources -import java.util.regex.Pattern - import scala.util.control.NonFatal import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogRelation, CatalogTable, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTable, CatalogUtils, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ @@ -122,9 +119,12 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl } private def checkPartitionColumns(schema: StructType, tableDesc: CatalogTable): CatalogTable = { - val normalizedPartitionCols = tableDesc.partitionColumnNames.map { colName => - normalizeColumnName(tableDesc.identifier, schema, colName, "partition") - } + val normalizedPartitionCols = CatalogUtils.normalizePartCols( + tableName = tableDesc.identifier.unquotedString, + tableCols = schema.map(_.name), + partCols = tableDesc.partitionColumnNames, + resolver = sparkSession.sessionState.conf.resolver) + checkDuplication(normalizedPartitionCols, "partition") if (schema.nonEmpty && normalizedPartitionCols.length == schema.length) { @@ -149,25 +149,21 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl private def checkBucketColumns(schema: StructType, tableDesc: CatalogTable): CatalogTable = { tableDesc.bucketSpec match { - case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) => - val normalizedBucketCols = bucketColumnNames.map { colName => - normalizeColumnName(tableDesc.identifier, schema, colName, "bucket") - } - checkDuplication(normalizedBucketCols, "bucket") - - val normalizedSortCols = sortColumnNames.map { colName => - normalizeColumnName(tableDesc.identifier, schema, colName, "sort") - } - checkDuplication(normalizedSortCols, "sort") - - schema.filter(f => normalizedSortCols.contains(f.name)).map(_.dataType).foreach { + case Some(bucketSpec) => + val normalizedBucketing = CatalogUtils.normalizeBucketSpec( + tableName = tableDesc.identifier.unquotedString, + tableCols = schema.map(_.name), + bucketSpec = bucketSpec, + resolver = sparkSession.sessionState.conf.resolver) + checkDuplication(normalizedBucketing.bucketColumnNames, "bucket") + checkDuplication(normalizedBucketing.sortColumnNames, "sort") + + normalizedBucketing.sortColumnNames.map(schema(_)).map(_.dataType).foreach { case dt if RowOrdering.isOrderable(dt) => // OK case other => failAnalysis(s"Cannot use ${other.simpleString} for sorting column") } - tableDesc.copy( - bucketSpec = Some(BucketSpec(numBuckets, normalizedBucketCols, normalizedSortCols)) - ) + tableDesc.copy(bucketSpec = Some(normalizedBucketing)) case None => tableDesc } @@ -182,19 +178,6 @@ case class AnalyzeCreateTable(sparkSession: SparkSession) extends Rule[LogicalPl } } - private def normalizeColumnName( - tableIdent: TableIdentifier, - schema: StructType, - colName: String, - colType: String): String = { - val tableCols = schema.map(_.name) - val resolver = sparkSession.sessionState.conf.resolver - tableCols.find(resolver(_, colName)).getOrElse { - failAnalysis(s"$colType column $colName is not defined in table $tableIdent, " + - s"defined table columns are: ${tableCols.mkString(", ")}") - } - } - private def failAnalysis(msg: String) = throw new AnalysisException(msg) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 6593fa479d66b45e25c732bb3b54203262f7bcbb..c0f583e5f7072397795f329fd2d1a95240f0bda2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -342,7 +342,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val e = intercept[AnalysisException] { sql("CREATE TABLE tbl(a int, b string) USING json PARTITIONED BY (c)") } - assert(e.message == "partition column c is not defined in table `tbl`, " + + assert(e.message == "partition column c is not defined in table tbl, " + "defined table columns are: a, b") } @@ -350,7 +350,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val e = intercept[AnalysisException] { sql("CREATE TABLE tbl(a int, b string) USING json CLUSTERED BY (c) INTO 4 BUCKETS") } - assert(e.message == "bucket column c is not defined in table `tbl`, " + + assert(e.message == "bucket column c is not defined in table tbl, " + "defined table columns are: a, b") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index e0887e0f1c7dee1ef401817119685c18703924fa..4bec2e3fdb9d3a368e7d8bf262720e7120fcc4a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -108,16 +108,14 @@ class DefaultSourceWithoutUserSpecifiedSchema } class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter { - + import testImplicits._ private val userSchema = new StructType().add("s", StringType) private val textSchema = new StructType().add("value", StringType) private val data = Seq("1", "2", "3") private val dir = Utils.createTempDir(namePrefix = "input").getCanonicalPath - private implicit var enc: Encoder[String] = _ before { - enc = spark.implicits.newStringEncoder Utils.deleteRecursively(new File(dir)) } @@ -459,8 +457,6 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be } test("column nullability and comment - write and then read") { - import testImplicits._ - Seq("json", "parquet", "csv").foreach { format => val schema = StructType( StructField("cl1", IntegerType, nullable = false).withComment("test") :: @@ -576,7 +572,6 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be test("SPARK-18510: use user specified types for partition columns in file sources") { import org.apache.spark.sql.functions.udf - import testImplicits._ withTempDir { src => val createArray = udf { (length: Long) => for (i <- 1 to length.toInt) yield i.toString @@ -609,4 +604,35 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be ) } } + + test("SPARK-18899: append to a bucketed table using DataFrameWriter with mismatched bucketing") { + withTable("t") { + Seq(1 -> "a", 2 -> "b").toDF("i", "j").write.bucketBy(2, "i").saveAsTable("t") + val e = intercept[AnalysisException] { + Seq(3 -> "c").toDF("i", "j").write.bucketBy(3, "i").mode("append").saveAsTable("t") + } + assert(e.message.contains("Specified bucketing does not match that of the existing table")) + } + } + + test("SPARK-18912: number of columns mismatch for non-file-based data source table") { + withTable("t") { + sql("CREATE TABLE t USING org.apache.spark.sql.test.DefaultSource") + + val e = intercept[AnalysisException] { + Seq(1 -> "a").toDF("a", "b").write + .format("org.apache.spark.sql.test.DefaultSource") + .mode("append").saveAsTable("t") + } + assert(e.message.contains("The column number of the existing table")) + } + } + + test("SPARK-18913: append to a table with special column names") { + withTable("t") { + Seq(1 -> "a").toDF("x.x", "y.y").write.saveAsTable("t") + Seq(2 -> "b").toDF("x.x", "y.y").write.mode("append").saveAsTable("t") + checkAnswer(spark.table("t"), Row(1, "a") :: Row(2, "b") :: Nil) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index a45f4b5d6376cd45f1d635239a539a39e3be80b2..deb40f0464016b11b2727955069d6ccd5873f395 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -422,7 +422,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv val e = intercept[AnalysisException] { df.write.mode(SaveMode.Append).saveAsTable(tableName) }.getMessage - assert(e.contains("Saving data in the Hive serde table `default`.`tab1` is not supported " + + assert(e.contains("Saving data in the Hive serde table default.tab1 is not supported " + "yet. Please use the insertInto() API as an alternative.")) df.write.insertInto(tableName) @@ -928,9 +928,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv createDF(10, 19).write.mode(SaveMode.Append).format("orc").saveAsTable("appendOrcToParquet") } assert(e.getMessage.contains( - "The file format of the existing table default.appendOrcToParquet " + - "is `org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat`. " + - "It doesn't match the specified format `orc`")) + "The format of the existing table default.appendOrcToParquet is `ParquetFileFormat`. " + + "It doesn't match the specified format `OrcFileFormat`")) } withTable("appendParquetToJson") { @@ -940,9 +939,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv .saveAsTable("appendParquetToJson") } assert(e.getMessage.contains( - "The file format of the existing table default.appendParquetToJson " + - "is `org.apache.spark.sql.execution.datasources.json.JsonFileFormat`. " + - "It doesn't match the specified format `parquet`")) + "The format of the existing table default.appendParquetToJson is `JsonFileFormat`. " + + "It doesn't match the specified format `ParquetFileFormat`")) } withTable("appendTextToJson") { @@ -952,9 +950,8 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv .saveAsTable("appendTextToJson") } assert(e.getMessage.contains( - "The file format of the existing table default.appendTextToJson is " + - "`org.apache.spark.sql.execution.datasources.json.JsonFileFormat`. " + - "It doesn't match the specified format `text`")) + "The format of the existing table default.appendTextToJson is `JsonFileFormat`. " + + "It doesn't match the specified format `TextFileFormat`")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala index 22f13a494cd4c818ef6ff9729c335aeb53bd48c4..224b2c6c6f79434d4e09abd22b8af033703ac11e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala @@ -446,7 +446,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes .saveAsTable("t") // Using only a subset of all partition columns - intercept[Throwable] { + intercept[AnalysisException] { partitionedTestDF2.write .format(dataSourceName) .mode(SaveMode.Append)