Skip to content
Snippets Groups Projects
Commit bfd3ee9f authored by Volodymyr Lyubinets's avatar Volodymyr Lyubinets Committed by Michael Armbrust
Browse files

[SPARK-6124] Support jdbc connection properties in OPTIONS part of the query

One more thing if this PR is considered to be OK - it might make sense to add extra .jdbc() API's that take Properties to SQLContext.

Author: Volodymyr Lyubinets <vlyubin@gmail.com>

Closes #4859 from vlyubin/jdbcProperties and squashes the following commits:

7a8cfda [Volodymyr Lyubinets] Support jdbc connection properties in OPTIONS part of the query
parent 6cd7058b
No related branches found
No related tags found
No related merge requests found
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
package org.apache.spark.sql.jdbc package org.apache.spark.sql.jdbc
import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException} import java.sql.{Connection, DriverManager, ResultSet, ResultSetMetaData, SQLException}
import java.util.Properties
import org.apache.commons.lang.StringEscapeUtils.escapeSql import org.apache.commons.lang.StringEscapeUtils.escapeSql
import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
...@@ -90,9 +91,9 @@ private[sql] object JDBCRDD extends Logging { ...@@ -90,9 +91,9 @@ private[sql] object JDBCRDD extends Logging {
* @throws SQLException if the table specification is garbage. * @throws SQLException if the table specification is garbage.
* @throws SQLException if the table contains an unsupported type. * @throws SQLException if the table contains an unsupported type.
*/ */
def resolveTable(url: String, table: String): StructType = { def resolveTable(url: String, table: String, properties: Properties): StructType = {
val quirks = DriverQuirks.get(url) val quirks = DriverQuirks.get(url)
val conn: Connection = DriverManager.getConnection(url) val conn: Connection = DriverManager.getConnection(url, properties)
try { try {
val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery() val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery()
try { try {
...@@ -147,7 +148,7 @@ private[sql] object JDBCRDD extends Logging { ...@@ -147,7 +148,7 @@ private[sql] object JDBCRDD extends Logging {
* *
* @return A function that loads the driver and connects to the url. * @return A function that loads the driver and connects to the url.
*/ */
def getConnector(driver: String, url: String): () => Connection = { def getConnector(driver: String, url: String, properties: Properties): () => Connection = {
() => { () => {
try { try {
if (driver != null) Class.forName(driver) if (driver != null) Class.forName(driver)
...@@ -156,7 +157,7 @@ private[sql] object JDBCRDD extends Logging { ...@@ -156,7 +157,7 @@ private[sql] object JDBCRDD extends Logging {
logWarning(s"Couldn't find class $driver", e); logWarning(s"Couldn't find class $driver", e);
} }
} }
DriverManager.getConnection(url) DriverManager.getConnection(url, properties)
} }
} }
/** /**
...@@ -179,6 +180,7 @@ private[sql] object JDBCRDD extends Logging { ...@@ -179,6 +180,7 @@ private[sql] object JDBCRDD extends Logging {
schema: StructType, schema: StructType,
driver: String, driver: String,
url: String, url: String,
properties: Properties,
fqTable: String, fqTable: String,
requiredColumns: Array[String], requiredColumns: Array[String],
filters: Array[Filter], filters: Array[Filter],
...@@ -189,7 +191,7 @@ private[sql] object JDBCRDD extends Logging { ...@@ -189,7 +191,7 @@ private[sql] object JDBCRDD extends Logging {
return new return new
JDBCRDD( JDBCRDD(
sc, sc,
getConnector(driver, url), getConnector(driver, url, properties),
prunedSchema, prunedSchema,
fqTable, fqTable,
requiredColumns, requiredColumns,
...@@ -361,7 +363,7 @@ private[sql] class JDBCRDD( ...@@ -361,7 +363,7 @@ private[sql] class JDBCRDD(
var ans = 0L var ans = 0L
var j = 0 var j = 0
while (j < bytes.size) { while (j < bytes.size) {
ans = 256*ans + (255 & bytes(j)) ans = 256 * ans + (255 & bytes(j))
j = j + 1; j = j + 1;
} }
mutableRow.setLong(i, ans) mutableRow.setLong(i, ans)
......
...@@ -17,16 +17,17 @@ ...@@ -17,16 +17,17 @@
package org.apache.spark.sql.jdbc package org.apache.spark.sql.jdbc
import org.apache.spark.rdd.RDD import java.sql.DriverManager
import org.apache.spark.sql.catalyst.expressions.Row import java.util.Properties
import org.apache.spark.sql.types.StructType
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import java.sql.DriverManager
import org.apache.spark.Partition import org.apache.spark.Partition
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
/** /**
* Data corresponding to one partition of a JDBCRDD. * Data corresponding to one partition of a JDBCRDD.
...@@ -115,18 +116,21 @@ private[sql] class DefaultSource extends RelationProvider { ...@@ -115,18 +116,21 @@ private[sql] class DefaultSource extends RelationProvider {
numPartitions.toInt) numPartitions.toInt)
} }
val parts = JDBCRelation.columnPartition(partitionInfo) val parts = JDBCRelation.columnPartition(partitionInfo)
JDBCRelation(url, table, parts)(sqlContext) val properties = new Properties() // Additional properties that we will pass to getConnection
parameters.foreach(kv => properties.setProperty(kv._1, kv._2))
JDBCRelation(url, table, parts, properties)(sqlContext)
} }
} }
private[sql] case class JDBCRelation( private[sql] case class JDBCRelation(
url: String, url: String,
table: String, table: String,
parts: Array[Partition])(@transient val sqlContext: SQLContext) parts: Array[Partition],
properties: Properties = new Properties())(@transient val sqlContext: SQLContext)
extends BaseRelation extends BaseRelation
with PrunedFilteredScan { with PrunedFilteredScan {
override val schema: StructType = JDBCRDD.resolveTable(url, table) override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
val driver: String = DriverManager.getDriver(url).getClass.getCanonicalName val driver: String = DriverManager.getDriver(url).getClass.getCanonicalName
...@@ -135,6 +139,7 @@ private[sql] case class JDBCRelation( ...@@ -135,6 +139,7 @@ private[sql] case class JDBCRelation(
schema, schema,
driver, driver,
url, url,
properties,
table, table,
requiredColumns, requiredColumns,
filters, filters,
......
...@@ -19,22 +19,31 @@ package org.apache.spark.sql.jdbc ...@@ -19,22 +19,31 @@ package org.apache.spark.sql.jdbc
import java.math.BigDecimal import java.math.BigDecimal
import java.sql.DriverManager import java.sql.DriverManager
import java.util.{Calendar, GregorianCalendar} import java.util.{Calendar, GregorianCalendar, Properties}
import org.apache.spark.sql.test._ import org.apache.spark.sql.test._
import org.h2.jdbc.JdbcSQLException
import org.scalatest.{FunSuite, BeforeAndAfter} import org.scalatest.{FunSuite, BeforeAndAfter}
import TestSQLContext._ import TestSQLContext._
import TestSQLContext.implicits._ import TestSQLContext.implicits._
class JDBCSuite extends FunSuite with BeforeAndAfter { class JDBCSuite extends FunSuite with BeforeAndAfter {
val url = "jdbc:h2:mem:testdb0" val url = "jdbc:h2:mem:testdb0"
val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
var conn: java.sql.Connection = null var conn: java.sql.Connection = null
val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte)
before { before {
Class.forName("org.h2.Driver") Class.forName("org.h2.Driver")
conn = DriverManager.getConnection(url) // Extra properties that will be specified for our database. We need these to test
// usage of parameters from OPTIONS clause in queries.
val properties = new Properties()
properties.setProperty("user", "testUser")
properties.setProperty("password", "testPass")
properties.setProperty("rowId", "false")
conn = DriverManager.getConnection(url, properties)
conn.prepareStatement("create schema test").executeUpdate() conn.prepareStatement("create schema test").executeUpdate()
conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate() conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate()
conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate() conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate()
...@@ -46,15 +55,15 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { ...@@ -46,15 +55,15 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
s""" s"""
|CREATE TEMPORARY TABLE foobar |CREATE TEMPORARY TABLE foobar
|USING org.apache.spark.sql.jdbc |USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.PEOPLE') |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " ")) """.stripMargin.replaceAll("\n", " "))
sql( sql(
s""" s"""
|CREATE TEMPORARY TABLE parts |CREATE TEMPORARY TABLE parts
|USING org.apache.spark.sql.jdbc |USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.PEOPLE', |OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass',
|partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3') | partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3')
""".stripMargin.replaceAll("\n", " ")) """.stripMargin.replaceAll("\n", " "))
conn.prepareStatement("create table test.inttypes (a INT, b BOOLEAN, c TINYINT, " conn.prepareStatement("create table test.inttypes (a INT, b BOOLEAN, c TINYINT, "
...@@ -68,12 +77,12 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { ...@@ -68,12 +77,12 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
s""" s"""
|CREATE TEMPORARY TABLE inttypes |CREATE TEMPORARY TABLE inttypes
|USING org.apache.spark.sql.jdbc |USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.INTTYPES') |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " ")) """.stripMargin.replaceAll("\n", " "))
conn.prepareStatement("create table test.strtypes (a BINARY(20), b VARCHAR(20), " conn.prepareStatement("create table test.strtypes (a BINARY(20), b VARCHAR(20), "
+ "c VARCHAR_IGNORECASE(20), d CHAR(20), e BLOB, f CLOB)").executeUpdate() + "c VARCHAR_IGNORECASE(20), d CHAR(20), e BLOB, f CLOB)").executeUpdate()
var stmt = conn.prepareStatement("insert into test.strtypes values (?, ?, ?, ?, ?, ?)") val stmt = conn.prepareStatement("insert into test.strtypes values (?, ?, ?, ?, ?, ?)")
stmt.setBytes(1, testBytes) stmt.setBytes(1, testBytes)
stmt.setString(2, "Sensitive") stmt.setString(2, "Sensitive")
stmt.setString(3, "Insensitive") stmt.setString(3, "Insensitive")
...@@ -85,7 +94,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { ...@@ -85,7 +94,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
s""" s"""
|CREATE TEMPORARY TABLE strtypes |CREATE TEMPORARY TABLE strtypes
|USING org.apache.spark.sql.jdbc |USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.STRTYPES') |OPTIONS (url '$url', dbtable 'TEST.STRTYPES', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " ")) """.stripMargin.replaceAll("\n", " "))
conn.prepareStatement("create table test.timetypes (a TIME, b DATE, c TIMESTAMP)" conn.prepareStatement("create table test.timetypes (a TIME, b DATE, c TIMESTAMP)"
...@@ -97,7 +106,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { ...@@ -97,7 +106,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
s""" s"""
|CREATE TEMPORARY TABLE timetypes |CREATE TEMPORARY TABLE timetypes
|USING org.apache.spark.sql.jdbc |USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.TIMETYPES') |OPTIONS (url '$url', dbtable 'TEST.TIMETYPES', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " ")) """.stripMargin.replaceAll("\n", " "))
...@@ -112,7 +121,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { ...@@ -112,7 +121,7 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
s""" s"""
|CREATE TEMPORARY TABLE flttypes |CREATE TEMPORARY TABLE flttypes
|USING org.apache.spark.sql.jdbc |USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.FLTTYPES') |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " ")) """.stripMargin.replaceAll("\n", " "))
// Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types. // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types.
...@@ -174,16 +183,17 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { ...@@ -174,16 +183,17 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
} }
test("Basic API") { test("Basic API") {
assert(TestSQLContext.jdbc(url, "TEST.PEOPLE").collect.size == 3) assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE").collect.size == 3)
} }
test("Partitioning via JDBCPartitioningInfo API") { test("Partitioning via JDBCPartitioningInfo API") {
assert(TestSQLContext.jdbc(url, "TEST.PEOPLE", "THEID", 0, 4, 3).collect.size == 3) assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3)
.collect.size == 3)
} }
test("Partitioning via list-of-where-clauses API") { test("Partitioning via list-of-where-clauses API") {
val parts = Array[String]("THEID < 2", "THEID >= 2") val parts = Array[String]("THEID < 2", "THEID >= 2")
assert(TestSQLContext.jdbc(url, "TEST.PEOPLE", parts).collect.size == 3) assert(TestSQLContext.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts).collect.size == 3)
} }
test("H2 integral types") { test("H2 integral types") {
...@@ -216,7 +226,6 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { ...@@ -216,7 +226,6 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
assert(rows(0).getString(5).equals("I am a clob!")) assert(rows(0).getString(5).equals("I am a clob!"))
} }
test("H2 time types") { test("H2 time types") {
val rows = sql("SELECT * FROM timetypes").collect() val rows = sql("SELECT * FROM timetypes").collect()
val cal = new GregorianCalendar(java.util.Locale.ROOT) val cal = new GregorianCalendar(java.util.Locale.ROOT)
...@@ -246,17 +255,31 @@ class JDBCSuite extends FunSuite with BeforeAndAfter { ...@@ -246,17 +255,31 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
.equals(new BigDecimal("123456789012345.54321543215432100000"))) .equals(new BigDecimal("123456789012345.54321543215432100000")))
} }
test("SQL query as table name") { test("SQL query as table name") {
sql( sql(
s""" s"""
|CREATE TEMPORARY TABLE hack |CREATE TEMPORARY TABLE hack
|USING org.apache.spark.sql.jdbc |USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable '(SELECT B, B*B FROM TEST.FLTTYPES)') |OPTIONS (url '$url', dbtable '(SELECT B, B*B FROM TEST.FLTTYPES)',
| user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " ")) """.stripMargin.replaceAll("\n", " "))
val rows = sql("SELECT * FROM hack").collect() val rows = sql("SELECT * FROM hack").collect()
assert(rows(0).getDouble(0) == 1.00000011920928955) // Yes, I meant ==. assert(rows(0).getDouble(0) == 1.00000011920928955) // Yes, I meant ==.
// For some reason, H2 computes this square incorrectly... // For some reason, H2 computes this square incorrectly...
assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12) assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12)
} }
test("Pass extra properties via OPTIONS") {
// We set rowId to false during setup, which means that _ROWID_ column should be absent from
// all tables. If rowId is true (default), the query below doesn't throw an exception.
intercept[JdbcSQLException] {
sql(
s"""
|CREATE TEMPORARY TABLE abc
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable '(SELECT _ROWID_ FROM test.people)',
| user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
}
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment