From 9808052b5adfed7dafd6c1b3971b998e45b2799a Mon Sep 17 00:00:00 2001 From: Cheng Hao <hao.cheng@intel.com> Date: Wed, 14 Oct 2015 20:56:08 -0700 Subject: [PATCH] [SPARK-11076] [SQL] Add decimal support for floor and ceil Actually all of the `UnaryMathExpression` doens't support the Decimal, will create follow ups for supporing it. This is the first PR which will be good to review the approach I am taking. Author: Cheng Hao <hao.cheng@intel.com> Closes #9086 from chenghao-intel/ceiling. --- .../expressions/mathExpressions.scala | 48 +++++++++++++++---- .../org/apache/spark/sql/types/Decimal.scala | 32 +++++++++++-- .../expressions/LiteralGenerator.scala | 14 +++++- .../expressions/MathFunctionsSuite.scala | 10 ++++ 4 files changed, 91 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index a8164e9e29..28f616fbb9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -55,7 +55,7 @@ abstract class LeafMathExpression(c: Double, name: String) abstract class UnaryMathExpression(val f: Double => Double, name: String) extends UnaryExpression with Serializable with ImplicitCastInputTypes { - override def inputTypes: Seq[DataType] = Seq(DoubleType) + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) override def dataType: DataType = DoubleType override def nullable: Boolean = true override def toString: String = s"$name($child)" @@ -153,13 +153,28 @@ case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN" case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT") case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") { - override def dataType: DataType = LongType - protected override def nullSafeEval(input: Any): Any = { - f(input.asInstanceOf[Double]).toLong + override def dataType: DataType = child.dataType match { + case dt @ DecimalType.Fixed(_, 0) => dt + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision - scale + 1, 0) + case _ => LongType + } + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(DoubleType, DecimalType)) + + protected override def nullSafeEval(input: Any): Any = child.dataType match { + case DoubleType => f(input.asInstanceOf[Double]).toLong + case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") + child.dataType match { + case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") + case DecimalType.Fixed(precision, scale) => + defineCodeGen(ctx, ev, c => s"$c.ceil()") + case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") + } } } @@ -205,13 +220,28 @@ case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP") case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1") case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") { - override def dataType: DataType = LongType - protected override def nullSafeEval(input: Any): Any = { - f(input.asInstanceOf[Double]).toLong + override def dataType: DataType = child.dataType match { + case dt @ DecimalType.Fixed(_, 0) => dt + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision - scale + 1, 0) + case _ => LongType + } + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(DoubleType, DecimalType)) + + protected override def nullSafeEval(input: Any): Any = child.dataType match { + case DoubleType => f(input.asInstanceOf[Double]).toLong + case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") + child.dataType match { + case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") + case DecimalType.Fixed(precision, scale) => + defineCodeGen(ctx, ev, c => s"$c.floor()") + case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index c11dab35cd..c7a1a2e746 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -107,7 +107,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { * Set this Decimal to the given BigDecimal value, with a given precision and scale. */ def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { - this.decimalVal = decimal.setScale(scale, ROUNDING_MODE) + this.decimalVal = decimal.setScale(scale, ROUND_HALF_UP) require( decimalVal.precision <= precision, s"Decimal precision ${decimalVal.precision} exceeds max precision $precision") @@ -198,6 +198,16 @@ final class Decimal extends Ordered[Decimal] with Serializable { * @return true if successful, false if overflow would occur */ def changePrecision(precision: Int, scale: Int): Boolean = { + changePrecision(precision, scale, ROUND_HALF_UP) + } + + /** + * Update precision and scale while keeping our value the same, and return true if successful. + * + * @return true if successful, false if overflow would occur + */ + private[sql] def changePrecision(precision: Int, scale: Int, + roundMode: BigDecimal.RoundingMode.Value): Boolean = { // fast path for UnsafeProjection if (precision == this.precision && scale == this.scale) { return true @@ -231,7 +241,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.ne(null)) { // We get here if either we started with a BigDecimal, or we switched to one because we would // have overflowed our Long; in either case we must rescale decimalVal to the new scale. - val newVal = decimalVal.setScale(scale, ROUNDING_MODE) + val newVal = decimalVal.setScale(scale, roundMode) if (newVal.precision > precision) { return false } @@ -309,10 +319,26 @@ final class Decimal extends Ordered[Decimal] with Serializable { } def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this + + def floor: Decimal = if (scale == 0) this else { + val value = this.clone() + value.changePrecision( + DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_FLOOR) + value + } + + def ceil: Decimal = if (scale == 0) this else { + val value = this.clone() + value.changePrecision( + DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_CEILING) + value + } } object Decimal { - private val ROUNDING_MODE = BigDecimal.RoundingMode.HALF_UP + val ROUND_HALF_UP = BigDecimal.RoundingMode.HALF_UP + val ROUND_CEILING = BigDecimal.RoundingMode.CEILING + val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR /** Maximum number of decimal digits a Long can represent */ val MAX_LONG_DIGITS = 18 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala index ee6d25157f..d9c91415e2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralGenerator.scala @@ -78,7 +78,18 @@ object LiteralGenerator { Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity) } yield Literal.create(f, DoubleType) - // TODO: decimal type + // TODO cache the generated data + def decimalLiteralGen(precision: Int, scale: Int): Gen[Literal] = { + assert(scale >= 0) + assert(precision >= scale) + Arbitrary.arbBigInt.arbitrary.map { s => + val a = (s % BigInt(10).pow(precision - scale)).toString() + val b = (s % BigInt(10).pow(scale)).abs.toString() + Literal.create( + Decimal(BigDecimal(s"$a.$b"), precision, scale), + DecimalType(precision, scale)) + } + } lazy val stringLiteralGen: Gen[Literal] = for { s <- Arbitrary.arbString.arbitrary } yield Literal.create(s, StringType) @@ -122,6 +133,7 @@ object LiteralGenerator { case StringType => stringLiteralGen case BinaryType => binaryLiteralGen case CalendarIntervalType => calendarIntervalLiterGen + case DecimalType.Fixed(precision, scale) => decimalLiteralGen(precision, scale) case dt => throw new IllegalArgumentException(s"not supported type $dt") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 1b2a9163a3..88ed9fdd64 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -246,11 +246,21 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("ceil") { testUnary(Ceil, (d: Double) => math.ceil(d).toLong) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DoubleType) + + testUnary(Ceil, (d: Decimal) => d.ceil, (-20 to 20).map(x => Decimal(x * 0.1))) + checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3)) + checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0)) + checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0)) } test("floor") { testUnary(Floor, (d: Double) => math.floor(d).toLong) checkConsistencyBetweenInterpretedAndCodegen(Floor, DoubleType) + + testUnary(Floor, (d: Decimal) => d.floor, (-20 to 20).map(x => Decimal(x * 0.1))) + checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3)) + checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0)) + checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0)) } test("factorial") { -- GitLab