From 6e89d574058bc2b96b14a691a07580be67f63707 Mon Sep 17 00:00:00 2001 From: liuxian <liu.xian3@zte.com.cn> Date: Fri, 12 May 2017 11:38:50 +0800 Subject: [PATCH] [SPARK-20665][SQL] Bround" and "Round" function return NULL spark-sql>select bround(12.3, 2); spark-sql>NULL For this case, the expected result is 12.3, but it is null. So ,when the second parameter is bigger than "decimal.scala", the result is not we expected. "round" function has the same problem. This PR can solve the problem for both of them. unit test cases in MathExpressionsSuite and MathFunctionsSuite Author: liuxian <liu.xian3@zte.com.cn> Closes #17906 from 10110346/wip_lx_0509. (cherry picked from commit 2b36eb696f6c738e1328582630755aaac4293460) Signed-off-by: Wenchen Fan <wenchen@databricks.com> --- .../sql/catalyst/expressions/mathExpressions.scala | 12 ++++++------ .../catalyst/expressions/MathExpressionsSuite.scala | 7 +++---- .../org/apache/spark/sql/MathFunctionsSuite.scala | 13 +++++++++++++ 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 65273a77b1..54b8457403 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1021,10 +1021,10 @@ abstract class RoundBase(child: Expression, scale: Expression, // not overriding since _scale is a constant int at runtime def nullSafeEval(input1: Any): Any = { - child.dataType match { - case _: DecimalType => + dataType match { + case DecimalType.Fixed(_, s) => val decimal = input1.asInstanceOf[Decimal] - if (decimal.changePrecision(decimal.precision, _scale, mode)) decimal else null + if (decimal.changePrecision(decimal.precision, s, mode)) decimal else null case ByteType => BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte case ShortType => @@ -1053,10 +1053,10 @@ abstract class RoundBase(child: Expression, scale: Expression, override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val ce = child.genCode(ctx) - val evaluationCode = child.dataType match { - case _: DecimalType => + val evaluationCode = dataType match { + case DecimalType.Fixed(_, s) => s""" - if (${ce.value}.changePrecision(${ce.value}.precision(), ${_scale}, + if (${ce.value}.changePrecision(${ce.value}.precision(), ${s}, java.math.BigDecimal.${modeStr})) { ${ev.value} = ${ce.value}; } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 6b5bfac946..1555dd1cf5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -546,15 +546,14 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), BigDecimal(3.141593), BigDecimal(3.1415927)) - // round_scale > current_scale would result in precision increase - // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null + (0 to 7).foreach { i => checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) checkEvaluation(BRound(bdPi, i), bdResults(i), EmptyRow) } (8 to 10).foreach { scale => - checkEvaluation(Round(bdPi, scale), null, EmptyRow) - checkEvaluation(BRound(bdPi, scale), null, EmptyRow) + checkEvaluation(Round(bdPi, scale), bdPi, EmptyRow) + checkEvaluation(BRound(bdPi, scale), bdPi, EmptyRow) } DataTypeTestUtils.numericTypes.foreach { dataType => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala index 37443d0342..0284f8311e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala @@ -231,6 +231,19 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) ) + + val bdPi: BigDecimal = BigDecimal(31415925L, 7) + checkAnswer( + sql(s"SELECT round($bdPi, 7), round($bdPi, 8), round($bdPi, 9), round($bdPi, 10), " + + s"round($bdPi, 100), round($bdPi, 6), round(null, 8)"), + Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141593"), null)) + ) + + checkAnswer( + sql(s"SELECT bround($bdPi, 7), bround($bdPi, 8), bround($bdPi, 9), bround($bdPi, 10), " + + s"bround($bdPi, 100), bround($bdPi, 6), bround(null, 8)"), + Seq(Row(bdPi, bdPi, bdPi, bdPi, bdPi, BigDecimal("3.141592"), null)) + ) } test("exp") { -- GitLab