diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 72bda8fe1ef10e80459378155bec666fddeb7a83..d55cdcf28b23836e3af27c198b9211b7caaa484d 100644 --- a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -22,6 +22,7 @@ import java.util.Properties import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.types.{ArrayType, DecimalType} import org.apache.spark.tags.DockerTest @DockerTest @@ -42,10 +43,10 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { conn.prepareStatement("CREATE TYPE enum_type AS ENUM ('d1', 'd2')").executeUpdate() conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, " + "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, " - + "c10 integer[], c11 text[], c12 real[], c13 enum_type)").executeUpdate() + + "c10 integer[], c11 text[], c12 real[], c13 numeric(2,2)[], c14 enum_type)").executeUpdate() conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', " - + """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}', 'd1')""").executeUpdate() + + """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}', '{0.11, 0.22}', 'd1')""").executeUpdate() } test("Type mapping for various types") { @@ -53,7 +54,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass) - assert(types.length == 14) + assert(types.length == 15) assert(classOf[String].isAssignableFrom(types(0))) assert(classOf[java.lang.Integer].isAssignableFrom(types(1))) assert(classOf[java.lang.Double].isAssignableFrom(types(2))) @@ -67,7 +68,8 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(classOf[Seq[Int]].isAssignableFrom(types(10))) assert(classOf[Seq[String]].isAssignableFrom(types(11))) assert(classOf[Seq[Double]].isAssignableFrom(types(12))) - assert(classOf[String].isAssignableFrom(types(13))) + assert(classOf[Seq[BigDecimal]].isAssignableFrom(types(13))) + assert(classOf[String].isAssignableFrom(types(14))) assert(rows(0).getString(0).equals("hello")) assert(rows(0).getInt(1) == 42) assert(rows(0).getDouble(2) == 1.25) @@ -84,13 +86,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(rows(0).getSeq(10) == Seq(1, 2)) assert(rows(0).getSeq(11) == Seq("a", null, "b")) assert(rows(0).getSeq(12).toSeq == Seq(0.11f, 0.22f)) - assert(rows(0).getString(13) == "d1") + assert(rows(0).getSeq(13) == Seq("0.11", "0.22").map(BigDecimal(_).bigDecimal)) + assert(rows(0).getString(14) == "d1") } test("Basic write test") { val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) // Test only that it doesn't crash. df.write.jdbc(jdbcUrl, "public.barcopy", new Properties) + // Test that written numeric type has same DataType as input + assert(sqlContext.read.jdbc(jdbcUrl, "public.barcopy", new Properties).schema(13).dataType == + ArrayType(DecimalType(2, 2), true)) // Test write null values. df.select(df.queryExecution.analyzed.output.map { a => Column(Literal.create(null, a.dataType)).as(a.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index befba867bc4605bcf37b6eaed6c9088f371e8faa..ed02b3f95f841ffcb258474fe3a9357bffa87688 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -137,7 +137,9 @@ private[sql] object JDBCRDD extends Logging { val fieldScale = rsmd.getScale(i + 1) val isSigned = rsmd.isSigned(i + 1) val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls - val metadata = new MetadataBuilder().putString("name", columnName) + val metadata = new MetadataBuilder() + .putString("name", columnName) + .putLong("scale", fieldScale) val columnType = dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( getCatalystType(dataType, fieldSize, fieldScale, isSigned)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 69ba84646f088e5ab1a5b8f7b1456d824a3704a5..e295722cacf15ae3291227164f09425ddf758c7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -194,8 +194,11 @@ object JdbcUtils extends Logging { case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) case ArrayType(et, _) => + // remove type length parameters from end of type name + val typeName = getJdbcType(et, dialect).databaseTypeDefinition + .toLowerCase.split("\\(")(0) val array = conn.createArrayOf( - getJdbcType(et, dialect).databaseTypeDefinition.toLowerCase, + typeName, row.getSeq[AnyRef](i).toArray) stmt.setArray(i + 1, array) case _ => throw new IllegalArgumentException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 8d43966480a65050e70641d7d9987dbc93f22ccf..2d6c3974a833ea672cd45786a8143f14f0200776 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -32,14 +32,18 @@ private object PostgresDialect extends JdbcDialect { if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { Some(BinaryType) } else if (sqlType == Types.OTHER) { - toCatalystType(typeName).filter(_ == StringType) - } else if (sqlType == Types.ARRAY && typeName.length > 1 && typeName(0) == '_') { - toCatalystType(typeName.drop(1)).map(ArrayType(_)) + Some(StringType) + } else if (sqlType == Types.ARRAY) { + val scale = md.build.getLong("scale").toInt + // postgres array type names start with underscore + toCatalystType(typeName.drop(1), size, scale).map(ArrayType(_)) } else None } - // TODO: support more type names. - private def toCatalystType(typeName: String): Option[DataType] = typeName match { + private def toCatalystType( + typeName: String, + precision: Int, + scale: Int): Option[DataType] = typeName match { case "bool" => Some(BooleanType) case "bit" => Some(BinaryType) case "int2" => Some(ShortType) @@ -52,7 +56,7 @@ private object PostgresDialect extends JdbcDialect { case "bytea" => Some(BinaryType) case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType) case "date" => Some(DateType) - case "numeric" => Some(DecimalType.SYSTEM_DEFAULT) + case "numeric" | "decimal" => Some(DecimalType.bounded(precision, scale)) case _ => None } @@ -62,6 +66,8 @@ private object PostgresDialect extends JdbcDialect { case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT)) case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE)) + case t: DecimalType => Some( + JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC)) case ArrayType(et, _) if et.isInstanceOf[AtomicType] => getJDBCType(et).map(_.databaseTypeDefinition) .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition))