Skip to content
Snippets Groups Projects
Commit 4d8c7c6d authored by Cheng Hao's avatar Cheng Hao Committed by Davies Liu
Browse files

[SPARK-10865] [SPARK-10866] [SQL] Fix bug of ceil/floor, which should returns...

[SPARK-10865] [SPARK-10866] [SQL] Fix bug of ceil/floor, which should returns long instead of the Double type

Floor & Ceiling function should returns Long type, rather than Double.

Verified with MySQL & Hive.

Author: Cheng Hao <hao.cheng@intel.com>

Closes #8933 from chenghao-intel/ceiling.
parent 9b3e7768
No related branches found
No related tags found
No related merge requests found
...@@ -52,7 +52,7 @@ abstract class LeafMathExpression(c: Double, name: String) ...@@ -52,7 +52,7 @@ abstract class LeafMathExpression(c: Double, name: String)
* @param f The math function. * @param f The math function.
* @param name The short name of the function * @param name The short name of the function
*/ */
abstract class UnaryMathExpression(f: Double => Double, name: String) abstract class UnaryMathExpression(val f: Double => Double, name: String)
extends UnaryExpression with Serializable with ImplicitCastInputTypes { extends UnaryExpression with Serializable with ImplicitCastInputTypes {
override def inputTypes: Seq[DataType] = Seq(DoubleType) override def inputTypes: Seq[DataType] = Seq(DoubleType)
...@@ -152,7 +152,16 @@ case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN" ...@@ -152,7 +152,16 @@ case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN"
case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT") case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT")
case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") {
override def dataType: DataType = LongType
protected override def nullSafeEval(input: Any): Any = {
f(input.asInstanceOf[Double]).toLong
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
}
}
case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS")
...@@ -195,7 +204,16 @@ case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP") ...@@ -195,7 +204,16 @@ case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP")
case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1") case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1")
case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") {
override def dataType: DataType = LongType
protected override def nullSafeEval(input: Any): Any = {
f(input.asInstanceOf[Double]).toLong
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
}
}
object Factorial { object Factorial {
......
...@@ -244,12 +244,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { ...@@ -244,12 +244,12 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
} }
test("ceil") { test("ceil") {
testUnary(Ceil, math.ceil) testUnary(Ceil, (d: Double) => math.ceil(d).toLong)
checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType)
} }
test("floor") { test("floor") {
testUnary(Floor, math.floor) testUnary(Floor, (d: Double) => math.floor(d).toLong)
checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType) checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType)
} }
......
...@@ -37,9 +37,11 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { ...@@ -37,9 +37,11 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
private lazy val nullDoubles = private lazy val nullDoubles =
Seq(NullDoubles(1.0), NullDoubles(2.0), NullDoubles(3.0), NullDoubles(null)).toDF() Seq(NullDoubles(1.0), NullDoubles(2.0), NullDoubles(3.0), NullDoubles(null)).toDF()
private def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( private def testOneToOneMathFunction[
@specialized(Int, Long, Float, Double) T,
@specialized(Int, Long, Float, Double) U](
c: Column => Column, c: Column => Column,
f: T => T): Unit = { f: T => U): Unit = {
checkAnswer( checkAnswer(
doubleData.select(c('a)), doubleData.select(c('a)),
(1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T]))) (1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T])))
...@@ -165,10 +167,10 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { ...@@ -165,10 +167,10 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
} }
test("ceil and ceiling") { test("ceil and ceiling") {
testOneToOneMathFunction(ceil, math.ceil) testOneToOneMathFunction(ceil, (d: Double) => math.ceil(d).toLong)
checkAnswer( checkAnswer(
sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
Row(0.0, 1.0, 2.0)) Row(0L, 1L, 2L))
} }
test("conv") { test("conv") {
...@@ -184,7 +186,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { ...@@ -184,7 +186,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
} }
test("floor") { test("floor") {
testOneToOneMathFunction(floor, math.floor) testOneToOneMathFunction(floor, (d: Double) => math.floor(d).toLong)
} }
test("factorial") { test("factorial") {
...@@ -228,7 +230,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { ...@@ -228,7 +230,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
} }
test("signum / sign") { test("signum / sign") {
testOneToOneMathFunction[Double](signum, math.signum) testOneToOneMathFunction[Double, Double](signum, math.signum)
checkAnswer( checkAnswer(
sql("SELECT sign(10), signum(-11)"), sql("SELECT sign(10), signum(-11)"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment