diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index a7bf81e98be8e315a582c307539664c6334fcc03..bf46a398621316a3d9c28c0e3be07041a9730b97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -232,9 +232,10 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" } override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(DoubleType, DecimalType)) + Seq(TypeCollection(LongType, DoubleType, DecimalType)) protected override def nullSafeEval(input: Any): Any = child.dataType match { + case LongType => input.asInstanceOf[Long] case DoubleType => f(input.asInstanceOf[Double]).toLong case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil } @@ -347,9 +348,10 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO } override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(DoubleType, DecimalType)) + Seq(TypeCollection(LongType, DoubleType, DecimalType)) protected override def nullSafeEval(input: Any): Any = child.dataType match { + case LongType => input.asInstanceOf[Long] case DoubleType => f(input.asInstanceOf[Double]).toLong case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 31047f688600be314193f87310306fe4dfb1ad32..0896caeab8d7a0df9ee087ca96e5b5f7e85cb517 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -262,7 +262,7 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { val plan = testRelation2.select('c).orderBy(Floor('a).asc) val expected = testRelation2.select(c, a) - .orderBy(Floor(Cast(a, DoubleType, Option(TimeZone.getDefault().getID))).asc).select(c) + .orderBy(Floor(Cast(a, LongType, Option(TimeZone.getDefault().getID))).asc).select(c) checkAnalysis(plan, expected) } diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 1920a108c658419dbf5b7908b569f2a44a6110d1..f7167472b05c67af2016036db95ab07db2c87c6a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -59,3 +59,17 @@ select cot(1); select cot(null); select cot(0); select cot(-1); + +-- ceil and ceiling +select ceiling(0); +select ceiling(1); +select ceil(1234567890123456); +select ceil(12345678901234567); +select ceiling(1234567890123456); +select ceiling(12345678901234567); + +-- floor +select floor(0); +select floor(1); +select floor(1234567890123456); +select floor(12345678901234567); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index abd18211c70d8a025f4090c35849c603c3cfe089..fe52005aa91da62d1204c05689d161d92d0b7df3 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -316,3 +316,83 @@ select cot(-1) struct<COT(CAST(-1 AS DOUBLE)):double> -- !query 37 output -0.6420926159343306 + + +-- !query 38 +select ceiling(0) +-- !query 38 schema +struct<CEIL(CAST(0 AS BIGINT)):bigint> +-- !query 38 output +0 + + +-- !query 39 +select ceiling(1) +-- !query 39 schema +struct<CEIL(CAST(1 AS BIGINT)):bigint> +-- !query 39 output +1 + + +-- !query 40 +select ceil(1234567890123456) +-- !query 40 schema +struct<CEIL(1234567890123456):bigint> +-- !query 40 output +1234567890123456 + + +-- !query 41 +select ceil(12345678901234567) +-- !query 41 schema +struct<CEIL(12345678901234567):bigint> +-- !query 41 output +12345678901234567 + + +-- !query 42 +select ceiling(1234567890123456) +-- !query 42 schema +struct<CEIL(1234567890123456):bigint> +-- !query 42 output +1234567890123456 + + +-- !query 43 +select ceiling(12345678901234567) +-- !query 43 schema +struct<CEIL(12345678901234567):bigint> +-- !query 43 output +12345678901234567 + + +-- !query 44 +select floor(0) +-- !query 44 schema +struct<FLOOR(CAST(0 AS BIGINT)):bigint> +-- !query 44 output +0 + + +-- !query 45 +select floor(1) +-- !query 45 schema +struct<FLOOR(CAST(1 AS BIGINT)):bigint> +-- !query 45 output +1 + + +-- !query 46 +select floor(1234567890123456) +-- !query 46 schema +struct<FLOOR(1234567890123456):bigint> +-- !query 46 output +1234567890123456 + + +-- !query 47 +select floor(12345678901234567) +-- !query 47 schema +struct<FLOOR(12345678901234567):bigint> +-- !query 47 output +12345678901234567