Skip to content
Snippets Groups Projects
Commit d9251496 authored by Wenchen Fan's avatar Wenchen Fan Committed by Michael Armbrust
Browse files

[SPARK-10186][SQL] support postgre array type in JDBCRDD

Add ARRAY support to `PostgresDialect`.

Nested ARRAY is not allowed for now because it's hard to get the array dimension info. See http://stackoverflow.com/questions/16619113/how-to-get-array-base-type-in-postgres-via-jdbc

Thanks for the initial work from mariusvniekerk !

Close https://github.com/apache/spark/pull/9137

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9662 from cloud-fan/postgre.
parent 0158ff77
No related branches found
No related tags found
No related merge requests found
......@@ -20,6 +20,8 @@ package org.apache.spark.sql.jdbc
import java.sql.Connection
import java.util.Properties
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{Literal, If}
import org.apache.spark.tags.DockerTest
@DockerTest
......@@ -37,28 +39,32 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
override def dataPreparation(conn: Connection): Unit = {
conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
conn.setCatalog("foo")
conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, "
+ "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate()
conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, "
+ "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, "
+ "c10 integer[], c11 text[])").executeUpdate()
conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', "
+ "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate()
+ "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', "
+ """'{1, 2}', '{"a", null, "b"}')""").executeUpdate()
}
test("Type mapping for various types") {
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
val rows = df.collect()
assert(rows.length == 1)
val types = rows(0).toSeq.map(x => x.getClass.toString)
assert(types.length == 10)
assert(types(0).equals("class java.lang.String"))
assert(types(1).equals("class java.lang.Integer"))
assert(types(2).equals("class java.lang.Double"))
assert(types(3).equals("class java.lang.Long"))
assert(types(4).equals("class java.lang.Boolean"))
assert(types(5).equals("class [B"))
assert(types(6).equals("class [B"))
assert(types(7).equals("class java.lang.Boolean"))
assert(types(8).equals("class java.lang.String"))
assert(types(9).equals("class java.lang.String"))
val types = rows(0).toSeq.map(x => x.getClass)
assert(types.length == 12)
assert(classOf[String].isAssignableFrom(types(0)))
assert(classOf[java.lang.Integer].isAssignableFrom(types(1)))
assert(classOf[java.lang.Double].isAssignableFrom(types(2)))
assert(classOf[java.lang.Long].isAssignableFrom(types(3)))
assert(classOf[java.lang.Boolean].isAssignableFrom(types(4)))
assert(classOf[Array[Byte]].isAssignableFrom(types(5)))
assert(classOf[Array[Byte]].isAssignableFrom(types(6)))
assert(classOf[java.lang.Boolean].isAssignableFrom(types(7)))
assert(classOf[String].isAssignableFrom(types(8)))
assert(classOf[String].isAssignableFrom(types(9)))
assert(classOf[Seq[Int]].isAssignableFrom(types(10)))
assert(classOf[Seq[String]].isAssignableFrom(types(11)))
assert(rows(0).getString(0).equals("hello"))
assert(rows(0).getInt(1) == 42)
assert(rows(0).getDouble(2) == 1.25)
......@@ -72,11 +78,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(rows(0).getBoolean(7) == true)
assert(rows(0).getString(8) == "172.16.0.42")
assert(rows(0).getString(9) == "192.168.0.0/16")
assert(rows(0).getSeq(10) == Seq(1, 2))
assert(rows(0).getSeq(11) == Seq("a", null, "b"))
}
test("Basic write test") {
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
df.write.jdbc(jdbcUrl, "public.barcopy", new Properties)
// Test only that it doesn't crash.
df.write.jdbc(jdbcUrl, "public.barcopy", new Properties)
// Test write null values.
df.select(df.queryExecution.analyzed.output.map { a =>
Column(If(Literal(true), Literal(null), a)).as(a.name)
}: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties)
}
}
......@@ -25,7 +25,7 @@ import org.apache.commons.lang3.StringUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.{GenericArrayData, DateTimeUtils}
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
......@@ -324,25 +324,27 @@ private[sql] class JDBCRDD(
case object StringConversion extends JDBCConversion
case object TimestampConversion extends JDBCConversion
case object BinaryConversion extends JDBCConversion
case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion
/**
* Maps a StructType to a type tag list.
*/
def getConversions(schema: StructType): Array[JDBCConversion] = {
schema.fields.map(sf => sf.dataType match {
case BooleanType => BooleanConversion
case DateType => DateConversion
case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
case DoubleType => DoubleConversion
case FloatType => FloatConversion
case IntegerType => IntegerConversion
case LongType =>
if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion
case StringType => StringConversion
case TimestampType => TimestampConversion
case BinaryType => BinaryConversion
case _ => throw new IllegalArgumentException(s"Unsupported field $sf")
}).toArray
def getConversions(schema: StructType): Array[JDBCConversion] =
schema.fields.map(sf => getConversions(sf.dataType, sf.metadata))
private def getConversions(dt: DataType, metadata: Metadata): JDBCConversion = dt match {
case BooleanType => BooleanConversion
case DateType => DateConversion
case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
case DoubleType => DoubleConversion
case FloatType => FloatConversion
case IntegerType => IntegerConversion
case LongType => if (metadata.contains("binarylong")) BinaryLongConversion else LongConversion
case StringType => StringConversion
case TimestampType => TimestampConversion
case BinaryType => BinaryConversion
case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata))
case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}")
}
/**
......@@ -420,16 +422,44 @@ private[sql] class JDBCRDD(
mutableRow.update(i, null)
}
case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
case BinaryLongConversion => {
case BinaryLongConversion =>
val bytes = rs.getBytes(pos)
var ans = 0L
var j = 0
while (j < bytes.size) {
ans = 256 * ans + (255 & bytes(j))
j = j + 1;
j = j + 1
}
mutableRow.setLong(i, ans)
}
case ArrayConversion(elementConversion) =>
val array = rs.getArray(pos).getArray
if (array != null) {
val data = elementConversion match {
case TimestampConversion =>
array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
}
case StringConversion =>
array.asInstanceOf[Array[java.lang.String]]
.map(UTF8String.fromString)
case DateConversion =>
array.asInstanceOf[Array[java.sql.Date]].map { date =>
nullSafeConvert(date, DateTimeUtils.fromJavaDate)
}
case DecimalConversion(p, s) =>
array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, p, s))
}
case BinaryLongConversion =>
throw new IllegalArgumentException(s"Unsupported array element conversion $i")
case _: ArrayConversion =>
throw new IllegalArgumentException("Nested arrays unsupported")
case _ => array.asInstanceOf[Array[Any]]
}
mutableRow.update(i, new GenericArrayData(data))
} else {
mutableRow.update(i, null)
}
}
if (rs.wasNull) mutableRow.setNullAt(i)
i = i + 1
......@@ -488,4 +518,12 @@ private[sql] class JDBCRDD(
nextValue
}
}
private def nullSafeConvert[T](input: T, f: T => Any): Any = {
if (input == null) {
null
} else {
f(input)
}
}
}
......@@ -23,7 +23,7 @@ import java.util.Properties
import scala.util.Try
import org.apache.spark.Logging
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType, JdbcDialects}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row}
......@@ -72,6 +72,35 @@ object JdbcUtils extends Logging {
conn.prepareStatement(sql.toString())
}
/**
* Retrieve standard jdbc types.
* @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
* @return The default JdbcType for this DataType
*/
def getCommonJDBCType(dt: DataType): Option[JdbcType] = {
dt match {
case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER))
case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT))
case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE))
case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT))
case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT))
case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT))
case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT))
case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB))
case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB))
case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP))
case DateType => Option(JdbcType("DATE", java.sql.Types.DATE))
case t: DecimalType => Option(
JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL))
case _ => None
}
}
private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse(
throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
}
/**
* Saves a partition of a DataFrame to the JDBC database. This is done in
* a single database transaction in order to avoid repeatedly inserting
......@@ -92,7 +121,8 @@ object JdbcUtils extends Logging {
iterator: Iterator[Row],
rddSchema: StructType,
nullTypes: Array[Int],
batchSize: Int): Iterator[Byte] = {
batchSize: Int,
dialect: JdbcDialect): Iterator[Byte] = {
val conn = getConnection()
var committed = false
try {
......@@ -121,6 +151,11 @@ object JdbcUtils extends Logging {
case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
case ArrayType(et, _) =>
val array = conn.createArrayOf(
getJdbcType(et, dialect).databaseTypeDefinition.toLowerCase,
row.getSeq[AnyRef](i).toArray)
stmt.setArray(i + 1, array)
case _ => throw new IllegalArgumentException(
s"Can't translate non-null value for field $i")
}
......@@ -169,23 +204,7 @@ object JdbcUtils extends Logging {
val dialect = JdbcDialects.get(url)
df.schema.fields foreach { field => {
val name = field.name
val typ: String =
dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
field.dataType match {
case IntegerType => "INTEGER"
case LongType => "BIGINT"
case DoubleType => "DOUBLE PRECISION"
case FloatType => "REAL"
case ShortType => "INTEGER"
case ByteType => "BYTE"
case BooleanType => "BIT(1)"
case StringType => "TEXT"
case BinaryType => "BLOB"
case TimestampType => "TIMESTAMP"
case DateType => "DATE"
case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})"
case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
})
val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
}}
......@@ -202,23 +221,7 @@ object JdbcUtils extends Logging {
properties: Properties = new Properties()) {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse(
field.dataType match {
case IntegerType => java.sql.Types.INTEGER
case LongType => java.sql.Types.BIGINT
case DoubleType => java.sql.Types.DOUBLE
case FloatType => java.sql.Types.REAL
case ShortType => java.sql.Types.INTEGER
case ByteType => java.sql.Types.INTEGER
case BooleanType => java.sql.Types.BIT
case StringType => java.sql.Types.CLOB
case BinaryType => java.sql.Types.BLOB
case TimestampType => java.sql.Types.TIMESTAMP
case DateType => java.sql.Types.DATE
case t: DecimalType => java.sql.Types.DECIMAL
case _ => throw new IllegalArgumentException(
s"Can't translate null value for field $field")
})
getJdbcType(field.dataType, dialect).jdbcNullType
}
val rddSchema = df.schema
......@@ -226,7 +229,7 @@ object JdbcUtils extends Logging {
val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
val batchSize = properties.getProperty("batchsize", "1000").toInt
df.foreachPartition { iterator =>
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize)
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect)
}
}
......
......@@ -51,7 +51,7 @@ case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int)
* for the given Catalyst type.
*/
@DeveloperApi
abstract class JdbcDialect {
abstract class JdbcDialect extends Serializable {
/**
* Check if this dialect instance can handle a certain jdbc url.
* @param url the jdbc url.
......
......@@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc
import java.sql.Types
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.types._
......@@ -29,22 +30,40 @@ private object PostgresDialect extends JdbcDialect {
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
Option(BinaryType)
} else if (sqlType == Types.OTHER && typeName.equals("cidr")) {
Option(StringType)
} else if (sqlType == Types.OTHER && typeName.equals("inet")) {
Option(StringType)
} else if (sqlType == Types.OTHER && typeName.equals("json")) {
Option(StringType)
} else if (sqlType == Types.OTHER && typeName.equals("jsonb")) {
Option(StringType)
Some(BinaryType)
} else if (sqlType == Types.OTHER) {
toCatalystType(typeName).filter(_ == StringType)
} else if (sqlType == Types.ARRAY && typeName.length > 1 && typeName(0) == '_') {
toCatalystType(typeName.drop(1)).map(ArrayType(_))
} else None
}
// TODO: support more type names.
private def toCatalystType(typeName: String): Option[DataType] = typeName match {
case "bool" => Some(BooleanType)
case "bit" => Some(BinaryType)
case "int2" => Some(ShortType)
case "int4" => Some(IntegerType)
case "int8" | "oid" => Some(LongType)
case "float4" => Some(FloatType)
case "money" | "float8" => Some(DoubleType)
case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" =>
Some(StringType)
case "bytea" => Some(BinaryType)
case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType)
case "date" => Some(DateType)
case "numeric" => Some(DecimalType.SYSTEM_DEFAULT)
case _ => None
}
override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR))
case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY))
case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN))
case StringType => Some(JdbcType("TEXT", Types.CHAR))
case BinaryType => Some(JdbcType("BYTEA", Types.BINARY))
case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN))
case ArrayType(et, _) if et.isInstanceOf[AtomicType] =>
getJDBCType(et).map(_.databaseTypeDefinition)
.orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition))
.map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY))
case _ => None
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment