diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 74f397c01e2f582210c2945669082dfe15f9a594..e39d936f3933fdce30e23c27b70a6b175d0b5da3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -57,6 +57,7 @@ class JdbcRelationProvider extends CreatableRelationProvider val table = jdbcOptions.table val createTableOptions = jdbcOptions.createTableOptions val isTruncate = jdbcOptions.isTruncate + val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis val conn = JdbcUtils.createConnectionFactory(jdbcOptions)() try { @@ -67,16 +68,18 @@ class JdbcRelationProvider extends CreatableRelationProvider if (isTruncate && isCascadingTruncateTable(url) == Some(false)) { // In this case, we should truncate table and then load. truncateTable(conn, table) - saveTable(df, url, table, jdbcOptions) + val tableSchema = JdbcUtils.getSchemaOption(conn, url, table) + saveTable(df, url, table, tableSchema, isCaseSensitive, jdbcOptions) } else { // Otherwise, do not truncate the table, instead drop and recreate it dropTable(conn, table) createTable(df.schema, url, table, createTableOptions, conn) - saveTable(df, url, table, jdbcOptions) + saveTable(df, url, table, Some(df.schema), isCaseSensitive, jdbcOptions) } case SaveMode.Append => - saveTable(df, url, table, jdbcOptions) + val tableSchema = JdbcUtils.getSchemaOption(conn, url, table) + saveTable(df, url, table, tableSchema, isCaseSensitive, jdbcOptions) case SaveMode.ErrorIfExists => throw new AnalysisException( @@ -89,7 +92,7 @@ class JdbcRelationProvider extends CreatableRelationProvider } } else { createTable(df.schema, url, table, createTableOptions, conn) - saveTable(df, url, table, jdbcOptions) + saveTable(df, url, table, Some(df.schema), isCaseSensitive, jdbcOptions) } } finally { conn.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index ff29a15960c5702965e7b13d2392884135fcc840..b13849475811f8f06555a73036651a771262522d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.spark.TaskContext import org.apache.spark.executor.InputMetrics import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow @@ -108,14 +108,36 @@ object JdbcUtils extends Logging { } /** - * Returns a PreparedStatement that inserts a row into table via conn. + * Returns an Insert SQL statement for inserting a row into the target table via JDBC conn. */ - def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect) - : PreparedStatement = { - val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",") + def getInsertStatement( + table: String, + rddSchema: StructType, + tableSchema: Option[StructType], + isCaseSensitive: Boolean, + dialect: JdbcDialect): String = { + val columns = if (tableSchema.isEmpty) { + rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",") + } else { + val columnNameEquality = if (isCaseSensitive) { + org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution + } else { + org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution + } + // The generated insert statement needs to follow rddSchema's column sequence and + // tableSchema's column names. When appending data into some case-sensitive DBMSs like + // PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of + // RDD column names for user convenience. + val tableColumnNames = tableSchema.get.fieldNames + rddSchema.fields.map { col => + val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse { + throw new AnalysisException(s"""Column "${col.name}" not found in schema $tableSchema""") + } + dialect.quoteIdentifier(normalizedName) + }.mkString(",") + } val placeholders = rddSchema.fields.map(_ => "?").mkString(",") - val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders)" - conn.prepareStatement(sql) + s"INSERT INTO $table ($columns) VALUES ($placeholders)" } /** @@ -210,6 +232,26 @@ object JdbcUtils extends Logging { answer } + /** + * Returns the schema if the table already exists in the JDBC database. + */ + def getSchemaOption(conn: Connection, url: String, table: String): Option[StructType] = { + val dialect = JdbcDialects.get(url) + + try { + val statement = conn.prepareStatement(dialect.getSchemaQuery(table)) + try { + Some(getSchema(statement.executeQuery(), dialect)) + } catch { + case _: SQLException => None + } finally { + statement.close() + } + } catch { + case _: SQLException => None + } + } + /** * Takes a [[ResultSet]] and returns its Catalyst schema. * @@ -531,7 +573,7 @@ object JdbcUtils extends Logging { table: String, iterator: Iterator[Row], rddSchema: StructType, - nullTypes: Array[Int], + insertStmt: String, batchSize: Int, dialect: JdbcDialect, isolationLevel: Int): Iterator[Byte] = { @@ -568,9 +610,9 @@ object JdbcUtils extends Logging { conn.setAutoCommit(false) // Everything in the same db transaction. conn.setTransactionIsolation(finalIsolationLevel) } - val stmt = insertStatement(conn, table, rddSchema, dialect) - val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType) - .map(makeSetter(conn, dialect, _)).toArray + val stmt = conn.prepareStatement(insertStmt) + val setters = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType)) + val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType) val numFields = rddSchema.fields.length try { @@ -657,16 +699,16 @@ object JdbcUtils extends Logging { df: DataFrame, url: String, table: String, + tableSchema: Option[StructType], + isCaseSensitive: Boolean, options: JDBCOptions): Unit = { val dialect = JdbcDialects.get(url) - val nullTypes: Array[Int] = df.schema.fields.map { field => - getJdbcType(field.dataType, dialect).jdbcNullType - } - val rddSchema = df.schema val getConnection: () => Connection = createConnectionFactory(options) val batchSize = options.batchSize val isolationLevel = options.isolationLevel + + val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect) val repartitionedDF = options.numPartitions match { case Some(n) if n <= 0 => throw new IllegalArgumentException( s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " + @@ -675,7 +717,7 @@ object JdbcUtils extends Logging { case _ => df } repartitionedDF.foreachPartition(iterator => savePartition( - getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel) + getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index f49ac23149709ed00e84fad6e10c8b0587398b10..354af29d4237b1d825e92ee819a98a15b153eb53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -24,9 +24,9 @@ import scala.collection.JavaConverters.propertiesAsScalaMapConverter import org.scalatest.BeforeAndAfter -import org.apache.spark.SparkException -import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.{AnalysisException, Row, SaveMode} import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -96,6 +96,10 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { StructField("id", IntegerType) :: StructField("seq", IntegerType) :: Nil) + private lazy val schema4 = StructType( + StructField("NAME", StringType) :: + StructField("ID", IntegerType) :: Nil) + test("Basic CREATE") { val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) @@ -165,6 +169,26 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).collect()(0).length) } + test("SPARK-18123 Append with column names with different cases") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema4) + + df.write.jdbc(url, "TEST.APPENDTEST", new Properties()) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val m = intercept[AnalysisException] { + df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties()) + }.getMessage + assert(m.contains("Column \"NAME\" not found")) + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties()) + assert(3 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).count()) + assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).collect()(0).length) + } + } + test("Truncate") { JdbcDialects.registerDialect(testH2Dialect) val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) @@ -177,7 +201,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count()) assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) - val m = intercept[SparkException] { + val m = intercept[AnalysisException] { df3.write.mode(SaveMode.Overwrite).option("truncate", true) .jdbc(url1, "TEST.TRUNCATETEST", properties) }.getMessage @@ -203,9 +227,10 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { val df2 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3) df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties()) - intercept[org.apache.spark.SparkException] { + val m = intercept[AnalysisException] { df2.write.mode(SaveMode.Append).jdbc(url, "TEST.INCOMPATIBLETEST", new Properties()) - } + }.getMessage + assert(m.contains("Column \"seq\" not found")) } test("INSERT to JDBC Datasource") {