Skip to content
Snippets Groups Projects
Commit 2b36eb69 authored by liuxian's avatar liuxian Committed by Wenchen Fan
Browse files

[SPARK-20665][SQL] Bround" and "Round" function return NULL

## What changes were proposed in this pull request?
   spark-sql>select bround(12.3, 2);
   spark-sql>NULL
For this case,  the expected result is 12.3, but it is null.
So ,when the second parameter is bigger than "decimal.scala", the result is not we expected.
"round" function  has the same problem. This PR can solve the problem for both of them.

## How was this patch tested?
unit test cases in MathExpressionsSuite and MathFunctionsSuite

Author: liuxian <liu.xian3@zte.com.cn>

Closes #17906 from 10110346/wip_lx_0509.
parent 609ba5f2
No related branches found
No related tags found
No related merge requests found
......@@ -1023,10 +1023,10 @@ abstract class RoundBase(child: Expression, scale: Expression,
// not overriding since _scale is a constant int at runtime
def nullSafeEval(input1: Any): Any = {
child.dataType match {
case _: DecimalType =>
dataType match {
case DecimalType.Fixed(_, s) =>
val decimal = input1.asInstanceOf[Decimal]
decimal.toPrecision(decimal.precision, _scale, mode).orNull
decimal.toPrecision(decimal.precision, s, mode).orNull
case ByteType =>
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
case ShortType =>
......@@ -1055,10 +1055,10 @@ abstract class RoundBase(child: Expression, scale: Expression,
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val ce = child.genCode(ctx)
val evaluationCode = child.dataType match {
case _: DecimalType =>
val evaluationCode = dataType match {
case DecimalType.Fixed(_, s) =>
s"""
if (${ce.value}.changePrecision(${ce.value}.precision(), ${_scale},
if (${ce.value}.changePrecision(${ce.value}.precision(), ${s},
java.math.BigDecimal.${modeStr})) {
${ev.value} = ${ce.value};
} else {
......
......@@ -546,15 +546,14 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14),
BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159),
BigDecimal(3.141593), BigDecimal(3.1415927))
// round_scale > current_scale would result in precision increase
// and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null
(0 to 7).foreach { i =>
checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow)
checkEvaluation(BRound(bdPi, i), bdResults(i), EmptyRow)
}
(8 to 10).foreach { scale =>
checkEvaluation(Round(bdPi, scale), null, EmptyRow)
checkEvaluation(BRound(bdPi, scale), null, EmptyRow)
checkEvaluation(Round(bdPi, scale), bdPi, EmptyRow)
checkEvaluation(BRound(bdPi, scale), bdPi, EmptyRow)
}
DataTypeTestUtils.numericTypes.foreach { dataType =>
......
......@@ -231,6 +231,19 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext {
Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
)
val bdPi: BigDecimal = BigDecimal(31415925L, 7)
checkAnswer(
sql(s"SELECT round($bdPi, 7), round($bdPi, 8), round($bdPi, 9), round($bdPi, 10), " +
s"round($bdPi, 100), round($bdPi, 6), round(null, 8)"),
Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141593"), null))
)
checkAnswer(
sql(s"SELECT bround($bdPi, 7), bround($bdPi, 8), bround($bdPi, 9), bround($bdPi, 10), " +
s"bround($bdPi, 100), bround($bdPi, 6), bround(null, 8)"),
Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141592"), null))
)
}
test("round/bround with data frame from a local Seq of Product") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment