Skip to content
Snippets Groups Projects
Commit b85e2943 authored by Dongjoon Hyun's avatar Dongjoon Hyun Committed by gatorsmile
Browse files

[SPARK-18123][SQL] Use db column names instead of RDD column ones during JDBC Writing

## What changes were proposed in this pull request?

Apache Spark supports the following cases **by quoting RDD column names** while saving through JDBC.
- Allow reserved keyword as a column name, e.g., 'order'.
- Allow mixed-case colume names like the following, e.g., `[a: int, A: int]`.

  ``` scala
  scala> val df = sql("select 1 a, 1 A")
  df: org.apache.spark.sql.DataFrame = [a: int, A: int]
  ...
  scala> df.write.mode("overwrite").format("jdbc").options(option).save()
  scala> df.write.mode("append").format("jdbc").options(option).save()
  ```

This PR aims to use **database column names** instead of RDD column ones in order to support the following additionally.
Note that this case succeeds with `MySQL`, but fails on `Postgres`/`Oracle` before.

``` scala
val df1 = sql("select 1 a")
val df2 = sql("select 1 A")
...
df1.write.mode("overwrite").format("jdbc").options(option).save()
df2.write.mode("append").format("jdbc").options(option).save()
```
## How was this patch tested?

Pass the Jenkins test with a new testcase.

Author: Dongjoon Hyun <dongjoon@apache.org>
Author: gatorsmile <gatorsmile@gmail.com>

Closes #15664 from dongjoon-hyun/SPARK-18123.
parent 852782b8
No related branches found
No related tags found
No related merge requests found
......@@ -57,6 +57,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
val table = jdbcOptions.table
val createTableOptions = jdbcOptions.createTableOptions
val isTruncate = jdbcOptions.isTruncate
val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis
val conn = JdbcUtils.createConnectionFactory(jdbcOptions)()
try {
......@@ -67,16 +68,18 @@ class JdbcRelationProvider extends CreatableRelationProvider
if (isTruncate && isCascadingTruncateTable(url) == Some(false)) {
// In this case, we should truncate table and then load.
truncateTable(conn, table)
saveTable(df, url, table, jdbcOptions)
val tableSchema = JdbcUtils.getSchemaOption(conn, url, table)
saveTable(df, url, table, tableSchema, isCaseSensitive, jdbcOptions)
} else {
// Otherwise, do not truncate the table, instead drop and recreate it
dropTable(conn, table)
createTable(df.schema, url, table, createTableOptions, conn)
saveTable(df, url, table, jdbcOptions)
saveTable(df, url, table, Some(df.schema), isCaseSensitive, jdbcOptions)
}
case SaveMode.Append =>
saveTable(df, url, table, jdbcOptions)
val tableSchema = JdbcUtils.getSchemaOption(conn, url, table)
saveTable(df, url, table, tableSchema, isCaseSensitive, jdbcOptions)
case SaveMode.ErrorIfExists =>
throw new AnalysisException(
......@@ -89,7 +92,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
}
} else {
createTable(df.schema, url, table, createTableOptions, conn)
saveTable(df, url, table, jdbcOptions)
saveTable(df, url, table, Some(df.schema), isCaseSensitive, jdbcOptions)
}
} finally {
conn.close()
......
......@@ -26,7 +26,7 @@ import scala.util.control.NonFatal
import org.apache.spark.TaskContext
import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
......@@ -108,14 +108,36 @@ object JdbcUtils extends Logging {
}
/**
* Returns a PreparedStatement that inserts a row into table via conn.
* Returns an Insert SQL statement for inserting a row into the target table via JDBC conn.
*/
def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect)
: PreparedStatement = {
val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
def getInsertStatement(
table: String,
rddSchema: StructType,
tableSchema: Option[StructType],
isCaseSensitive: Boolean,
dialect: JdbcDialect): String = {
val columns = if (tableSchema.isEmpty) {
rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
} else {
val columnNameEquality = if (isCaseSensitive) {
org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
} else {
org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
}
// The generated insert statement needs to follow rddSchema's column sequence and
// tableSchema's column names. When appending data into some case-sensitive DBMSs like
// PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of
// RDD column names for user convenience.
val tableColumnNames = tableSchema.get.fieldNames
rddSchema.fields.map { col =>
val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
throw new AnalysisException(s"""Column "${col.name}" not found in schema $tableSchema""")
}
dialect.quoteIdentifier(normalizedName)
}.mkString(",")
}
val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders)"
conn.prepareStatement(sql)
s"INSERT INTO $table ($columns) VALUES ($placeholders)"
}
/**
......@@ -210,6 +232,26 @@ object JdbcUtils extends Logging {
answer
}
/**
* Returns the schema if the table already exists in the JDBC database.
*/
def getSchemaOption(conn: Connection, url: String, table: String): Option[StructType] = {
val dialect = JdbcDialects.get(url)
try {
val statement = conn.prepareStatement(dialect.getSchemaQuery(table))
try {
Some(getSchema(statement.executeQuery(), dialect))
} catch {
case _: SQLException => None
} finally {
statement.close()
}
} catch {
case _: SQLException => None
}
}
/**
* Takes a [[ResultSet]] and returns its Catalyst schema.
*
......@@ -531,7 +573,7 @@ object JdbcUtils extends Logging {
table: String,
iterator: Iterator[Row],
rddSchema: StructType,
nullTypes: Array[Int],
insertStmt: String,
batchSize: Int,
dialect: JdbcDialect,
isolationLevel: Int): Iterator[Byte] = {
......@@ -568,9 +610,9 @@ object JdbcUtils extends Logging {
conn.setAutoCommit(false) // Everything in the same db transaction.
conn.setTransactionIsolation(finalIsolationLevel)
}
val stmt = insertStatement(conn, table, rddSchema, dialect)
val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
.map(makeSetter(conn, dialect, _)).toArray
val stmt = conn.prepareStatement(insertStmt)
val setters = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
val numFields = rddSchema.fields.length
try {
......@@ -657,16 +699,16 @@ object JdbcUtils extends Logging {
df: DataFrame,
url: String,
table: String,
tableSchema: Option[StructType],
isCaseSensitive: Boolean,
options: JDBCOptions): Unit = {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
getJdbcType(field.dataType, dialect).jdbcNullType
}
val rddSchema = df.schema
val getConnection: () => Connection = createConnectionFactory(options)
val batchSize = options.batchSize
val isolationLevel = options.isolationLevel
val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
val repartitionedDF = options.numPartitions match {
case Some(n) if n <= 0 => throw new IllegalArgumentException(
s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " +
......@@ -675,7 +717,7 @@ object JdbcUtils extends Logging {
case _ => df
}
repartitionedDF.foreachPartition(iterator => savePartition(
getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)
getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
)
}
......
......@@ -24,9 +24,9 @@ import scala.collection.JavaConverters.propertiesAsScalaMapConverter
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkException
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.{AnalysisException, Row, SaveMode}
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
......@@ -96,6 +96,10 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
StructField("id", IntegerType) ::
StructField("seq", IntegerType) :: Nil)
private lazy val schema4 = StructType(
StructField("NAME", StringType) ::
StructField("ID", IntegerType) :: Nil)
test("Basic CREATE") {
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
......@@ -165,6 +169,26 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).collect()(0).length)
}
test("SPARK-18123 Append with column names with different cases") {
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema4)
df.write.jdbc(url, "TEST.APPENDTEST", new Properties())
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
val m = intercept[AnalysisException] {
df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties())
}.getMessage
assert(m.contains("Column \"NAME\" not found"))
}
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties())
assert(3 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).count())
assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).collect()(0).length)
}
}
test("Truncate") {
JdbcDialects.registerDialect(testH2Dialect)
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
......@@ -177,7 +201,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count())
assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
val m = intercept[SparkException] {
val m = intercept[AnalysisException] {
df3.write.mode(SaveMode.Overwrite).option("truncate", true)
.jdbc(url1, "TEST.TRUNCATETEST", properties)
}.getMessage
......@@ -203,9 +227,10 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
val df2 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3)
df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties())
intercept[org.apache.spark.SparkException] {
val m = intercept[AnalysisException] {
df2.write.mode(SaveMode.Append).jdbc(url, "TEST.INCOMPATIBLETEST", new Properties())
}
}.getMessage
assert(m.contains("Column \"seq\" not found"))
}
test("INSERT to JDBC Datasource") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment