Skip to content
Snippets Groups Projects
Commit 347cab85 authored by Reynold Xin's avatar Reynold Xin
Browse files

[SQL] More unit tests for implicit type cast & add simpleString to AbstractDataType.

Author: Reynold Xin <rxin@databricks.com>

Closes #7221 from rxin/implicit-cast-tests and squashes the following commits:

64b13bd [Reynold Xin] Fixed a bug ..
489b732 [Reynold Xin] [SQL] More unit tests for implicit type cast & add simpleString to AbstractDataType.
parent 48f7aed6
No related branches found
No related tags found
No related merge requests found
...@@ -40,7 +40,7 @@ trait CheckAnalysis { ...@@ -40,7 +40,7 @@ trait CheckAnalysis {
def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
exprs.flatMap(_.collect { exprs.flatMap(_.collect {
case e: Generator => true case e: Generator => true
}).length >= 1 }).nonEmpty
} }
def checkAnalysis(plan: LogicalPlan): Unit = { def checkAnalysis(plan: LogicalPlan): Unit = {
...@@ -85,12 +85,12 @@ trait CheckAnalysis { ...@@ -85,12 +85,12 @@ trait CheckAnalysis {
case Aggregate(groupingExprs, aggregateExprs, child) => case Aggregate(groupingExprs, aggregateExprs, child) =>
def checkValidAggregateExpression(expr: Expression): Unit = expr match { def checkValidAggregateExpression(expr: Expression): Unit = expr match {
case _: AggregateExpression => // OK case _: AggregateExpression => // OK
case e: Attribute if groupingExprs.find(_ semanticEquals e).isEmpty => case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) =>
failAnalysis( failAnalysis(
s"expression '${e.prettyString}' is neither present in the group by, " + s"expression '${e.prettyString}' is neither present in the group by, " +
s"nor is it an aggregate function. " + s"nor is it an aggregate function. " +
"Add to group by or wrap in first() if you don't care which value you get.") "Add to group by or wrap in first() if you don't care which value you get.")
case e if groupingExprs.find(_ semanticEquals e).isDefined => // OK case e if groupingExprs.exists(_.semanticEquals(e)) => // OK
case e if e.references.isEmpty => // OK case e if e.references.isEmpty => // OK
case e => e.children.foreach(checkValidAggregateExpression) case e => e.children.foreach(checkValidAggregateExpression)
} }
......
...@@ -37,6 +37,9 @@ private[sql] abstract class AbstractDataType { ...@@ -37,6 +37,9 @@ private[sql] abstract class AbstractDataType {
* Returns true if this data type is a parent of the `childCandidate`. * Returns true if this data type is a parent of the `childCandidate`.
*/ */
private[sql] def isParentOf(childCandidate: DataType): Boolean private[sql] def isParentOf(childCandidate: DataType): Boolean
/** Readable string representation for the type. */
private[sql] def simpleString: String
} }
...@@ -56,6 +59,10 @@ private[sql] class TypeCollection(private val types: Seq[DataType]) extends Abst ...@@ -56,6 +59,10 @@ private[sql] class TypeCollection(private val types: Seq[DataType]) extends Abst
private[sql] override def defaultConcreteType: DataType = types.head private[sql] override def defaultConcreteType: DataType = types.head
private[sql] override def isParentOf(childCandidate: DataType): Boolean = false private[sql] override def isParentOf(childCandidate: DataType): Boolean = false
private[sql] override def simpleString: String = {
types.map(_.simpleString).mkString("(", " or ", ")")
}
} }
......
...@@ -31,6 +31,8 @@ object ArrayType extends AbstractDataType { ...@@ -31,6 +31,8 @@ object ArrayType extends AbstractDataType {
private[sql] override def isParentOf(childCandidate: DataType): Boolean = { private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
childCandidate.isInstanceOf[ArrayType] childCandidate.isInstanceOf[ArrayType]
} }
private[sql] override def simpleString: String = "array"
} }
......
...@@ -90,6 +90,8 @@ object DecimalType extends AbstractDataType { ...@@ -90,6 +90,8 @@ object DecimalType extends AbstractDataType {
childCandidate.isInstanceOf[DecimalType] childCandidate.isInstanceOf[DecimalType]
} }
private[sql] override def simpleString: String = "decimal"
val Unlimited: DecimalType = DecimalType(None) val Unlimited: DecimalType = DecimalType(None)
private[sql] object Fixed { private[sql] object Fixed {
......
...@@ -75,6 +75,8 @@ object MapType extends AbstractDataType { ...@@ -75,6 +75,8 @@ object MapType extends AbstractDataType {
childCandidate.isInstanceOf[MapType] childCandidate.isInstanceOf[MapType]
} }
private[sql] override def simpleString: String = "map"
/** /**
* Construct a [[MapType]] object with the given key type and value type. * Construct a [[MapType]] object with the given key type and value type.
* The `valueContainsNull` is true. * The `valueContainsNull` is true.
......
...@@ -309,6 +309,8 @@ object StructType extends AbstractDataType { ...@@ -309,6 +309,8 @@ object StructType extends AbstractDataType {
childCandidate.isInstanceOf[StructType] childCandidate.isInstanceOf[StructType]
} }
private[sql] override def simpleString: String = "struct"
def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray)
def apply(fields: java.util.List[StructField]): StructType = { def apply(fields: java.util.List[StructField]): StructType = {
......
...@@ -26,7 +26,7 @@ import org.apache.spark.sql.types._ ...@@ -26,7 +26,7 @@ import org.apache.spark.sql.types._
class HiveTypeCoercionSuite extends PlanTest { class HiveTypeCoercionSuite extends PlanTest {
test("implicit type cast") { test("eligible implicit type cast") {
def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = {
val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to) val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
assert(got.map(_.dataType) == Option(expected), assert(got.map(_.dataType) == Option(expected),
...@@ -68,6 +68,29 @@ class HiveTypeCoercionSuite extends PlanTest { ...@@ -68,6 +68,29 @@ class HiveTypeCoercionSuite extends PlanTest {
shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType) shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType)
shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType) shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType)
shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType) shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType)
shouldCast(IntegerType, TypeCollection(StringType, BinaryType), StringType)
shouldCast(IntegerType, TypeCollection(BinaryType, StringType), StringType)
}
test("ineligible implicit type cast") {
def shouldNotCast(from: DataType, to: AbstractDataType): Unit = {
val got = HiveTypeCoercion.ImplicitTypeCasts.implicitCast(Literal.create(null, from), to)
assert(got.isEmpty, s"Should not be able to cast $from to $to, but got $got")
}
shouldNotCast(IntegerType, DateType)
shouldNotCast(IntegerType, TimestampType)
shouldNotCast(LongType, DateType)
shouldNotCast(LongType, TimestampType)
shouldNotCast(DecimalType.Unlimited, DateType)
shouldNotCast(DecimalType.Unlimited, TimestampType)
shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType))
shouldNotCast(IntegerType, ArrayType)
shouldNotCast(IntegerType, MapType)
shouldNotCast(IntegerType, StructType)
} }
test("tightest common bound for types") { test("tightest common bound for types") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment