diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 2624f5586fd5d8e3b63f13513514db06dc4d4c05..2ac11598e63d1cd32dda2294a7a707aa210cd96b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -484,24 +484,50 @@ class TypeCoercionSuite extends PlanTest { } test("coalesce casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion, - Coalesce(Literal(1.0) - :: Literal(1) - :: Literal.create(1.0, FloatType) - :: Nil), - Coalesce(Cast(Literal(1.0), DoubleType) - :: Cast(Literal(1), DoubleType) - :: Cast(Literal.create(1.0, FloatType), DoubleType) - :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, - Coalesce(Literal(1L) - :: Literal(1) - :: Literal(new java.math.BigDecimal("1000000000000000000000")) - :: Nil), - Coalesce(Cast(Literal(1L), DecimalType(22, 0)) - :: Cast(Literal(1), DecimalType(22, 0)) - :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) - :: Nil)) + val rule = TypeCoercion.FunctionArgumentConversion + + val intLit = Literal(1) + val longLit = Literal.create(1L) + val doubleLit = Literal(1.0) + val stringLit = Literal.create("c", StringType) + val nullLit = Literal.create(null, NullType) + val floatNullLit = Literal.create(null, FloatType) + val floatLit = Literal.create(1.0f, FloatType) + val timestampLit = Literal.create("2017-04-12", TimestampType) + val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000")) + + ruleTest(rule, + Coalesce(Seq(doubleLit, intLit, floatLit)), + Coalesce(Seq(Cast(doubleLit, DoubleType), + Cast(intLit, DoubleType), Cast(floatLit, DoubleType)))) + + ruleTest(rule, + Coalesce(Seq(longLit, intLit, decimalLit)), + Coalesce(Seq(Cast(longLit, DecimalType(22, 0)), + Cast(intLit, DecimalType(22, 0)), Cast(decimalLit, DecimalType(22, 0))))) + + ruleTest(rule, + Coalesce(Seq(nullLit, intLit)), + Coalesce(Seq(Cast(nullLit, IntegerType), Cast(intLit, IntegerType)))) + + ruleTest(rule, + Coalesce(Seq(timestampLit, stringLit)), + Coalesce(Seq(Cast(timestampLit, StringType), Cast(stringLit, StringType)))) + + ruleTest(rule, + Coalesce(Seq(nullLit, floatNullLit, intLit)), + Coalesce(Seq(Cast(nullLit, FloatType), Cast(floatNullLit, FloatType), + Cast(intLit, FloatType)))) + + ruleTest(rule, + Coalesce(Seq(nullLit, intLit, decimalLit, doubleLit)), + Coalesce(Seq(Cast(nullLit, DoubleType), Cast(intLit, DoubleType), + Cast(decimalLit, DoubleType), Cast(doubleLit, DoubleType)))) + + ruleTest(rule, + Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)), + Coalesce(Seq(Cast(nullLit, StringType), Cast(floatNullLit, StringType), + Cast(doubleLit, StringType), Cast(stringLit, StringType)))) } test("CreateArray casts") { @@ -675,6 +701,14 @@ class TypeCoercionSuite extends PlanTest { test("type coercion for If") { val rule = TypeCoercion.IfCoercion + val intLit = Literal(1) + val doubleLit = Literal(1.0) + val trueLit = Literal.create(true, BooleanType) + val falseLit = Literal.create(false, BooleanType) + val stringLit = Literal.create("c", StringType) + val floatLit = Literal.create(1.0f, FloatType) + val timestampLit = Literal.create("2017-04-12", TimestampType) + val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000")) ruleTest(rule, If(Literal(true), Literal(1), Literal(1L)), @@ -685,12 +719,32 @@ class TypeCoercionSuite extends PlanTest { If(Literal.create(null, BooleanType), Literal(1), Literal(1))) ruleTest(rule, - If(AssertTrue(Literal.create(true, BooleanType)), Literal(1), Literal(2)), - If(Cast(AssertTrue(Literal.create(true, BooleanType)), BooleanType), Literal(1), Literal(2))) + If(AssertTrue(trueLit), Literal(1), Literal(2)), + If(Cast(AssertTrue(trueLit), BooleanType), Literal(1), Literal(2))) + + ruleTest(rule, + If(AssertTrue(falseLit), Literal(1), Literal(2)), + If(Cast(AssertTrue(falseLit), BooleanType), Literal(1), Literal(2))) + + ruleTest(rule, + If(trueLit, intLit, doubleLit), + If(trueLit, Cast(intLit, DoubleType), doubleLit)) + + ruleTest(rule, + If(trueLit, floatLit, doubleLit), + If(trueLit, Cast(floatLit, DoubleType), doubleLit)) + + ruleTest(rule, + If(trueLit, floatLit, decimalLit), + If(trueLit, Cast(floatLit, DoubleType), Cast(decimalLit, DoubleType))) + + ruleTest(rule, + If(falseLit, stringLit, doubleLit), + If(falseLit, stringLit, Cast(doubleLit, StringType))) ruleTest(rule, - If(AssertTrue(Literal.create(false, BooleanType)), Literal(1), Literal(2)), - If(Cast(AssertTrue(Literal.create(false, BooleanType)), BooleanType), Literal(1), Literal(2))) + If(trueLit, timestampLit, stringLit), + If(trueLit, Cast(timestampLit, StringType), stringLit)) } test("type coercion for CaseKeyWhen") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 5064a1f63f83dc26c141f2168259762b7639244d..394c0a091e3902859cd5e0038c55c6ddc6b51cef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -97,14 +97,30 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val doubleLit = Literal.create(2.2, DoubleType) val stringLit = Literal.create("c", StringType) val nullLit = Literal.create(null, NullType) - + val floatNullLit = Literal.create(null, FloatType) + val floatLit = Literal.create(1.01f, FloatType) + val timestampLit = Literal.create("2017-04-12", TimestampType) + val decimalLit = Literal.create(10.2, DecimalType(20, 2)) + + assert(analyze(new Nvl(decimalLit, stringLit)).dataType == StringType) + assert(analyze(new Nvl(doubleLit, decimalLit)).dataType == DoubleType) + assert(analyze(new Nvl(decimalLit, doubleLit)).dataType == DoubleType) + assert(analyze(new Nvl(decimalLit, floatLit)).dataType == DoubleType) + assert(analyze(new Nvl(floatLit, decimalLit)).dataType == DoubleType) + + assert(analyze(new Nvl(timestampLit, stringLit)).dataType == StringType) assert(analyze(new Nvl(intLit, doubleLit)).dataType == DoubleType) assert(analyze(new Nvl(intLit, stringLit)).dataType == StringType) assert(analyze(new Nvl(stringLit, doubleLit)).dataType == StringType) + assert(analyze(new Nvl(doubleLit, stringLit)).dataType == StringType) assert(analyze(new Nvl(nullLit, intLit)).dataType == IntegerType) assert(analyze(new Nvl(doubleLit, nullLit)).dataType == DoubleType) assert(analyze(new Nvl(nullLit, stringLit)).dataType == StringType) + + assert(analyze(new Nvl(floatLit, stringLit)).dataType == StringType) + assert(analyze(new Nvl(floatLit, doubleLit)).dataType == DoubleType) + assert(analyze(new Nvl(floatNullLit, intLit)).dataType == FloatType) } test("AtLeastNNonNulls") {