diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 2930f7bb4cae12340b29768793cfa832bc92de31..db68b9c86db1b21f6777c429508e6e5730a6016e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -212,13 +212,13 @@ private[sql] object JDBCRDD extends Logging { filters: Array[Filter], parts: Array[Partition]): RDD[Row] = { val dialect = JdbcDialects.get(url) - val enclosedColumns = requiredColumns.map(dialect.columnEnclosing(_)) + val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) new JDBCRDD( sc, getConnector(driver, url, properties), pruneSchema(schema, requiredColumns), fqTable, - enclosedColumns, + quotedColumns, filters, parts, properties) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 04052f80f5e787254b9ce0e772485cf075f75c1a..8849fc2f1f0ef51cea57fa3580c333770bfb1953 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.jdbc +import java.sql.Types + import org.apache.spark.sql.types._ import org.apache.spark.annotation.DeveloperApi -import java.sql.Types - /** * :: DeveloperApi :: * A database type definition coupled with the jdbc type needed to send null @@ -82,11 +82,10 @@ abstract class JdbcDialect { def getJDBCType(dt: DataType): Option[JdbcType] = None /** - * Enclose column name - * @param colName The coulmn name - * @return Enclosed column name + * Quotes the identifier. This is used to put quotes around the identifier in case the column + * name is a reserved keyword, or in case it contains characters that require quotes (e.g. space). */ - def columnEnclosing(colName: String): String = { + def quoteIdentifier(colName: String): String = { s""""$colName"""" } } @@ -150,18 +149,19 @@ object JdbcDialects { @DeveloperApi class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { - require(!dialects.isEmpty) + require(dialects.nonEmpty) - def canHandle(url : String): Boolean = + override def canHandle(url : String): Boolean = dialects.map(_.canHandle(url)).reduce(_ && _) override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = - dialects.map(_.getCatalystType(sqlType, typeName, size, md)).flatten.headOption - - override def getJDBCType(dt: DataType): Option[JdbcType] = - dialects.map(_.getJDBCType(dt)).flatten.headOption + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption + } + override def getJDBCType(dt: DataType): Option[JdbcType] = { + dialects.flatMap(_.getJDBCType(dt)).headOption + } } /** @@ -170,7 +170,7 @@ class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { */ @DeveloperApi case object NoopDialect extends JdbcDialect { - def canHandle(url : String): Boolean = true + override def canHandle(url : String): Boolean = true } /** @@ -179,7 +179,7 @@ case object NoopDialect extends JdbcDialect { */ @DeveloperApi case object PostgresDialect extends JdbcDialect { - def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") + override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { @@ -205,7 +205,7 @@ case object PostgresDialect extends JdbcDialect { */ @DeveloperApi case object MySQLDialect extends JdbcDialect { - def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") + override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { @@ -218,7 +218,7 @@ case object MySQLDialect extends JdbcDialect { } else None } - override def columnEnclosing(colName: String): String = { + override def quoteIdentifier(colName: String): String = { s"`$colName`" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index a228543953536e72382b8a46ab99b59bf46f6370..49d348c3ed21b66aaad7b0abe693fe0bb48b280a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -410,13 +410,13 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { assert(JdbcDialects.get("test.invalid") == NoopDialect) } - test("Enclosing column names by jdbc dialect") { + test("quote column names by jdbc dialect") { val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") val columns = Seq("abc", "key") - val MySQLColumns = columns.map(MySQL.columnEnclosing(_)) - val PostgresColumns = columns.map(Postgres.columnEnclosing(_)) + val MySQLColumns = columns.map(MySQL.quoteIdentifier(_)) + val PostgresColumns = columns.map(Postgres.quoteIdentifier(_)) assert(MySQLColumns === Seq("`abc`", "`key`")) assert(PostgresColumns === Seq(""""abc"""", """"key"""")) }