diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 87ffbfe791b935569a0725c0f147624078cbd4d4..e0527503442f04d2c897e9cf8684594a069ca5fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -19,9 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ @@ -38,7 +36,7 @@ object HiveTypeCoercion { val typeCoercionRules = PropagateTypes :: InConversion :: - WidenTypes :: + WidenSetOperationTypes :: PromoteStrings :: DecimalPrecision :: BooleanEquality :: @@ -175,7 +173,7 @@ object HiveTypeCoercion { * * This rule is only applied to Union/Except/Intersect */ - object WidenTypes extends Rule[LogicalPlan] { + object WidenSetOperationTypes extends Rule[LogicalPlan] { private[this] def widenOutputTypes( planName: String, @@ -203,9 +201,9 @@ object HiveTypeCoercion { def castOutput(plan: LogicalPlan): LogicalPlan = { val casted = plan.output.zip(castedTypes).map { - case (hs, Some(dt)) if dt != hs.dataType => - Alias(Cast(hs, dt), hs.name)() - case (hs, _) => hs + case (e, Some(dt)) if e.dataType != dt => + Alias(Cast(e, dt), e.name)() + case (e, _) => e } Project(casted, plan) } @@ -355,20 +353,8 @@ object HiveTypeCoercion { DecimalType.bounded(range + scale, scale) } - /** - * An expression used to wrap the children when promote the precision of DecimalType to avoid - * promote multiple times. - */ - case class ChangePrecision(child: Expression) extends UnaryExpression { - override def dataType: DataType = child.dataType - override def eval(input: InternalRow): Any = child.eval(input) - override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" - override def prettyName: String = "change_precision" - } - - def changePrecision(e: Expression, dataType: DataType): Expression = { - ChangePrecision(Cast(e, dataType)) + private def changePrecision(e: Expression, dataType: DataType): Expression = { + ChangeDecimalPrecision(Cast(e, dataType)) } def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -378,7 +364,7 @@ object HiveTypeCoercion { case e if !e.childrenResolved => e // Skip nodes who is already promoted - case e: BinaryArithmetic if e.left.isInstanceOf[ChangePrecision] => e + case e: BinaryArithmetic if e.left.isInstanceOf[ChangeDecimalPrecision] => e case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index b9d4736a65e268196aa0b866c4da6f972c95dbdf..adb33e4c8d4a12af1fe7e82607e37c0269c6c3b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types._ @@ -60,3 +61,15 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un }) } } + +/** + * An expression used to wrap the children when promote the precision of DecimalType to avoid + * promote multiple times. + */ +case class ChangeDecimalPrecision(child: Expression) extends UnaryExpression { + override def dataType: DataType = child.dataType + override def eval(input: InternalRow): Any = child.eval(input) + override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" + override def prettyName: String = "change_decimal_precision" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 26b24616d98ec41031c275732fd779d8211bc670..0cd352d0fa928048a30fc6f68ded4652f0a2bfe6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -78,6 +78,10 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { override def toString: String = s"DecimalType($precision,$scale)" + /** + * Returns whether this DecimalType is wider than `other`. If yes, it means `other` + * can be casted into `this` safely without losing any precision or range. + */ private[sql] def isWiderThan(other: DataType): Boolean = other match { case dt: DecimalType => (precision - scale) >= (dt.precision - dt.scale) && scale >= dt.scale @@ -109,7 +113,7 @@ object DecimalType extends AbstractDataType { @deprecated("Does not support unlimited precision, please specify the precision and scale", "1.5") val Unlimited: DecimalType = SYSTEM_DEFAULT - // The decimal types compatible with other numberic types + // The decimal types compatible with other numeric types private[sql] val ByteDecimal = DecimalType(3, 0) private[sql] val ShortDecimal = DecimalType(5, 0) private[sql] val IntDecimal = DecimalType(10, 0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index f9f15e7a6608d12e5472f64e9acdb28266b4dda8..fc11627da6fd1473986c09f5ecd9423d93009178 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -154,4 +154,30 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { checkType(Remainder(expr, u), DoubleType) } } + + test("DecimalType.isWiderThan") { + val d0 = DecimalType(2, 0) + val d1 = DecimalType(2, 1) + val d2 = DecimalType(5, 2) + val d3 = DecimalType(15, 3) + val d4 = DecimalType(25, 4) + + assert(d0.isWiderThan(d1) === false) + assert(d1.isWiderThan(d0) === false) + assert(d1.isWiderThan(d2) === false) + assert(d2.isWiderThan(d1) === true) + assert(d2.isWiderThan(d3) === false) + assert(d3.isWiderThan(d2) === true) + assert(d4.isWiderThan(d3) === true) + + assert(d1.isWiderThan(ByteType) === false) + assert(d2.isWiderThan(ByteType) === true) + assert(d2.isWiderThan(ShortType) === false) + assert(d3.isWiderThan(ShortType) === true) + assert(d3.isWiderThan(IntegerType) === true) + assert(d3.isWiderThan(LongType) === false) + assert(d4.isWiderThan(LongType) === true) + assert(d4.isWiderThan(FloatType) === false) + assert(d4.isWiderThan(DoubleType) === false) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 55865bdb534b4531f2da8b65a671b6f8ff205abc..4454d51b758770425d51c7b70d89e3783af44fb5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -314,7 +314,7 @@ class HiveTypeCoercionSuite extends PlanTest { ) } - test("WidenTypes for union except and intersect") { + test("WidenSetOperationTypes for union except and intersect") { def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { logical.output.zip(expectTypes).foreach { case (attr, dt) => assert(attr.dataType === dt) @@ -332,7 +332,7 @@ class HiveTypeCoercionSuite extends PlanTest { AttributeReference("f", FloatType)(), AttributeReference("l", LongType)()) - val wt = HiveTypeCoercion.WidenTypes + val wt = HiveTypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) val r1 = wt(Union(left, right)).asInstanceOf[Union] @@ -353,7 +353,7 @@ class HiveTypeCoercionSuite extends PlanTest { } } - val dp = HiveTypeCoercion.WidenTypes + val dp = HiveTypeCoercion.WidenSetOperationTypes val left1 = LocalRelation( AttributeReference("l", DecimalType(10, 8))())