Skip to content
Snippets Groups Projects
Commit 70d495dc authored by jiangxingbo's avatar jiangxingbo Committed by Herman van Hovell
Browse files

[SPARK-18624][SQL] Implicit cast ArrayType(InternalType)

## What changes were proposed in this pull request?

Currently `ImplicitTypeCasts` doesn't handle casts between `ArrayType`s, this is not convenient, we should add a rule to enable casting from `ArrayType(InternalType)` to `ArrayType(newInternalType)`.

Goals:
1. Add a rule to `ImplicitTypeCasts` to enable casting between `ArrayType`s;
2. Simplify `Percentile` and `ApproximatePercentile`.

## How was this patch tested?

Updated test cases in `TypeCoercionSuite`.

Author: jiangxingbo <jiangxb1987@gmail.com>

Closes #16057 from jiangxb1987/implicit-cast-complex-types.
parent 7a75ee1c
No related branches found
No related tags found
No related merge requests found
......@@ -673,48 +673,69 @@ object TypeCoercion {
* If the expression has an incompatible type that cannot be implicitly cast, return None.
*/
def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = {
val inType = e.dataType
implicitCast(e.dataType, expectedType).map { dt =>
if (dt == e.dataType) e else Cast(e, dt)
}
}
private def implicitCast(inType: DataType, expectedType: AbstractDataType): Option[DataType] = {
// Note that ret is nullable to avoid typing a lot of Some(...) in this local scope.
// We wrap immediately an Option after this.
@Nullable val ret: Expression = (inType, expectedType) match {
@Nullable val ret: DataType = (inType, expectedType) match {
// If the expected type is already a parent of the input type, no need to cast.
case _ if expectedType.acceptsType(inType) => e
case _ if expectedType.acceptsType(inType) => inType
// Cast null type (usually from null literals) into target types
case (NullType, target) => Cast(e, target.defaultConcreteType)
case (NullType, target) => target.defaultConcreteType
// If the function accepts any numeric type and the input is a string, we follow the hive
// convention and cast that input into a double
case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType)
case (StringType, NumericType) => NumericType.defaultConcreteType
// Implicit cast among numeric types. When we reach here, input type is not acceptable.
// If input is a numeric type but not decimal, and we expect a decimal type,
// cast the input to decimal.
case (d: NumericType, DecimalType) => Cast(e, DecimalType.forType(d))
case (d: NumericType, DecimalType) => DecimalType.forType(d)
// For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long
case (_: NumericType, target: NumericType) => Cast(e, target)
case (_: NumericType, target: NumericType) => target
// Implicit cast between date time types
case (DateType, TimestampType) => Cast(e, TimestampType)
case (TimestampType, DateType) => Cast(e, DateType)
case (DateType, TimestampType) => TimestampType
case (TimestampType, DateType) => DateType
// Implicit cast from/to string
case (StringType, DecimalType) => Cast(e, DecimalType.SYSTEM_DEFAULT)
case (StringType, target: NumericType) => Cast(e, target)
case (StringType, DateType) => Cast(e, DateType)
case (StringType, TimestampType) => Cast(e, TimestampType)
case (StringType, BinaryType) => Cast(e, BinaryType)
case (StringType, DecimalType) => DecimalType.SYSTEM_DEFAULT
case (StringType, target: NumericType) => target
case (StringType, DateType) => DateType
case (StringType, TimestampType) => TimestampType
case (StringType, BinaryType) => BinaryType
// Cast any atomic type to string.
case (any: AtomicType, StringType) if any != StringType => Cast(e, StringType)
case (any: AtomicType, StringType) if any != StringType => StringType
// When we reach here, input type is not acceptable for any types in this type collection,
// try to find the first one we can implicitly cast.
case (_, TypeCollection(types)) => types.flatMap(implicitCast(e, _)).headOption.orNull
case (_, TypeCollection(types)) =>
types.flatMap(implicitCast(inType, _)).headOption.orNull
// Implicit cast between array types.
//
// Compare the nullabilities of the from type and the to type, check whether the cast of
// the nullability is resolvable by the following rules:
// 1. If the nullability of the to type is true, the cast is always allowed;
// 2. If the nullability of the to type is false, and the nullability of the from type is
// true, the cast is never allowed;
// 3. If the nullabilities of both the from type and the to type are false, the cast is
// allowed only when Cast.forceNullable(fromType, toType) is false.
case (ArrayType(fromType, fn), ArrayType(toType: DataType, true)) =>
implicitCast(fromType, toType).map(ArrayType(_, true)).orNull
case (ArrayType(fromType, true), ArrayType(toType: DataType, false)) => null
case (ArrayType(fromType, false), ArrayType(toType: DataType, false))
if !Cast.forceNullable(fromType, toType) =>
implicitCast(fromType, toType).map(ArrayType(_, false)).orNull
// Else, just return the same input expression
case _ => null
}
Option(ret)
......
......@@ -89,9 +89,7 @@ object Cast {
case _ => false
}
private def resolvableNullability(from: Boolean, to: Boolean) = !from || to
private def forceNullable(from: DataType, to: DataType) = (from, to) match {
def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match {
case (NullType, _) => true
case (_, _) if from == to => false
......@@ -110,6 +108,8 @@ object Cast {
case (_: FractionalType, _: IntegralType) => true // NaN, infinity
case _ => false
}
private def resolvableNullability(from: Boolean, to: Boolean) = !from || to
}
/** Cast the child expression to the target data type. */
......
......@@ -86,23 +86,16 @@ case class ApproximatePercentile(
private lazy val accuracy: Int = accuracyExpression.eval().asInstanceOf[Int]
override def inputTypes: Seq[AbstractDataType] = {
Seq(DoubleType, TypeCollection(DoubleType, ArrayType), IntegerType)
Seq(DoubleType, TypeCollection(DoubleType, ArrayType(DoubleType)), IntegerType)
}
// Mark as lazy so that percentageExpression is not evaluated during tree transformation.
private lazy val (returnPercentileArray: Boolean, percentages: Array[Double]) = {
(percentageExpression.dataType, percentageExpression.eval()) match {
private lazy val (returnPercentileArray: Boolean, percentages: Array[Double]) =
percentageExpression.eval() match {
// Rule ImplicitTypeCasts can cast other numeric types to double
case (_, num: Double) => (false, Array(num))
case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) =>
val numericArray = arrayData.toObjectArray(baseType)
(true, numericArray.map { x =>
baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])
})
case other =>
throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage")
case num: Double => (false, Array(num))
case arrayData: ArrayData => (true, arrayData.toDoubleArray())
}
}
override def checkInputDataTypes(): TypeCheckResult = {
val defaultCheck = super.checkInputDataTypes()
......@@ -162,7 +155,7 @@ case class ApproximatePercentile(
override def nullable: Boolean = true
override def dataType: DataType = {
if (returnPercentileArray) ArrayType(DoubleType) else DoubleType
if (returnPercentileArray) ArrayType(DoubleType, false) else DoubleType
}
override def prettyName: String = "percentile_approx"
......
......@@ -77,15 +77,9 @@ case class Percentile(
private lazy val returnPercentileArray = percentageExpression.dataType.isInstanceOf[ArrayType]
@transient
private lazy val percentages =
(percentageExpression.dataType, percentageExpression.eval()) match {
case (_, num: Double) => Seq(num)
case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) =>
val numericArray = arrayData.toObjectArray(baseType)
numericArray.map { x =>
baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])}.toSeq
case other =>
throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentages")
private lazy val percentages = percentageExpression.eval() match {
case num: Double => Seq(num)
case arrayData: ArrayData => arrayData.toDoubleArray().toSeq
}
override def children: Seq[Expression] = child :: percentageExpression :: Nil
......@@ -99,7 +93,7 @@ case class Percentile(
}
override def inputTypes: Seq[AbstractDataType] = percentageExpression.dataType match {
case _: ArrayType => Seq(NumericType, ArrayType)
case _: ArrayType => Seq(NumericType, ArrayType(DoubleType))
case _ => Seq(NumericType, DoubleType)
}
......
......@@ -57,14 +57,43 @@ class TypeCoercionSuite extends PlanTest {
// scalastyle:on line.size.limit
private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = {
val got = TypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
assert(got.map(_.dataType) == Option(expected),
// Check default value
val castDefault = TypeCoercion.ImplicitTypeCasts.implicitCast(default(from), to)
assert(DataType.equalsIgnoreCompatibleNullability(
castDefault.map(_.dataType).getOrElse(null), expected),
s"Failed to cast $from to $to")
// Check null value
val castNull = TypeCoercion.ImplicitTypeCasts.implicitCast(createNull(from), to)
assert(DataType.equalsIgnoreCaseAndNullability(
castNull.map(_.dataType).getOrElse(null), expected),
s"Failed to cast $from to $to")
}
private def shouldNotCast(from: DataType, to: AbstractDataType): Unit = {
val got = TypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
assert(got.isEmpty, s"Should not be able to cast $from to $to, but got $got")
// Check default value
val castDefault = TypeCoercion.ImplicitTypeCasts.implicitCast(default(from), to)
assert(castDefault.isEmpty, s"Should not be able to cast $from to $to, but got $castDefault")
// Check null value
val castNull = TypeCoercion.ImplicitTypeCasts.implicitCast(createNull(from), to)
assert(castNull.isEmpty, s"Should not be able to cast $from to $to, but got $castNull")
}
private def default(dataType: DataType): Expression = dataType match {
case ArrayType(internalType: DataType, _) =>
CreateArray(Seq(Literal.default(internalType)))
case MapType(keyDataType: DataType, valueDataType: DataType, _) =>
CreateMap(Seq(Literal.default(keyDataType), Literal.default(valueDataType)))
case _ => Literal.default(dataType)
}
private def createNull(dataType: DataType): Expression = dataType match {
case ArrayType(internalType: DataType, _) =>
CreateArray(Seq(Literal.create(null, internalType)))
case MapType(keyDataType: DataType, valueDataType: DataType, _) =>
CreateMap(Seq(Literal.create(null, keyDataType), Literal.create(null, valueDataType)))
case _ => Literal.create(null, dataType)
}
val integralTypes: Seq[DataType] =
......@@ -196,7 +225,13 @@ class TypeCoercionSuite extends PlanTest {
test("implicit type cast - ArrayType(StringType)") {
val checkedType = ArrayType(StringType)
checkTypeCasting(checkedType, castableTypes = Seq(checkedType))
val nonCastableTypes =
complexTypes ++ Seq(BooleanType, NullType, CalendarIntervalType)
checkTypeCasting(checkedType,
castableTypes = allTypes.filterNot(nonCastableTypes.contains).map(ArrayType(_)))
nonCastableTypes.map(ArrayType(_)).foreach(shouldNotCast(checkedType, _))
shouldNotCast(ArrayType(DoubleType, containsNull = false),
ArrayType(LongType, containsNull = false))
shouldNotCast(checkedType, DecimalType)
shouldNotCast(checkedType, NumericType)
shouldNotCast(checkedType, IntegralType)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment