diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 6dbed26b0dec4a03eda3e7d0122a25505c1a8892..44a9f312bd76c19fe13176d53c6eeaabf1ddbb7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, HadoopFsRelation} -import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types.StructType /** @@ -415,39 +415,49 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { assertNotPartitioned("jdbc") assertNotBucketed("jdbc") + // to add required options like URL and dbtable + val params = extraOptions.toMap ++ Map("url" -> url, "dbtable" -> table) + val jdbcOptions = new JDBCOptions(params) + val jdbcUrl = jdbcOptions.url + val jdbcTable = jdbcOptions.table + val props = new Properties() extraOptions.foreach { case (key, value) => props.put(key, value) } // connectionProperties should override settings in extraOptions props.putAll(connectionProperties) - val conn = JdbcUtils.createConnectionFactory(url, props)() + val conn = JdbcUtils.createConnectionFactory(jdbcUrl, props)() try { - var tableExists = JdbcUtils.tableExists(conn, url, table) + var tableExists = JdbcUtils.tableExists(conn, jdbcUrl, jdbcTable) if (mode == SaveMode.Ignore && tableExists) { return } if (mode == SaveMode.ErrorIfExists && tableExists) { - sys.error(s"Table $table already exists.") + sys.error(s"Table $jdbcTable already exists.") } if (mode == SaveMode.Overwrite && tableExists) { - if (extraOptions.getOrElse("truncate", "false").toBoolean && - JdbcUtils.isCascadingTruncateTable(url) == Some(false)) { - JdbcUtils.truncateTable(conn, table) + if (jdbcOptions.isTruncate && + JdbcUtils.isCascadingTruncateTable(jdbcUrl) == Some(false)) { + JdbcUtils.truncateTable(conn, jdbcTable) } else { - JdbcUtils.dropTable(conn, table) + JdbcUtils.dropTable(conn, jdbcTable) tableExists = false } } // Create the table if the table didn't exist. if (!tableExists) { - val schema = JdbcUtils.schemaString(df, url) - val sql = s"CREATE TABLE $table ($schema)" + val schema = JdbcUtils.schemaString(df, jdbcUrl) + // To allow certain options to append when create a new table, which can be + // table_options or partition_options. + // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" + val createtblOptions = jdbcOptions.createTableOptions + val sql = s"CREATE TABLE $jdbcTable ($schema) $createtblOptions" val statement = conn.createStatement try { statement.executeUpdate(sql) @@ -459,7 +469,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { conn.close() } - JdbcUtils.saveTable(df, url, table, props) + JdbcUtils.saveTable(df, jdbcUrl, jdbcTable, props) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 6c6ec89746ee162ec3fde9e0f84e5480c4daa75f..1db090eaf9c9e4231071b99520648b6aec71afee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -20,14 +20,21 @@ package org.apache.spark.sql.execution.datasources.jdbc /** * Options for the JDBC data source. */ -private[jdbc] class JDBCOptions( +class JDBCOptions( @transient private val parameters: Map[String, String]) extends Serializable { + // ------------------------------------------------------------ + // Required parameters + // ------------------------------------------------------------ // a JDBC URL val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) // name of table val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) + + // ------------------------------------------------------------ + // Optional parameter list + // ------------------------------------------------------------ // the column used to partition val partitionColumn = parameters.getOrElse("partitionColumn", null) // the lower bound of partition column @@ -36,4 +43,14 @@ private[jdbc] class JDBCOptions( val upperBound = parameters.getOrElse("upperBound", null) // the number of partitions val numPartitions = parameters.getOrElse("numPartitions", null) + + // ------------------------------------------------------------ + // The options for DataFrameWriter + // ------------------------------------------------------------ + // if to truncate the table from the JDBC database + val isTruncate = parameters.getOrElse("truncate", "false").toBoolean + // the create table option , which can be table_options or partition_options. + // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" + // TODO: to reuse the existing partition parameters for those partition specific options + val createTableOptions = parameters.getOrElse("createTableOptions", "") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index d99b3cf975f4fc275139d21abf715e51e4399899..ff3309874f2e1a44dfd74a6ce9a36e48b2767ae8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -174,6 +174,18 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { JdbcDialects.unregisterDialect(testH2Dialect) } + test("createTableOptions") { + JdbcDialects.registerDialect(testH2Dialect) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val m = intercept[org.h2.jdbc.JdbcSQLException] { + df.write.option("createTableOptions", "ENGINE tableEngineName") + .jdbc(url1, "TEST.CREATETBLOPTS", properties) + }.getMessage + assert(m.contains("Class \"TABLEENGINENAME\" not found")) + JdbcDialects.unregisterDialect(testH2Dialect) + } + test("Incompatible INSERT to append") { val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) val df2 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3)