Skip to content
Snippets Groups Projects
Commit db81b9d8 authored by Wenchen Fan's avatar Wenchen Fan Committed by Reynold Xin
Browse files

[SPARK-7952][SQL] use internal Decimal instead of java.math.BigDecimal

This PR fixes a bug introduced in https://github.com/apache/spark/pull/6505.
Decimal literal's value is not `java.math.BigDecimal`, but Spark SQL internal type: `Decimal`.

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #6574 from cloud-fan/fix and squashes the following commits:

b0e3549 [Wenchen Fan] rename to BooleanEquality
1987b37 [Wenchen Fan] use Decimal instead of java.math.BigDecimal
f93c420 [Wenchen Fan] compare literal
parent d6d601a0
No related branches found
No related tags found
No related merge requests found
...@@ -87,7 +87,7 @@ trait HiveTypeCoercion { ...@@ -87,7 +87,7 @@ trait HiveTypeCoercion {
WidenTypes :: WidenTypes ::
PromoteStrings :: PromoteStrings ::
DecimalPrecision :: DecimalPrecision ::
BooleanEqualization :: BooleanEquality ::
StringToIntegralCasts :: StringToIntegralCasts ::
FunctionArgumentConversion :: FunctionArgumentConversion ::
CaseWhenCoercion :: CaseWhenCoercion ::
...@@ -479,9 +479,9 @@ trait HiveTypeCoercion { ...@@ -479,9 +479,9 @@ trait HiveTypeCoercion {
/** /**
* Changes numeric values to booleans so that expressions like true = 1 can be evaluated. * Changes numeric values to booleans so that expressions like true = 1 can be evaluated.
*/ */
object BooleanEqualization extends Rule[LogicalPlan] { object BooleanEquality extends Rule[LogicalPlan] {
private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, new java.math.BigDecimal(1)) private val trueValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1))
private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, new java.math.BigDecimal(0)) private val falseValues = Seq(0.toByte, 0.toShort, 0, 0L, Decimal(0))
private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = { private def buildCaseKeyWhen(booleanExpr: Expression, numericExpr: Expression) = {
CaseKeyWhen(numericExpr, Seq( CaseKeyWhen(numericExpr, Seq(
...@@ -512,22 +512,22 @@ trait HiveTypeCoercion { ...@@ -512,22 +512,22 @@ trait HiveTypeCoercion {
// all other cases are considered as false. // all other cases are considered as false.
// We may simplify the expression if one side is literal numeric values // We may simplify the expression if one side is literal numeric values
case EqualTo(left @ BooleanType(), Literal(value, _: NumericType)) case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType))
if trueValues.contains(value) => left if trueValues.contains(value) => bool
case EqualTo(left @ BooleanType(), Literal(value, _: NumericType)) case EqualTo(bool @ BooleanType(), Literal(value, _: NumericType))
if falseValues.contains(value) => Not(left) if falseValues.contains(value) => Not(bool)
case EqualTo(Literal(value, _: NumericType), right @ BooleanType()) case EqualTo(Literal(value, _: NumericType), bool @ BooleanType())
if trueValues.contains(value) => right if trueValues.contains(value) => bool
case EqualTo(Literal(value, _: NumericType), right @ BooleanType()) case EqualTo(Literal(value, _: NumericType), bool @ BooleanType())
if falseValues.contains(value) => Not(right) if falseValues.contains(value) => Not(bool)
case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType)) case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType))
if trueValues.contains(value) => And(IsNotNull(left), left) if trueValues.contains(value) => And(IsNotNull(bool), bool)
case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType)) case EqualNullSafe(bool @ BooleanType(), Literal(value, _: NumericType))
if falseValues.contains(value) => And(IsNotNull(left), Not(left)) if falseValues.contains(value) => And(IsNotNull(bool), Not(bool))
case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType()) case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType())
if trueValues.contains(value) => And(IsNotNull(right), right) if trueValues.contains(value) => And(IsNotNull(bool), bool)
case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType()) case EqualNullSafe(Literal(value, _: NumericType), bool @ BooleanType())
if falseValues.contains(value) => And(IsNotNull(right), Not(right)) if falseValues.contains(value) => And(IsNotNull(bool), Not(bool))
case EqualTo(left @ BooleanType(), right @ NumericType()) => case EqualTo(left @ BooleanType(), right @ NumericType()) =>
transform(left , right) transform(left , right)
......
...@@ -147,7 +147,8 @@ class HiveTypeCoercionSuite extends PlanTest { ...@@ -147,7 +147,8 @@ class HiveTypeCoercionSuite extends PlanTest {
} }
test("type coercion simplification for equal to") { test("type coercion simplification for equal to") {
val be = new HiveTypeCoercion {}.BooleanEqualization val be = new HiveTypeCoercion {}.BooleanEquality
ruleTest(be, ruleTest(be,
EqualTo(Literal(true), Literal(1)), EqualTo(Literal(true), Literal(1)),
Literal(true) Literal(true)
...@@ -164,5 +165,26 @@ class HiveTypeCoercionSuite extends PlanTest { ...@@ -164,5 +165,26 @@ class HiveTypeCoercionSuite extends PlanTest {
EqualNullSafe(Literal(true), Literal(0)), EqualNullSafe(Literal(true), Literal(0)),
And(IsNotNull(Literal(true)), Not(Literal(true))) And(IsNotNull(Literal(true)), Not(Literal(true)))
) )
ruleTest(be,
EqualTo(Literal(true), Literal(1L)),
Literal(true)
)
ruleTest(be,
EqualTo(Literal(new java.math.BigDecimal(1)), Literal(true)),
Literal(true)
)
ruleTest(be,
EqualTo(Literal(BigDecimal(0)), Literal(true)),
Not(Literal(true))
)
ruleTest(be,
EqualTo(Literal(Decimal(1)), Literal(true)),
Literal(true)
)
ruleTest(be,
EqualTo(Literal.create(Decimal(1), DecimalType(8, 0)), Literal(true)),
Literal(true)
)
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment