diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 603afc4032a374feb10747a179461cdeed2d200b..422d423747026c13468c731716f0fc0986c46a97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -562,6 +562,11 @@ object HiveTypeCoercion { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => c } + + case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType => + NaNvl(l, Cast(r, DoubleType)) + case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType => + NaNvl(Cast(l, DoubleType), r) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 6f8f4dd230f121b31a52df493046c1dd395e64c0..0891b554947104c009f462d01f30c1ed83902a2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -36,7 +36,14 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") - case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))") + case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { + val originValue = ctx.freshName("origin") + // codegen would fail to compile if we just write (-($c)) + // for example, we could not write --9223372036854775808L in code + s""" + ${ctx.javaType(dt)} $originValue = (${ctx.javaType(dt)})($eval); + ${ev.primitive} = (${ctx.javaType(dt)})(-($originValue)); + """}) case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index ab7d3afce8f2ec4a211012da75a8aef3eae10ec2..b69bbabee7e810d212097b4f7c356ee8bb597ba1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -227,6 +227,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { if (ctx.isPrimitiveType(left.dataType) + && left.dataType != BooleanType // java boolean doesn't support > or < operator && left.dataType != FloatType && left.dataType != DoubleType) { // faster version diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 6a98f4d9c54bc59c9bc1ba510d3946e3cd5eeea9..f645eb5f7bb01b5ae446a02b6fb215a55453a106 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -614,8 +614,9 @@ object DateTimeUtils { */ def dateAddMonths(days: Int, months: Int): Int = { val absoluteMonth = (getYear(days) - YearZero) * 12 + getMonth(days) - 1 + months - val currentMonthInYear = absoluteMonth % 12 - val currentYear = absoluteMonth / 12 + val nonNegativeMonth = if (absoluteMonth >= 0) absoluteMonth else 0 + val currentMonthInYear = nonNegativeMonth % 12 + val currentYear = nonNegativeMonth / 12 val leapDay = if (currentMonthInYear == 1 && isLeapYear(currentYear + YearZero)) 1 else 0 val lastDayOfMonth = monthDays(currentMonthInYear) + leapDay @@ -626,7 +627,7 @@ object DateTimeUtils { } else { dayOfMonth } - firstDayOfMonth(absoluteMonth) + currentDayInMonth - 1 + firstDayOfMonth(nonNegativeMonth) + currentDayInMonth - 1 } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 70608771dd1101f7bd5c97a095d3b588dc48199c..cbdf453f600abb4f278c3dbafe25c1dba87265c5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -251,6 +251,18 @@ class HiveTypeCoercionSuite extends PlanTest { :: Nil)) } + test("nanvl casts") { + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)), + NaNvl(Cast(Literal.create(1.0, FloatType), DoubleType), Literal.create(1.0, DoubleType))) + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, FloatType)), + NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0, FloatType), DoubleType))) + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)), + NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType))) + } + test("type coercion for If") { val rule = HiveTypeCoercion.IfCoercion ruleTest(rule, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index d03b0fbbfb2b2ef09a24aca74eaa20538b6eee0a..0bae8fe2fd8aa5343f3b3464db830f4f80b6cbdc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types._ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -56,6 +56,10 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(UnaryMinus(input), convert(-1)) checkEvaluation(UnaryMinus(Literal.create(null, dataType)), null) } + checkEvaluation(UnaryMinus(Literal(Long.MinValue)), Long.MinValue) + checkEvaluation(UnaryMinus(Literal(Int.MinValue)), Int.MinValue) + checkEvaluation(UnaryMinus(Literal(Short.MinValue)), Short.MinValue) + checkEvaluation(UnaryMinus(Literal(Byte.MinValue)), Byte.MinValue) } test("- (Minus)") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 3bff8e012a76399729c51c8eeb372b589df10d6d..e6e8790e90926c3519bb1905c7c7ed1bff633061 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -280,6 +280,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(AddMonths(Literal.create(null, DateType), Literal(1)), null) checkEvaluation(AddMonths(Literal.create(null, DateType), Literal.create(null, IntegerType)), null) + checkEvaluation( + AddMonths(Literal(Date.valueOf("2015-01-30")), Literal(Int.MinValue)), -7293498) } test("months_between") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 0bc2812a5dc83e094f7ad364a51c05bfc63025e6..d7eb13c50b134691b226d68e8f7bd8ed5cd8bcf0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -136,60 +136,60 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true) } - private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d).map(Literal(_)) + private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_)) private val largeValues = - Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN).map(Literal(_)) + Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN, true).map(Literal(_)) private val equalValues1 = - Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_)) + Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_)) private val equalValues2 = - Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_)) + Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_)) - test("BinaryComparison: <") { + test("BinaryComparison: lessThan") { for (i <- 0 until smallValues.length) { - checkEvaluation(smallValues(i) < largeValues(i), true) - checkEvaluation(equalValues1(i) < equalValues2(i), false) - checkEvaluation(largeValues(i) < smallValues(i), false) + checkEvaluation(LessThan(smallValues(i), largeValues(i)), true) + checkEvaluation(LessThan(equalValues1(i), equalValues2(i)), false) + checkEvaluation(LessThan(largeValues(i), smallValues(i)), false) } } - test("BinaryComparison: <=") { + test("BinaryComparison: LessThanOrEqual") { for (i <- 0 until smallValues.length) { - checkEvaluation(smallValues(i) <= largeValues(i), true) - checkEvaluation(equalValues1(i) <= equalValues2(i), true) - checkEvaluation(largeValues(i) <= smallValues(i), false) + checkEvaluation(LessThanOrEqual(smallValues(i), largeValues(i)), true) + checkEvaluation(LessThanOrEqual(equalValues1(i), equalValues2(i)), true) + checkEvaluation(LessThanOrEqual(largeValues(i), smallValues(i)), false) } } - test("BinaryComparison: >") { + test("BinaryComparison: GreaterThan") { for (i <- 0 until smallValues.length) { - checkEvaluation(smallValues(i) > largeValues(i), false) - checkEvaluation(equalValues1(i) > equalValues2(i), false) - checkEvaluation(largeValues(i) > smallValues(i), true) + checkEvaluation(GreaterThan(smallValues(i), largeValues(i)), false) + checkEvaluation(GreaterThan(equalValues1(i), equalValues2(i)), false) + checkEvaluation(GreaterThan(largeValues(i), smallValues(i)), true) } } - test("BinaryComparison: >=") { + test("BinaryComparison: GreaterThanOrEqual") { for (i <- 0 until smallValues.length) { - checkEvaluation(smallValues(i) >= largeValues(i), false) - checkEvaluation(equalValues1(i) >= equalValues2(i), true) - checkEvaluation(largeValues(i) >= smallValues(i), true) + checkEvaluation(GreaterThanOrEqual(smallValues(i), largeValues(i)), false) + checkEvaluation(GreaterThanOrEqual(equalValues1(i), equalValues2(i)), true) + checkEvaluation(GreaterThanOrEqual(largeValues(i), smallValues(i)), true) } } - test("BinaryComparison: ===") { + test("BinaryComparison: EqualTo") { for (i <- 0 until smallValues.length) { - checkEvaluation(smallValues(i) === largeValues(i), false) - checkEvaluation(equalValues1(i) === equalValues2(i), true) - checkEvaluation(largeValues(i) === smallValues(i), false) + checkEvaluation(EqualTo(smallValues(i), largeValues(i)), false) + checkEvaluation(EqualTo(equalValues1(i), equalValues2(i)), true) + checkEvaluation(EqualTo(largeValues(i), smallValues(i)), false) } } - test("BinaryComparison: <=>") { + test("BinaryComparison: EqualNullSafe") { for (i <- 0 until smallValues.length) { - checkEvaluation(smallValues(i) <=> largeValues(i), false) - checkEvaluation(equalValues1(i) <=> equalValues2(i), true) - checkEvaluation(largeValues(i) <=> smallValues(i), false) + checkEvaluation(EqualNullSafe(smallValues(i), largeValues(i)), false) + checkEvaluation(EqualNullSafe(equalValues1(i), equalValues2(i)), true) + checkEvaluation(EqualNullSafe(largeValues(i), smallValues(i)), false) } } @@ -209,8 +209,8 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { nullTest(GreaterThanOrEqual) nullTest(EqualTo) - checkEvaluation(normalInt <=> nullInt, false) - checkEvaluation(nullInt <=> normalInt, false) - checkEvaluation(nullInt <=> nullInt, true) + checkEvaluation(EqualNullSafe(normalInt, nullInt), false) + checkEvaluation(EqualNullSafe(nullInt, normalInt), false) + checkEvaluation(EqualNullSafe(nullInt, nullInt), true) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index eb64684ae0fd96c5d374b8c4a0ce13372116c157..35ca0b4c7cc214df42f8e80ef222e008cc1636f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -227,20 +227,24 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { test("nanvl") { val testData = ctx.createDataFrame(ctx.sparkContext.parallelize( - Row(null, 3.0, Double.NaN, Double.PositiveInfinity) :: Nil), + Row(null, 3.0, Double.NaN, Double.PositiveInfinity, 1.0f, 4) :: Nil), StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType), - StructField("c", DoubleType), StructField("d", DoubleType)))) + StructField("c", DoubleType), StructField("d", DoubleType), + StructField("e", FloatType), StructField("f", IntegerType)))) checkAnswer( testData.select( - nanvl($"a", lit(5)), nanvl($"b", lit(10)), - nanvl($"c", lit(null).cast(DoubleType)), nanvl($"d", lit(10))), - Row(null, 3.0, null, Double.PositiveInfinity) + nanvl($"a", lit(5)), nanvl($"b", lit(10)), nanvl(lit(10), $"b"), + nanvl($"c", lit(null).cast(DoubleType)), nanvl($"d", lit(10)), + nanvl($"b", $"e"), nanvl($"e", $"f")), + Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0) ) testData.registerTempTable("t") checkAnswer( - ctx.sql("select nanvl(a, 5), nanvl(b, 10), nanvl(c, null), nanvl(d, 10) from t"), - Row(null, 3.0, null, Double.PositiveInfinity) + ctx.sql( + "select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), nanvl(d, 10), " + + " nanvl(b, e), nanvl(e, f) from t"), + Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0) ) }