diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 2737fe32cd086ad5ef458c2cc91d556e7bd6f5eb..7df3787e6d2d3e8c9b48a774c9f09b6eab8497ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -27,10 +27,20 @@ import org.apache.spark.sql.types._ /** - * A collection of [[Rule Rules]] that can be used to coerce differing types that - * participate in operations into compatible ones. Most of these rules are based on Hive semantics, - * but they do not introduce any dependencies on the hive codebase. For this reason they remain in - * Catalyst until we have a more standard set of coercions. + * A collection of [[Rule Rules]] that can be used to coerce differing types that participate in + * operations into compatible ones. + * + * Most of these rules are based on Hive semantics, but they do not introduce any dependencies on + * the hive codebase. + * + * Notes about type widening / tightest common types: Broadly, there are two cases when we need + * to widen data types (e.g. union, binary comparison). In case 1, we are looking for a common + * data type for two or more data types, and in this case no loss of precision is allowed. Examples + * include type inference in JSON (e.g. what's the column's data type if one row is an integer + * while the other row is a long?). In case 2, we are looking for a widened data type with + * some acceptable loss of precision (e.g. there is no common type for double and decimal because + * double's range is larger than decimal, and yet decimal is more precise than double, but in + * union we would cast the decimal into double). */ object HiveTypeCoercion { @@ -63,6 +73,8 @@ object HiveTypeCoercion { DoubleType) /** + * Case 1 type widening (see the classdoc comment above for HiveTypeCoercion). + * * Find the tightest common type of two types that might be used in a binary expression. * This handles all numeric types except fixed-precision decimals interacting with each other or * with primitive types, because in that case the precision and scale of the result depends on @@ -118,6 +130,12 @@ object HiveTypeCoercion { }) } + /** + * Case 2 type widening (see the classdoc comment above for HiveTypeCoercion). + * + * i.e. the main difference with [[findTightestCommonTypeOfTwo]] is that here we allow some + * loss of precision when widening decimal and double. + */ private def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = (t1, t2) match { case (t1: DecimalType, t2: DecimalType) => Some(DecimalPrecision.widerDecimalType(t1, t2)) @@ -125,9 +143,7 @@ object HiveTypeCoercion { Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) case (d: DecimalType, t: IntegralType) => Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (t: FractionalType, d: DecimalType) => - Some(DoubleType) - case (d: DecimalType, t: FractionalType) => + case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) => Some(DoubleType) case _ => findTightestCommonTypeToString(t1, t2) @@ -200,41 +216,37 @@ object HiveTypeCoercion { */ object WidenSetOperationTypes extends Rule[LogicalPlan] { - private[this] def widenOutputTypes( - planName: String, - left: LogicalPlan, - right: LogicalPlan): (LogicalPlan, LogicalPlan) = { - require(left.output.length == right.output.length) + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if p.analyzed => p - val castedTypes = left.output.zip(right.output).map { - case (lhs, rhs) if lhs.dataType != rhs.dataType => - findWiderTypeForTwo(lhs.dataType, rhs.dataType) - case other => None - } + case s @ SetOperation(left, right) if s.childrenResolved + && left.output.length == right.output.length && !s.resolved => - def castOutput(plan: LogicalPlan): LogicalPlan = { - val casted = plan.output.zip(castedTypes).map { - case (e, Some(dt)) if e.dataType != dt => - Alias(Cast(e, dt), e.name)() - case (e, _) => e + // Tracks the list of data types to widen. + // Some(dataType) means the right-hand side and the left-hand side have different types, + // and there is a target type to widen both sides to. + val targetTypes: Seq[Option[DataType]] = left.output.zip(right.output).map { + case (lhs, rhs) if lhs.dataType != rhs.dataType => + findWiderTypeForTwo(lhs.dataType, rhs.dataType) + case other => None } - Project(casted, plan) - } - if (castedTypes.exists(_.isDefined)) { - (castOutput(left), castOutput(right)) - } else { - (left, right) - } + if (targetTypes.exists(_.isDefined)) { + // There is at least one column to widen. + s.makeCopy(Array(widenTypes(left, targetTypes), widenTypes(right, targetTypes))) + } else { + // If we cannot find any column to widen, then just return the original set. + s + } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case p if p.analyzed => p - - case s @ SetOperation(left, right) if s.childrenResolved - && left.output.length == right.output.length && !s.resolved => - val (newLeft, newRight) = widenOutputTypes(s.nodeName, left, right) - s.makeCopy(Array(newLeft, newRight)) + /** Given a plan, add an extra project on top to widen some columns' data types. */ + private def widenTypes(plan: LogicalPlan, targetTypes: Seq[Option[DataType]]): LogicalPlan = { + val casted = plan.output.zip(targetTypes).map { + case (e, Some(dt)) if e.dataType != dt => Alias(Cast(e, dt), e.name)() + case (e, _) => e + } + Project(casted, plan) } } @@ -372,8 +384,6 @@ object HiveTypeCoercion { * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE - * - * Note: Union/Except/Interact is handled by WidenTypes */ // scalastyle:on object DecimalPrecision extends Rule[LogicalPlan] { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index b1f6c0b802d8e958993be75c371c9f6bcd7286f5..b326aa9c55992d14348caf75eba065a247fd01b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -387,7 +387,7 @@ class HiveTypeCoercionSuite extends PlanTest { ) } - test("WidenSetOperationTypes for union except and intersect") { + test("WidenSetOperationTypes for union, except, and intersect") { def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { logical.output.zip(expectTypes).foreach { case (attr, dt) => assert(attr.dataType === dt) @@ -499,7 +499,6 @@ class HiveTypeCoercionSuite extends PlanTest { ruleTest(dateTimeOperations, Subtract(interval, interval), Subtract(interval, interval)) } - /** * There are rules that need to not fire before child expressions get resolved. * We use this test to make sure those rules do not fire early.