Skip to content
Snippets Groups Projects
Commit 2c3cc764 authored by Cheng Hao's avatar Cheng Hao Committed by Michael Armbrust
Browse files

[SPARK-3501] [SQL] Fix the bug of Hive SimpleUDF creates unnecessary type cast

When do the query like:
```
select datediff(cast(value as timestamp), cast('2002-03-21 00:00:00' as timestamp)) from src;
```
SparkSQL will raise exception:
```
[info] scala.MatchError: TimestampType (of class org.apache.spark.sql.catalyst.types.TimestampType$)
[info] at org.apache.spark.sql.catalyst.expressions.Cast.castToTimestamp(Cast.scala:77)
[info] at org.apache.spark.sql.catalyst.expressions.Cast.cast$lzycompute(Cast.scala:251)
[info] at org.apache.spark.sql.catalyst.expressions.Cast.cast(Cast.scala:247)
[info] at org.apache.spark.sql.catalyst.expressions.Cast.eval(Cast.scala:263)
[info] at org.apache.spark.sql.catalyst.optimizer.ConstantFolding$$anonfun$apply$5$$anonfun$applyOrElse$2.applyOrElse(Optimizer.scala:217)
[info] at org.apache.spark.sql.catalyst.optimizer.ConstantFolding$$anonfun$apply$5$$anonfun$applyOrElse$2.applyOrElse(Optimizer.scala:210)
[info] at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:144)
[info] at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4$$anonfun$apply$2.apply(TreeNode.scala:180)
[info] at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
[info] at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
```

Author: Cheng Hao <hao.cheng@intel.com>

Closes #2368 from chenghao-intel/cast_exception and squashes the following commits:

5c9c3a5 [Cheng Hao] make more clear code
49dfc50 [Cheng Hao] Add no-op for Cast and revert the position of SimplifyCasts
b804abd [Cheng Hao] Add unit test to show the failure in identical data type casting
330a5c8 [Cheng Hao] Update Code based on comments
b834ed4 [Cheng Hao] Fix bug of HiveSimpleUDF with unnecessary type cast which cause exception in constant folding
parent fce5e251
No related branches found
No related tags found
No related merge requests found
......@@ -245,6 +245,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
}
private[this] lazy val cast: Any => Any = dataType match {
case dt if dt == child.dataType => identity[Any]
case StringType => castToString
case BinaryType => castToBinary
case DecimalType => castToDecimal
......
......@@ -51,12 +51,13 @@ private[hive] abstract class HiveFunctionRegistry
val function = functionInfo.getFunctionClass.newInstance().asInstanceOf[UDF]
val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo))
lazy val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType)
val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType)
HiveSimpleUdf(
functionClassName,
children.zip(expectedDataTypes).map {
case (e, NullType) => e
case (e, t) if (e.dataType == t) => e
case (e, t) => Cast(e, t)
}
)
......
......@@ -142,16 +142,25 @@ class HiveQuerySuite extends HiveComparisonTest {
setConf("spark.sql.dialect", "sql")
assert(sql("SELECT 1").collect() === Array(Seq(1)))
setConf("spark.sql.dialect", "hiveql")
}
test("Query expressed in HiveQL") {
sql("FROM src SELECT key").collect()
}
test("Query with constant folding the CAST") {
sql("SELECT CAST(CAST('123' AS binary) AS binary) FROM src LIMIT 1").collect()
}
createQueryTest("Constant Folding Optimization for AVG_SUM_COUNT",
"SELECT AVG(0), SUM(0), COUNT(null), COUNT(value) FROM src GROUP BY key")
createQueryTest("Cast Timestamp to Timestamp in UDF",
"""
| SELECT DATEDIFF(CAST(value AS timestamp), CAST('2002-03-21 00:00:00' AS timestamp))
| FROM src LIMIT 1
""".stripMargin)
createQueryTest("Simple Average",
"SELECT AVG(key) FROM src")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment