Skip to content
Snippets Groups Projects
Commit d8935db5 authored by Wenchen Fan's avatar Wenchen Fan Committed by Davies Liu
Browse files

[SPARK-15241] [SPARK-15242] [SQL] fix 2 decimal-related issues in RowEncoder

## What changes were proposed in this pull request?

SPARK-15241: We now support java decimal and catalyst decimal in external row, it makes sense to also support scala decimal.

SPARK-15242: This is a long-standing bug, and is exposed after https://github.com/apache/spark/pull/12364, which eliminate the `If` expression if the field is not nullable:
```
val fieldValue = serializerFor(
  GetExternalRowField(inputObject, i, externalDataTypeForInput(f.dataType)),
  f.dataType)
if (f.nullable) {
  If(
    Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
    Literal.create(null, f.dataType),
    fieldValue)
} else {
  fieldValue
}
```

Previously, we always use `DecimalType.SYSTEM_DEFAULT` as the output type of converted decimal field, which is wrong as it doesn't match the real decimal type. However, it works well because we always put converted field into `If` expression to do the null check, and `If` use its `trueValue`'s data type as its output type.
Now if we have a not nullable decimal field, then the converted field's output type will be `DecimalType.SYSTEM_DEFAULT`, and we will write wrong data into unsafe row.

The fix is simple, just use the given decimal type as the output type of converted decimal field.

These 2 issues was found at https://github.com/apache/spark/pull/13008

## How was this patch tested?

new tests in RowEncoderSuite

Author: Wenchen Fan <wenchen@databricks.com>

Closes #13019 from cloud-fan/encoder-decimal.
parent e1576478
No related branches found
No related tags found
No related merge requests found
......@@ -84,10 +84,10 @@ object RowEncoder {
"fromJavaDate",
inputObject :: Nil)
case _: DecimalType =>
case d: DecimalType =>
StaticInvoke(
Decimal.getClass,
DecimalType.SYSTEM_DEFAULT,
d,
"fromDecimal",
inputObject :: Nil)
......@@ -162,7 +162,7 @@ object RowEncoder {
* `org.apache.spark.sql.types.Decimal`.
*/
private def externalDataTypeForInput(dt: DataType): DataType = dt match {
// In order to support both Decimal and java BigDecimal in external row, we make this
// In order to support both Decimal and java/scala BigDecimal in external row, we make this
// as java.lang.Object.
case _: DecimalType => ObjectType(classOf[java.lang.Object])
case _ => externalDataTypeFor(dt)
......
......@@ -386,6 +386,7 @@ object Decimal {
def fromDecimal(value: Any): Decimal = {
value match {
case j: java.math.BigDecimal => apply(j)
case d: BigDecimal => apply(d)
case d: Decimal => d
}
}
......
......@@ -108,7 +108,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
encodeDecodeTest(new java.lang.Double(-3.7), "boxed double")
encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal")
// encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal")
encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal")
encodeDecodeTest(Decimal("32131413.211321313"), "catalyst decimal")
......@@ -336,6 +336,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
Arrays.deepEquals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]])
case (b1: Array[_], b2: Array[_]) =>
Arrays.equals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]])
case (left: Comparable[Any], right: Comparable[Any]) => left.compareTo(right) == 0
case _ => input == convertedBack
}
......
......@@ -143,21 +143,38 @@ class RowEncoderSuite extends SparkFunSuite {
assert(input.getStruct(0) == convertedBack.getStruct(0))
}
test("encode/decode Decimal") {
test("encode/decode decimal type") {
val schema = new StructType()
.add("int", IntegerType)
.add("string", StringType)
.add("double", DoubleType)
.add("decimal", DecimalType.SYSTEM_DEFAULT)
.add("java_decimal", DecimalType.SYSTEM_DEFAULT)
.add("scala_decimal", DecimalType.SYSTEM_DEFAULT)
.add("catalyst_decimal", DecimalType.SYSTEM_DEFAULT)
val encoder = RowEncoder(schema)
val input: Row = Row(100, "test", 0.123, Decimal(1234.5678))
val javaDecimal = new java.math.BigDecimal("1234.5678")
val scalaDecimal = BigDecimal("1234.5678")
val catalystDecimal = Decimal("1234.5678")
val input = Row(100, "test", 0.123, javaDecimal, scalaDecimal, catalystDecimal)
val row = encoder.toRow(input)
val convertedBack = encoder.fromRow(row)
// Decimal inside external row will be converted back to Java BigDecimal when decoding.
assert(input.get(3).asInstanceOf[Decimal].toJavaBigDecimal
.compareTo(convertedBack.getDecimal(3)) == 0)
// Decimal will be converted back to Java BigDecimal when decoding.
assert(convertedBack.getDecimal(3).compareTo(javaDecimal) == 0)
assert(convertedBack.getDecimal(4).compareTo(scalaDecimal.bigDecimal) == 0)
assert(convertedBack.getDecimal(5).compareTo(catalystDecimal.toJavaBigDecimal) == 0)
}
test("RowEncoder should preserve decimal precision and scale") {
val schema = new StructType().add("decimal", DecimalType(10, 5), false)
val encoder = RowEncoder(schema)
val decimal = Decimal("67123.45")
val input = Row(decimal)
val row = encoder.toRow(input)
assert(row.toSeq(schema).head == decimal)
}
test("RowEncoder should preserve schema nullability") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment