From 7467b52ed07f174d93dfc4cb544dc4b69a2c2826 Mon Sep 17 00:00:00 2001 From: Davies Liu <davies@databricks.com> Date: Tue, 25 Aug 2015 15:19:41 -0700 Subject: [PATCH] [SPARK-10215] [SQL] Fix precision of division (follow the rule in Hive) Follow the rule in Hive for decimal division. see https://github.com/apache/hive/blob/ac755ebe26361a4647d53db2a28500f71697b276/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFOPDivide.java#L113 cc chenghao-intel Author: Davies Liu <davies@databricks.com> Closes #8415 from davies/decimal_div2. --- .../catalyst/analysis/HiveTypeCoercion.scala | 10 ++++++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 9 +++---- .../analysis/DecimalPrecisionSuite.scala | 8 +++--- .../org/apache/spark/sql/SQLQuerySuite.scala | 25 +++++++++++++++++-- 4 files changed, 39 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index a1aa2a2b2c..87c11abbad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -396,8 +396,14 @@ object HiveTypeCoercion { resultType) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - val resultType = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1), - max(6, s1 + p2 + 1)) + var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2) + var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1)) + val diff = (intDig + decDig) - DecimalType.MAX_SCALE + if (diff > 0) { + decDig -= diff / 2 + 1 + intDig = DecimalType.MAX_SCALE - decDig + } + val resultType = DecimalType.bounded(intDig + decDig, decDig) val widerType = widerDecimalType(p1, s1, p2, s2) CheckOverflow(Divide(promotePrecision(e1, widerType), promotePrecision(e2, widerType)), resultType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 1e0cc81dae..820b336aac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.SimpleCatalystConf -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ class AnalysisSuite extends AnalysisTest { - import TestRelations._ + import org.apache.spark.sql.catalyst.analysis.TestRelations._ test("union project *") { val plan = (1 to 100) @@ -96,7 +95,7 @@ class AnalysisSuite extends AnalysisTest { assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) // StringType will be promoted into Decimal(38, 18) - assert(pl(3).dataType == DecimalType(38, 29)) + assert(pl(3).dataType == DecimalType(38, 22)) assert(pl(4).dataType == DoubleType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index fc11627da6..b4ad618c23 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -136,10 +136,10 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { checkType(Multiply(i, u), DecimalType(38, 18)) checkType(Multiply(u, u), DecimalType(38, 36)) - checkType(Divide(u, d1), DecimalType(38, 21)) - checkType(Divide(u, d2), DecimalType(38, 24)) - checkType(Divide(u, i), DecimalType(38, 29)) - checkType(Divide(u, u), DecimalType(38, 38)) + checkType(Divide(u, d1), DecimalType(38, 18)) + checkType(Divide(u, d2), DecimalType(38, 19)) + checkType(Divide(u, i), DecimalType(38, 23)) + checkType(Divide(u, u), DecimalType(38, 18)) checkType(Remainder(d1, u), DecimalType(19, 18)) checkType(Remainder(d2, u), DecimalType(21, 18)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index aa07665c6b..9e172b2c26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1622,9 +1622,30 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333"))) checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333"))) checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"), - Row(BigDecimal("3.4333333333333333333333333333333333333", new MathContext(38)))) + Row(BigDecimal("3.433333333333333333333333333", new MathContext(38)))) checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"), - Row(null)) + Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38)))) + } + + test("SPARK-10215 Div of Decimal returns null") { + val d = Decimal(1.12321) + val df = Seq((d, 1)).toDF("a", "b") + + checkAnswer( + df.selectExpr("b * a / b"), + Seq(Row(d.toBigDecimal))) + checkAnswer( + df.selectExpr("b * a / b / b"), + Seq(Row(d.toBigDecimal))) + checkAnswer( + df.selectExpr("b * a + b"), + Seq(Row(BigDecimal(2.12321)))) + checkAnswer( + df.selectExpr("b * a - b"), + Seq(Row(BigDecimal(0.12321)))) + checkAnswer( + df.selectExpr("b * a * b"), + Seq(Row(d.toBigDecimal))) } test("precision smaller than scale") { -- GitLab