Skip to content
Snippets Groups Projects
Commit ba330968 authored by Wenchen Fan's avatar Wenchen Fan Committed by Reynold Xin
Browse files

[SPARK-9068][SQL] refactor the implicit type cast code

based on https://github.com/apache/spark/pull/7348

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #7420 from cloud-fan/type-check and squashes the following commits:

7633fa9 [Wenchen Fan] revert
fe169b0 [Wenchen Fan] improve test
03b70da [Wenchen Fan] enhance implicit type cast
parent 42dea3ac
No related branches found
No related tags found
No related merge requests found
Showing
with 81 additions and 126 deletions
......@@ -675,10 +675,10 @@ object HiveTypeCoercion {
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType =>
if (b.inputType.acceptsType(commonType)) {
// If the expression accepts the tighest common type, cast to that.
// If the expression accepts the tightest common type, cast to that.
val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
b.makeCopy(Array(newLeft, newRight))
b.withNewChildren(Seq(newLeft, newRight))
} else {
// Otherwise, don't do anything with the expression.
b
......@@ -697,7 +697,7 @@ object HiveTypeCoercion {
// general implicit casting.
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
if (in.dataType == NullType && !expected.acceptsType(NullType)) {
Cast(in, expected.defaultConcreteType)
Literal.create(null, expected.defaultConcreteType)
} else {
in
}
......@@ -719,27 +719,22 @@ object HiveTypeCoercion {
@Nullable val ret: Expression = (inType, expectedType) match {
// If the expected type is already a parent of the input type, no need to cast.
case _ if expectedType.isSameType(inType) => e
case _ if expectedType.acceptsType(inType) => e
// Cast null type (usually from null literals) into target types
case (NullType, target) => Cast(e, target.defaultConcreteType)
// If the function accepts any numeric type (i.e. the ADT `NumericType`) and the input is
// already a number, leave it as is.
case (_: NumericType, NumericType) => e
// If the function accepts any numeric type and the input is a string, we follow the hive
// convention and cast that input into a double
case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType)
// Implicit cast among numeric types
// Implicit cast among numeric types. When we reach here, input type is not acceptable.
// If input is a numeric type but not decimal, and we expect a decimal type,
// cast the input to unlimited precision decimal.
case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] =>
Cast(e, DecimalType.Unlimited)
case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited)
// For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long
case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target)
case (_: NumericType, target: NumericType) => e
case (_: NumericType, target: NumericType) => Cast(e, target)
// Implicit cast between date time types
case (DateType, TimestampType) => Cast(e, TimestampType)
......@@ -753,15 +748,9 @@ object HiveTypeCoercion {
case (StringType, BinaryType) => Cast(e, BinaryType)
case (any, StringType) if any != StringType => Cast(e, StringType)
// Type collection.
// First see if we can find our input type in the type collection. If we can, then just
// use the current expression; otherwise, find the first one we can implicitly cast.
case (_, TypeCollection(types)) =>
if (types.exists(_.isSameType(inType))) {
e
} else {
types.flatMap(implicitCast(e, _)).headOption.orNull
}
// When we reach here, input type is not acceptable for any types in this type collection,
// try to find the first one we can implicitly cast.
case (_, TypeCollection(types)) => types.flatMap(implicitCast(e, _)).headOption.orNull
// Else, just return the same input expression
case _ => null
......
......@@ -386,17 +386,15 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType)
override def checkInputDataTypes(): TypeCheckResult = {
// First call the checker for ExpectsInputTypes, and then check whether left and right have
// the same type.
super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
if (left.dataType != right.dataType) {
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
} else {
TypeCheckResult.TypeCheckSuccess
}
case TypeCheckResult.TypeCheckFailure(msg) => TypeCheckResult.TypeCheckFailure(msg)
// First check whether left and right have the same type, then check if the type is acceptable.
if (left.dataType != right.dataType) {
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
} else if (!inputType.acceptsType(left.dataType)) {
TypeCheckResult.TypeCheckFailure(s"'$prettyString' accepts ${inputType.simpleString} type," +
s" not ${left.dataType.simpleString}")
} else {
TypeCheckResult.TypeCheckSuccess
}
}
}
......
......@@ -320,7 +320,6 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
override def symbol: String = "max"
override def prettyName: String = symbol
}
case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
......@@ -375,7 +374,6 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
}
override def symbol: String = "min"
override def prettyName: String = symbol
}
case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
......
......@@ -28,7 +28,7 @@ import org.apache.spark.sql.types._
*/
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
override def inputType: AbstractDataType = TypeCollection.Bitwise
override def inputType: AbstractDataType = IntegralType
override def symbol: String = "&"
......@@ -53,7 +53,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
*/
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
override def inputType: AbstractDataType = TypeCollection.Bitwise
override def inputType: AbstractDataType = IntegralType
override def symbol: String = "|"
......@@ -78,7 +78,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
*/
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
override def inputType: AbstractDataType = TypeCollection.Bitwise
override def inputType: AbstractDataType = IntegralType
override def symbol: String = "^"
......@@ -101,7 +101,7 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
*/
case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.Bitwise)
override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)
override def dataType: DataType = child.dataType
......
......@@ -35,8 +35,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
TypeCheckResult.TypeCheckFailure(
s"type of predicate expression in If should be boolean, not ${predicate.dataType}")
} else if (trueValue.dataType != falseValue.dataType) {
TypeCheckResult.TypeCheckFailure(
s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).")
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).")
} else {
TypeCheckResult.TypeCheckSuccess
}
......
......@@ -34,32 +34,18 @@ private[sql] abstract class AbstractDataType {
private[sql] def defaultConcreteType: DataType
/**
* Returns true if this data type is the same type as `other`. This is different that equality
* as equality will also consider data type parametrization, such as decimal precision.
* Returns true if `other` is an acceptable input type for a function that expects this,
* possibly abstract DataType.
*
* {{{
* // this should return true
* DecimalType.isSameType(DecimalType(10, 2))
*
* // this should return false
* NumericType.isSameType(DecimalType(10, 2))
* }}}
*/
private[sql] def isSameType(other: DataType): Boolean
/**
* Returns true if `other` is an acceptable input type for a function that expectes this,
* possibly abstract, DataType.
*
* {{{
* // this should return true
* DecimalType.isSameType(DecimalType(10, 2))
* DecimalType.acceptsType(DecimalType(10, 2))
*
* // this should return true as well
* NumericType.acceptsType(DecimalType(10, 2))
* }}}
*/
private[sql] def acceptsType(other: DataType): Boolean = isSameType(other)
private[sql] def acceptsType(other: DataType): Boolean
/** Readable string representation for the type. */
private[sql] def simpleString: String
......@@ -83,10 +69,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])
override private[sql] def defaultConcreteType: DataType = types.head.defaultConcreteType
override private[sql] def isSameType(other: DataType): Boolean = false
override private[sql] def acceptsType(other: DataType): Boolean =
types.exists(_.isSameType(other))
types.exists(_.acceptsType(other))
override private[sql] def simpleString: String = {
types.map(_.simpleString).mkString("(", " or ", ")")
......@@ -107,13 +91,6 @@ private[sql] object TypeCollection {
TimestampType, DateType,
StringType, BinaryType)
/**
* Types that can be used in bitwise operations.
*/
val Bitwise = TypeCollection(
BooleanType,
ByteType, ShortType, IntegerType, LongType)
def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types)
def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match {
......@@ -134,8 +111,6 @@ protected[sql] object AnyDataType extends AbstractDataType {
override private[sql] def simpleString: String = "any"
override private[sql] def isSameType(other: DataType): Boolean = false
override private[sql] def acceptsType(other: DataType): Boolean = true
}
......@@ -183,13 +158,11 @@ private[sql] object NumericType extends AbstractDataType {
override private[sql] def simpleString: String = "numeric"
override private[sql] def isSameType(other: DataType): Boolean = false
override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType]
}
private[sql] object IntegralType {
private[sql] object IntegralType extends AbstractDataType {
/**
* Enables matching against IntegralType for expressions:
* {{{
......@@ -198,6 +171,12 @@ private[sql] object IntegralType {
* }}}
*/
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType]
override private[sql] def defaultConcreteType: DataType = IntegerType
override private[sql] def simpleString: String = "integral"
override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[IntegralType]
}
......
......@@ -28,7 +28,7 @@ object ArrayType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true)
override private[sql] def isSameType(other: DataType): Boolean = {
override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[ArrayType]
}
......
......@@ -79,7 +79,7 @@ abstract class DataType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = this
override private[sql] def isSameType(other: DataType): Boolean = this == other
override private[sql] def acceptsType(other: DataType): Boolean = this == other
}
......
......@@ -86,7 +86,7 @@ object DecimalType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = Unlimited
override private[sql] def isSameType(other: DataType): Boolean = {
override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[DecimalType]
}
......
......@@ -71,7 +71,7 @@ object MapType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType)
override private[sql] def isSameType(other: DataType): Boolean = {
override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[MapType]
}
......
......@@ -307,7 +307,7 @@ object StructType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = new StructType
override private[sql] def isSameType(other: DataType): Boolean = {
override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[StructType]
}
......
......@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.{TypeCollection, StringType}
class ExpressionTypeCheckingSuite extends SparkFunSuite {
......@@ -49,23 +49,16 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
def assertErrorForDifferingTypes(expr: Expression): Unit = {
assertError(expr,
s"differing types in '${expr.prettyString}' (int and boolean)")
}
def assertErrorWithImplicitCast(expr: Expression, errorMessage: String): Unit = {
val e = intercept[AnalysisException] {
assertSuccess(expr)
}
assert(e.getMessage.contains(errorMessage))
s"differing types in '${expr.prettyString}'")
}
test("check types for unary arithmetic") {
assertError(UnaryMinus('stringField), "expected to be of type numeric")
assertError(Abs('stringField), "expected to be of type numeric")
assertError(BitwiseNot('stringField), "type (boolean or tinyint or smallint or int or bigint)")
assertError(BitwiseNot('stringField), "expected to be of type integral")
}
ignore("check types for binary arithmetic") {
test("check types for binary arithmetic") {
// We will cast String to Double for binary arithmetic
assertSuccess(Add('intField, 'stringField))
assertSuccess(Subtract('intField, 'stringField))
......@@ -85,21 +78,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertErrorForDifferingTypes(MaxOf('intField, 'booleanField))
assertErrorForDifferingTypes(MinOf('intField, 'booleanField))
assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type")
assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type")
assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type")
assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type")
assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type")
assertError(Add('booleanField, 'booleanField), "accepts numeric type")
assertError(Subtract('booleanField, 'booleanField), "accepts numeric type")
assertError(Multiply('booleanField, 'booleanField), "accepts numeric type")
assertError(Divide('booleanField, 'booleanField), "accepts numeric type")
assertError(Remainder('booleanField, 'booleanField), "accepts numeric type")
assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts integral type")
assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts integral type")
assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts integral type")
assertError(BitwiseAnd('booleanField, 'booleanField), "accepts integral type")
assertError(BitwiseOr('booleanField, 'booleanField), "accepts integral type")
assertError(BitwiseXor('booleanField, 'booleanField), "accepts integral type")
assertError(MaxOf('complexField, 'complexField), "function maxOf accepts non-complex type")
assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type")
assertError(MaxOf('complexField, 'complexField),
s"accepts ${TypeCollection.Ordered.simpleString} type")
assertError(MinOf('complexField, 'complexField),
s"accepts ${TypeCollection.Ordered.simpleString} type")
}
ignore("check types for predicates") {
test("check types for predicates") {
// We will cast String to Double for binary comparison
assertSuccess(EqualTo('intField, 'stringField))
assertSuccess(EqualNullSafe('intField, 'stringField))
......@@ -112,25 +107,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertSuccess(EqualTo('intField, 'booleanField))
assertSuccess(EqualNullSafe('intField, 'booleanField))
assertError(EqualTo('intField, 'complexField), "differing types")
assertError(EqualNullSafe('intField, 'complexField), "differing types")
assertErrorForDifferingTypes(EqualTo('intField, 'complexField))
assertErrorForDifferingTypes(EqualNullSafe('intField, 'complexField))
assertErrorForDifferingTypes(LessThan('intField, 'booleanField))
assertErrorForDifferingTypes(LessThanOrEqual('intField, 'booleanField))
assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))
assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))
assertError(
LessThan('complexField, 'complexField), "operator < accepts non-complex type")
assertError(
LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type")
assertError(
GreaterThan('complexField, 'complexField), "operator > accepts non-complex type")
assertError(
GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type")
assertError(LessThan('complexField, 'complexField),
s"accepts ${TypeCollection.Ordered.simpleString} type")
assertError(LessThanOrEqual('complexField, 'complexField),
s"accepts ${TypeCollection.Ordered.simpleString} type")
assertError(GreaterThan('complexField, 'complexField),
s"accepts ${TypeCollection.Ordered.simpleString} type")
assertError(GreaterThanOrEqual('complexField, 'complexField),
s"accepts ${TypeCollection.Ordered.simpleString} type")
assertError(
If('intField, 'stringField, 'stringField),
assertError(If('intField, 'stringField, 'stringField),
"type of predicate expression in If should be boolean")
assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField))
......@@ -180,12 +173,12 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
}
test("check types for ROUND") {
assertErrorWithImplicitCast(Round(Literal(null), 'booleanField),
"data type mismatch: argument 2 is expected to be of type int")
assertErrorWithImplicitCast(Round(Literal(null), 'complexField),
"data type mismatch: argument 2 is expected to be of type int")
assertSuccess(Round(Literal(null), Literal(null)))
assertError(Round('booleanField, 'intField),
"data type mismatch: argument 1 is expected to be of type numeric")
assertSuccess(Round('intField, Literal(1)))
assertError(Round('intField, 'intField), "Only foldable Expression is allowed")
assertError(Round('intField, 'booleanField), "expected to be of type int")
assertError(Round('intField, 'complexField), "expected to be of type int")
assertError(Round('booleanField, 'intField), "expected to be of type numeric")
}
}
......@@ -203,7 +203,7 @@ class HiveTypeCoercionSuite extends PlanTest {
ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
NumericTypeUnaryExpression(Literal.create(null, NullType)),
NumericTypeUnaryExpression(Cast(Literal.create(null, NullType), DoubleType)))
NumericTypeUnaryExpression(Literal.create(null, DoubleType)))
}
test("cast NullType for binary operators") {
......@@ -215,9 +215,7 @@ class HiveTypeCoercionSuite extends PlanTest {
ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
NumericTypeBinaryOperator(
Cast(Literal.create(null, NullType), DoubleType),
Cast(Literal.create(null, NullType), DoubleType)))
NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType)))
}
test("coalesce casts") {
......@@ -345,14 +343,14 @@ object HiveTypeCoercionSuite {
}
case class AnyTypeBinaryOperator(left: Expression, right: Expression)
extends BinaryOperator with ExpectsInputTypes {
extends BinaryOperator {
override def dataType: DataType = NullType
override def inputType: AbstractDataType = AnyDataType
override def symbol: String = "anytype"
}
case class NumericTypeBinaryOperator(left: Expression, right: Expression)
extends BinaryOperator with ExpectsInputTypes {
extends BinaryOperator {
override def dataType: DataType = NullType
override def inputType: AbstractDataType = NumericType
override def symbol: String = "numerictype"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment