Skip to content
Snippets Groups Projects
Commit 3969a807 authored by liuxian's avatar liuxian Committed by Xiao Li
Browse files

[SPARK-20876][SQL] If the input parameter is float type for ceil or floor,the...

[SPARK-20876][SQL] If the input parameter is float type for ceil or floor,the result is not we expected

## What changes were proposed in this pull request?

spark-sql>SELECT ceil(cast(12345.1233 as float));
spark-sql>12345
For this case, the result we expected is `12346`
spark-sql>SELECT floor(cast(-12345.1233 as float));
spark-sql>-12345
For this case, the result we expected is `-12346`

Because in `Ceil` or `Floor`, `inputTypes` has no FloatType, so it is converted to LongType.
## How was this patch tested?

After the modification:
spark-sql>SELECT ceil(cast(12345.1233 as float));
spark-sql>12346
spark-sql>SELECT floor(cast(-12345.1233 as float));
spark-sql>-12346

Author: liuxian <liu.xian3@zte.com.cn>

Closes #18103 from 10110346/wip-lx-0525-1.
parent 08ede46b
No related branches found
No related tags found
No related merge requests found
...@@ -232,19 +232,20 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" ...@@ -232,19 +232,20 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL"
} }
override def inputTypes: Seq[AbstractDataType] = override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(LongType, DoubleType, DecimalType)) Seq(TypeCollection(DoubleType, DecimalType, LongType))
protected override def nullSafeEval(input: Any): Any = child.dataType match { protected override def nullSafeEval(input: Any): Any = child.dataType match {
case LongType => input.asInstanceOf[Long] case LongType => input.asInstanceOf[Long]
case DoubleType => f(input.asInstanceOf[Double]).toLong case DoubleType => f(input.asInstanceOf[Double]).toLong
case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].ceil
} }
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
child.dataType match { child.dataType match {
case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
case DecimalType.Fixed(precision, scale) => case DecimalType.Fixed(_, _) =>
defineCodeGen(ctx, ev, c => s"$c.ceil()") defineCodeGen(ctx, ev, c => s"$c.ceil()")
case LongType => defineCodeGen(ctx, ev, c => s"$c")
case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
} }
} }
...@@ -348,19 +349,20 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO ...@@ -348,19 +349,20 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO
} }
override def inputTypes: Seq[AbstractDataType] = override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(LongType, DoubleType, DecimalType)) Seq(TypeCollection(DoubleType, DecimalType, LongType))
protected override def nullSafeEval(input: Any): Any = child.dataType match { protected override def nullSafeEval(input: Any): Any = child.dataType match {
case LongType => input.asInstanceOf[Long] case LongType => input.asInstanceOf[Long]
case DoubleType => f(input.asInstanceOf[Double]).toLong case DoubleType => f(input.asInstanceOf[Double]).toLong
case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].floor
} }
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
child.dataType match { child.dataType match {
case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
case DecimalType.Fixed(precision, scale) => case DecimalType.Fixed(_, _) =>
defineCodeGen(ctx, ev, c => s"$c.floor()") defineCodeGen(ctx, ev, c => s"$c.floor()")
case LongType => defineCodeGen(ctx, ev, c => s"$c")
case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
} }
} }
......
...@@ -262,7 +262,7 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { ...@@ -262,7 +262,7 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers {
val plan = testRelation2.select('c).orderBy(Floor('a).asc) val plan = testRelation2.select('c).orderBy(Floor('a).asc)
val expected = testRelation2.select(c, a) val expected = testRelation2.select(c, a)
.orderBy(Floor(Cast(a, LongType, Option(TimeZone.getDefault().getID))).asc).select(c) .orderBy(Floor(Cast(a, DoubleType, Option(TimeZone.getDefault().getID))).asc).select(c)
checkAnalysis(plan, expected) checkAnalysis(plan, expected)
} }
......
...@@ -258,6 +258,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ...@@ -258,6 +258,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3)) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3))
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0)) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0))
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0)) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0))
val doublePi: Double = 3.1415
val floatPi: Float = 3.1415f
val longLit: Long = 12345678901234567L
checkEvaluation(Ceil(doublePi), 4L, EmptyRow)
checkEvaluation(Ceil(floatPi.toDouble), 4L, EmptyRow)
checkEvaluation(Ceil(longLit), longLit, EmptyRow)
checkEvaluation(Ceil(-doublePi), -3L, EmptyRow)
checkEvaluation(Ceil(-floatPi.toDouble), -3L, EmptyRow)
checkEvaluation(Ceil(-longLit), -longLit, EmptyRow)
} }
test("floor") { test("floor") {
...@@ -268,6 +278,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ...@@ -268,6 +278,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3)) checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3))
checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0)) checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0))
checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0)) checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0))
val doublePi: Double = 3.1415
val floatPi: Float = 3.1415f
val longLit: Long = 12345678901234567L
checkEvaluation(Floor(doublePi), 3L, EmptyRow)
checkEvaluation(Floor(floatPi.toDouble), 3L, EmptyRow)
checkEvaluation(Floor(longLit), longLit, EmptyRow)
checkEvaluation(Floor(-doublePi), -4L, EmptyRow)
checkEvaluation(Floor(-floatPi.toDouble), -4L, EmptyRow)
checkEvaluation(Floor(-longLit), -longLit, EmptyRow)
} }
test("factorial") { test("factorial") {
......
...@@ -64,12 +64,9 @@ select cot(-1); ...@@ -64,12 +64,9 @@ select cot(-1);
select ceiling(0); select ceiling(0);
select ceiling(1); select ceiling(1);
select ceil(1234567890123456); select ceil(1234567890123456);
select ceil(12345678901234567);
select ceiling(1234567890123456); select ceiling(1234567890123456);
select ceiling(12345678901234567);
-- floor -- floor
select floor(0); select floor(0);
select floor(1); select floor(1);
select floor(1234567890123456); select floor(1234567890123456);
select floor(12345678901234567);
-- Automatically generated by SQLQueryTestSuite -- Automatically generated by SQLQueryTestSuite
-- Number of queries: 38 -- Number of queries: 45
-- !query 0 -- !query 0
...@@ -321,7 +321,7 @@ struct<COT(CAST(-1 AS DOUBLE)):double> ...@@ -321,7 +321,7 @@ struct<COT(CAST(-1 AS DOUBLE)):double>
-- !query 38 -- !query 38
select ceiling(0) select ceiling(0)
-- !query 38 schema -- !query 38 schema
struct<CEIL(CAST(0 AS BIGINT)):bigint> struct<CEIL(CAST(0 AS DOUBLE)):bigint>
-- !query 38 output -- !query 38 output
0 0
...@@ -329,7 +329,7 @@ struct<CEIL(CAST(0 AS BIGINT)):bigint> ...@@ -329,7 +329,7 @@ struct<CEIL(CAST(0 AS BIGINT)):bigint>
-- !query 39 -- !query 39
select ceiling(1) select ceiling(1)
-- !query 39 schema -- !query 39 schema
struct<CEIL(CAST(1 AS BIGINT)):bigint> struct<CEIL(CAST(1 AS DOUBLE)):bigint>
-- !query 39 output -- !query 39 output
1 1
...@@ -343,56 +343,32 @@ struct<CEIL(1234567890123456):bigint> ...@@ -343,56 +343,32 @@ struct<CEIL(1234567890123456):bigint>
-- !query 41 -- !query 41
select ceil(12345678901234567) select ceiling(1234567890123456)
-- !query 41 schema -- !query 41 schema
struct<CEIL(12345678901234567):bigint> struct<CEIL(1234567890123456):bigint>
-- !query 41 output -- !query 41 output
12345678901234567 1234567890123456
-- !query 42 -- !query 42
select ceiling(1234567890123456) select floor(0)
-- !query 42 schema -- !query 42 schema
struct<CEIL(1234567890123456):bigint> struct<FLOOR(CAST(0 AS DOUBLE)):bigint>
-- !query 42 output -- !query 42 output
1234567890123456 0
-- !query 43 -- !query 43
select ceiling(12345678901234567) select floor(1)
-- !query 43 schema -- !query 43 schema
struct<CEIL(12345678901234567):bigint> struct<FLOOR(CAST(1 AS DOUBLE)):bigint>
-- !query 43 output -- !query 43 output
12345678901234567
-- !query 44
select floor(0)
-- !query 44 schema
struct<FLOOR(CAST(0 AS BIGINT)):bigint>
-- !query 44 output
0
-- !query 45
select floor(1)
-- !query 45 schema
struct<FLOOR(CAST(1 AS BIGINT)):bigint>
-- !query 45 output
1 1
-- !query 46 -- !query 44
select floor(1234567890123456) select floor(1234567890123456)
-- !query 46 schema -- !query 44 schema
struct<FLOOR(1234567890123456):bigint> struct<FLOOR(1234567890123456):bigint>
-- !query 46 output -- !query 44 output
1234567890123456 1234567890123456
-- !query 47
select floor(12345678901234567)
-- !query 47 schema
struct<FLOOR(12345678901234567):bigint>
-- !query 47 output
12345678901234567
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment