diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala index 3da34b1b382d70edf5516a0f1b7afa5e790b4b87..f5930bc281e8c14eec2c9fa78a06f5151d6c4517 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala @@ -21,10 +21,13 @@ import java.math.BigDecimal import java.sql.{Connection, Date, Timestamp} import java.util.Properties -import org.scalatest._ +import org.scalatest.Ignore +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{BooleanType, ByteType, ShortType, StructType} import org.apache.spark.tags.DockerTest + @DockerTest @Ignore // AMPLab Jenkins needs to be updated before shared memory works on docker class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { @@ -47,19 +50,22 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { conn.prepareStatement("INSERT INTO tbl VALUES (17,'dave')").executeUpdate() conn.prepareStatement("CREATE TABLE numbers ( small SMALLINT, med INTEGER, big BIGINT, " - + "deci DECIMAL(31,20), flt FLOAT, dbl DOUBLE)").executeUpdate() + + "deci DECIMAL(31,20), flt FLOAT, dbl DOUBLE, real REAL, " + + "decflt DECFLOAT, decflt16 DECFLOAT(16), decflt34 DECFLOAT(34))").executeUpdate() conn.prepareStatement("INSERT INTO numbers VALUES (17, 77777, 922337203685477580, " - + "123456745.56789012345000000000, 42.75, 5.4E-70)").executeUpdate() + + "123456745.56789012345000000000, 42.75, 5.4E-70, " + + "3.4028234663852886e+38, 4.2999, DECFLOAT('9.999999999999999E19', 16), " + + "DECFLOAT('1234567891234567.123456789123456789', 34))").executeUpdate() conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, ts TIMESTAMP )").executeUpdate() conn.prepareStatement("INSERT INTO dates VALUES ('1991-11-09', '13:31:24', " + "'2009-02-13 23:31:30')").executeUpdate() // TODO: Test locale conversion for strings. - conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c CLOB, d BLOB)") - .executeUpdate() - conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', BLOB('fox'))") + conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c CLOB, d BLOB, e XML)") .executeUpdate() + conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', BLOB('fox')," + + "'<cinfo cid=\"10\"><name>Kathy</name></cinfo>')").executeUpdate() } test("Basic test") { @@ -77,13 +83,17 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 6) + assert(types.length == 10) assert(types(0).equals("class java.lang.Integer")) assert(types(1).equals("class java.lang.Integer")) assert(types(2).equals("class java.lang.Long")) assert(types(3).equals("class java.math.BigDecimal")) assert(types(4).equals("class java.lang.Double")) assert(types(5).equals("class java.lang.Double")) + assert(types(6).equals("class java.lang.Float")) + assert(types(7).equals("class java.math.BigDecimal")) + assert(types(8).equals("class java.math.BigDecimal")) + assert(types(9).equals("class java.math.BigDecimal")) assert(rows(0).getInt(0) == 17) assert(rows(0).getInt(1) == 77777) assert(rows(0).getLong(2) == 922337203685477580L) @@ -91,6 +101,10 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { assert(rows(0).getAs[BigDecimal](3).equals(bd)) assert(rows(0).getDouble(4) == 42.75) assert(rows(0).getDouble(5) == 5.4E-70) + assert(rows(0).getFloat(6) == 3.4028234663852886e+38) + assert(rows(0).getDecimal(7) == new BigDecimal("4.299900000000000000")) + assert(rows(0).getDecimal(8) == new BigDecimal("99999999999999990000.000000000000000000")) + assert(rows(0).getDecimal(9) == new BigDecimal("1234567891234567.123456789123456789")) } test("Date types") { @@ -112,7 +126,7 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass.toString) - assert(types.length == 4) + assert(types.length == 5) assert(types(0).equals("class java.lang.String")) assert(types(1).equals("class java.lang.String")) assert(types(2).equals("class java.lang.String")) @@ -121,14 +135,27 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { assert(rows(0).getString(1).equals("quick")) assert(rows(0).getString(2).equals("brown")) assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](3), Array[Byte](102, 111, 120))) + assert(rows(0).getString(4).equals("""<cinfo cid="10"><name>Kathy</name></cinfo>""")) } test("Basic write test") { - // val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + // cast decflt column with precision value of 38 to DB2 max decimal precision value of 31. + val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + .selectExpr("small", "med", "big", "deci", "flt", "dbl", "real", + "cast(decflt as decimal(31, 5)) as decflt") val df2 = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) val df3 = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) - // df1.write.jdbc(jdbcUrl, "numberscopy", new Properties) + df1.write.jdbc(jdbcUrl, "numberscopy", new Properties) df2.write.jdbc(jdbcUrl, "datescopy", new Properties) df3.write.jdbc(jdbcUrl, "stringscopy", new Properties) + // spark types that does not have exact matching db2 table types. + val df4 = sqlContext.createDataFrame( + sparkContext.parallelize(Seq(Row("1".toShort, "20".toByte, true))), + new StructType().add("c1", ShortType).add("b", ByteType).add("c3", BooleanType)) + df4.write.jdbc(jdbcUrl, "otherscopy", new Properties) + val rows = sqlContext.read.jdbc(jdbcUrl, "otherscopy", new Properties).collect() + assert(rows(0).getInt(0) == 1) + assert(rows(0).getInt(1) == 20) + assert(rows(0).getString(2) == "1") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index 190463df0d9287cd0711702994b816d689684ca8..d160ad82888a26958b5d5b515e7db47039975173 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -17,15 +17,34 @@ package org.apache.spark.sql.jdbc -import org.apache.spark.sql.types.{BooleanType, DataType, StringType} +import java.sql.Types + +import org.apache.spark.sql.types._ private object DB2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2") + override def getCatalystType( + sqlType: Int, + typeName: String, + size: Int, + md: MetadataBuilder): Option[DataType] = sqlType match { + case Types.REAL => Option(FloatType) + case Types.OTHER => + typeName match { + case "DECFLOAT" => Option(DecimalType(38, 18)) + case "XML" => Option(StringType) + case t if (t.startsWith("TIMESTAMP")) => Option(TimestampType) // TIMESTAMP WITH TIMEZONE + case _ => None + } + case _ => None + } + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB)) case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) + case ShortType | ByteType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) case _ => None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 70bee929b31da829ff2867c573c269a376689cf0..d1daf860fdfff51e7a9b64c8f41d47217156fac8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -713,6 +713,15 @@ class JDBCSuite extends SparkFunSuite val db2Dialect = JdbcDialects.get("jdbc:db2://127.0.0.1/db") assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB") assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)") + assert(db2Dialect.getJDBCType(ShortType).map(_.databaseTypeDefinition).get == "SMALLINT") + assert(db2Dialect.getJDBCType(ByteType).map(_.databaseTypeDefinition).get == "SMALLINT") + // test db2 dialect mappings on read + assert(db2Dialect.getCatalystType(java.sql.Types.REAL, "REAL", 1, null) == Option(FloatType)) + assert(db2Dialect.getCatalystType(java.sql.Types.OTHER, "DECFLOAT", 1, null) == + Option(DecimalType(38, 18))) + assert(db2Dialect.getCatalystType(java.sql.Types.OTHER, "XML", 1, null) == Option(StringType)) + assert(db2Dialect.getCatalystType(java.sql.Types.OTHER, "TIMESTAMP WITH TIME ZONE", 1, null) == + Option(TimestampType)) } test("PostgresDialect type mapping") {