diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b2a66dd417b4c24dc7750456a5ec3b87e7f7e43b..745bb4ec9cf1cf38e943b25359541c87be490549 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -255,7 +255,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { val conn = JdbcUtils.createConnection(url, props) try { - var tableExists = JdbcUtils.tableExists(conn, table) + var tableExists = JdbcUtils.tableExists(conn, url, table) if (mode == SaveMode.Ignore && tableExists) { return diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 26788b2a4fd69ad174a002b612dffc4325163a04..f89d55b20e212680706a7dcdab0f9d06f196d727 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -42,10 +42,13 @@ object JdbcUtils extends Logging { /** * Returns true if the table already exists in the JDBC database. */ - def tableExists(conn: Connection, table: String): Boolean = { + def tableExists(conn: Connection, url: String, table: String): Boolean = { + val dialect = JdbcDialects.get(url) + // Somewhat hacky, but there isn't a good way to identify whether a table exists for all - // SQL database systems, considering "table" could also include the database name. - Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess + // SQL database systems using JDBC meta data calls, considering "table" could also include + // the database name. Query used to find table exists can be overriden by the dialects. + Try(conn.prepareStatement(dialect.getTableExistsQuery(table)).executeQuery()).isSuccess } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index c6d05c9b83b98540b5473243937e50c96208443e..68ebaaca6c53d0832272a00442cb66558a6f686a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -88,6 +88,17 @@ abstract class JdbcDialect { def quoteIdentifier(colName: String): String = { s""""$colName"""" } + + /** + * Get the SQL query that should be used to find if the given table exists. Dialects can + * override this method to return a query that works best in a particular database. + * @param table The name of the table. + * @return The SQL query to use for checking the table. + */ + def getTableExistsQuery(table: String): String = { + s"SELECT * FROM $table WHERE 1=0" + } + } /** @@ -198,6 +209,11 @@ case object PostgresDialect extends JdbcDialect { case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) case _ => None } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } + } /** @@ -222,6 +238,10 @@ case object MySQLDialect extends JdbcDialect { override def quoteIdentifier(colName: String): String = { s"`$colName`" } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index ed710689cc6704dd9fe3463d3d1939edf7eb6208..5ab9381de4d669b7b56a0e0dbc26e9f6aca46e1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -450,4 +450,18 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)") } + + test("table exists query by jdbc dialect") { + val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") + val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db") + val h2 = JdbcDialects.get(url) + val table = "weblogs" + val defaultQuery = s"SELECT * FROM $table WHERE 1=0" + val limitQuery = s"SELECT 1 FROM $table LIMIT 1" + assert(MySQL.getTableExistsQuery(table) == limitQuery) + assert(Postgres.getTableExistsQuery(table) == limitQuery) + assert(db2.getTableExistsQuery(table) == defaultQuery) + assert(h2.getTableExistsQuery(table) == defaultQuery) + } }