Skip to content
Snippets Groups Projects
Commit 3b22291b authored by Dongjoon Hyun's avatar Dongjoon Hyun Committed by Reynold Xin
Browse files

[SPARK-16387][SQL] JDBC Writer should use dialect to quote field names.

## What changes were proposed in this pull request?

Currently, JDBC Writer uses dialects to get datatypes, but doesn't to quote field names. This PR uses dialects to quote the field names, too.

**Reported Error Scenario (MySQL case)**
```scala
scala> val url="jdbc:mysql://localhost:3306/temp"
scala> val prop = new java.util.Properties
scala> prop.setProperty("user","root")
scala> spark.createDataset(Seq("a","b","c")).toDF("order")
scala> df.write.mode("overwrite").jdbc(url, "temptable", prop)
...MySQLSyntaxErrorException: ... near 'order TEXT )
```

## How was this patch tested?

Pass the Jenkins tests and manually do the above case.

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #14107 from dongjoon-hyun/SPARK-16387.
parent 60ba436b
No related branches found
No related tags found
No related merge requests found
......@@ -100,8 +100,9 @@ object JdbcUtils extends Logging {
/**
* Returns a PreparedStatement that inserts a row into table via conn.
*/
def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = {
val columns = rddSchema.fields.map(_.name).mkString(",")
def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect)
: PreparedStatement = {
val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders)"
conn.prepareStatement(sql)
......@@ -177,7 +178,7 @@ object JdbcUtils extends Logging {
if (supportsTransactions) {
conn.setAutoCommit(false) // Everything in the same db transaction.
}
val stmt = insertStatement(conn, table, rddSchema)
val stmt = insertStatement(conn, table, rddSchema, dialect)
try {
var rowCount = 0
while (iterator.hasNext) {
......@@ -260,7 +261,7 @@ object JdbcUtils extends Logging {
val sb = new StringBuilder()
val dialect = JdbcDialects.get(url)
df.schema.fields foreach { field =>
val name = field.name
val name = dialect.quoteIdentifier(field.name)
val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
......
......@@ -764,4 +764,10 @@ class JDBCSuite extends SparkFunSuite
assertEmptyQuery(s"SELECT * FROM tempFrame where $FALSE2")
}
}
test("SPARK-16387: Reserved SQL words are not escaped by JDBC writer") {
val df = spark.createDataset(Seq("a", "b", "c")).toDF("order")
val schema = JdbcUtils.schemaString(df, "jdbc:mysql://localhost:3306/temp")
assert(schema.contains("`order` TEXT"))
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment