Skip to content
Snippets Groups Projects
Commit 6e89d574 authored by liuxian's avatar liuxian Committed by Wenchen Fan
Browse files

[SPARK-20665][SQL] Bround" and "Round" function return NULL


   spark-sql>select bround(12.3, 2);
   spark-sql>NULL
For this case,  the expected result is 12.3, but it is null.
So ,when the second parameter is bigger than "decimal.scala", the result is not we expected.
"round" function  has the same problem. This PR can solve the problem for both of them.

unit test cases in MathExpressionsSuite and MathFunctionsSuite

Author: liuxian <liu.xian3@zte.com.cn>

Closes #17906 from 10110346/wip_lx_0509.

(cherry picked from commit 2b36eb69)
Signed-off-by: default avatarWenchen Fan <wenchen@databricks.com>
parent 92a71a66
No related branches found
Tags v2.1.1
No related merge requests found
...@@ -1021,10 +1021,10 @@ abstract class RoundBase(child: Expression, scale: Expression, ...@@ -1021,10 +1021,10 @@ abstract class RoundBase(child: Expression, scale: Expression,
// not overriding since _scale is a constant int at runtime // not overriding since _scale is a constant int at runtime
def nullSafeEval(input1: Any): Any = { def nullSafeEval(input1: Any): Any = {
child.dataType match { dataType match {
case _: DecimalType => case DecimalType.Fixed(_, s) =>
val decimal = input1.asInstanceOf[Decimal] val decimal = input1.asInstanceOf[Decimal]
if (decimal.changePrecision(decimal.precision, _scale, mode)) decimal else null if (decimal.changePrecision(decimal.precision, s, mode)) decimal else null
case ByteType => case ByteType =>
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
case ShortType => case ShortType =>
...@@ -1053,10 +1053,10 @@ abstract class RoundBase(child: Expression, scale: Expression, ...@@ -1053,10 +1053,10 @@ abstract class RoundBase(child: Expression, scale: Expression,
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val ce = child.genCode(ctx) val ce = child.genCode(ctx)
val evaluationCode = child.dataType match { val evaluationCode = dataType match {
case _: DecimalType => case DecimalType.Fixed(_, s) =>
s""" s"""
if (${ce.value}.changePrecision(${ce.value}.precision(), ${_scale}, if (${ce.value}.changePrecision(${ce.value}.precision(), ${s},
java.math.BigDecimal.${modeStr})) { java.math.BigDecimal.${modeStr})) {
${ev.value} = ${ce.value}; ${ev.value} = ${ce.value};
} else { } else {
......
...@@ -546,15 +546,14 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ...@@ -546,15 +546,14 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
BigDecimal(3.141593), BigDecimal(3.1415927)) BigDecimal(3.141593), BigDecimal(3.1415927))
// round_scale > current_scale would result in precision increase
// and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
(0 to 7).foreach { i => (0 to 7).foreach { i =>
checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow)
checkEvaluation(BRound(bdPi, i), bdResults(i), EmptyRow) checkEvaluation(BRound(bdPi, i), bdResults(i), EmptyRow)
} }
(8 to 10).foreach { scale => (8 to 10).foreach { scale =>
checkEvaluation(Round(bdPi, scale), null, EmptyRow) checkEvaluation(Round(bdPi, scale), bdPi, EmptyRow)
checkEvaluation(BRound(bdPi, scale), null, EmptyRow) checkEvaluation(BRound(bdPi, scale), bdPi, EmptyRow)
} }
DataTypeTestUtils.numericTypes.foreach { dataType => DataTypeTestUtils.numericTypes.foreach { dataType =>
......
...@@ -231,6 +231,19 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext { ...@@ -231,6 +231,19 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext {
Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
) )
val bdPi: BigDecimal = BigDecimal(31415925L, 7)
checkAnswer(
sql(s"SELECT round($bdPi, 7), round($bdPi, 8), round($bdPi, 9), round($bdPi, 10), " +
s"round($bdPi, 100), round($bdPi, 6), round(null, 8)"),
Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141593"), null))
)
checkAnswer(
sql(s"SELECT bround($bdPi, 7), bround($bdPi, 8), bround($bdPi, 9), bround($bdPi, 10), " +
s"bround($bdPi, 100), bround($bdPi, 6), bround(null, 8)"),
Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141592"), null))
)
} }
test("exp") { test("exp") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment