From 26ac2ce05cbaf8f152347219403e31491e9c9bf1 Mon Sep 17 00:00:00 2001
From: Kris Mok <kris.mok@databricks.com>
Date: Thu, 27 Apr 2017 12:08:16 -0700
Subject: [PATCH] [SPARK-20482][SQL] Resolving Casts is too strict on having
 time zone set

## What changes were proposed in this pull request?

Relax the requirement that a `TimeZoneAwareExpression` has to have its `timeZoneId` set to be considered resolved.
With this change, a `Cast` (which is a `TimeZoneAwareExpression`) can be considered resolved if the `(fromType, toType)` combination doesn't require time zone information.

Also de-relaxed test cases in `CastSuite` so Casts in that test suite don't get a default`timeZoneId = Option("GMT")`.

## How was this patch tested?

Ran the de-relaxed`CastSuite` and it's passing. Also ran the SQL unit tests and they're passing too.

Author: Kris Mok <kris.mok@databricks.com>

Closes #17777 from rednaxelafx/fix-catalyst-cast-timezone.
---
 .../spark/sql/catalyst/expressions/Cast.scala | 32 +++++++++++++++++++
 .../sql/catalyst/expressions/CastSuite.scala  |  4 +--
 2 files changed, 34 insertions(+), 2 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index bb1273f5c3..a53ef426f7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -89,6 +89,31 @@ object Cast {
     case _ => false
   }
 
+  /**
+   * Return true if we need to use the `timeZone` information casting `from` type to `to` type.
+   * The patterns matched reflect the current implementation in the Cast node.
+   * c.f. usage of `timeZone` in:
+   * * Cast.castToString
+   * * Cast.castToDate
+   * * Cast.castToTimestamp
+   */
+  def needsTimeZone(from: DataType, to: DataType): Boolean = (from, to) match {
+    case (StringType, TimestampType) => true
+    case (DateType, TimestampType) => true
+    case (TimestampType, StringType) => true
+    case (TimestampType, DateType) => true
+    case (ArrayType(fromType, _), ArrayType(toType, _)) => needsTimeZone(fromType, toType)
+    case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) =>
+      needsTimeZone(fromKey, toKey) || needsTimeZone(fromValue, toValue)
+    case (StructType(fromFields), StructType(toFields)) =>
+      fromFields.length == toFields.length &&
+        fromFields.zip(toFields).exists {
+          case (fromField, toField) =>
+            needsTimeZone(fromField.dataType, toField.dataType)
+        }
+    case _ => false
+  }
+
   /**
    * Return true iff we may truncate during casting `from` type to `to` type. e.g. long -> int,
    * timestamp -> date.
@@ -165,6 +190,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
   override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
     copy(timeZoneId = Option(timeZoneId))
 
+  // When this cast involves TimeZone, it's only resolved if the timeZoneId is set;
+  // Otherwise behave like Expression.resolved.
+  override lazy val resolved: Boolean =
+    childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined)
+
+  private[this] def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType)
+
   // [[func]] assumes the input is no longer null because eval already does the null check.
   @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T])
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 22f3f3514f..a7ffa884d2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -34,7 +34,7 @@ import org.apache.spark.unsafe.types.UTF8String
  */
 class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
 
-  private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = Some("GMT")): Cast = {
+  private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): Cast = {
     v match {
       case lit: Expression => Cast(lit, targetType, timeZoneId)
       case _ => Cast(Literal(v), targetType, timeZoneId)
@@ -47,7 +47,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
   }
 
   private def checkNullCast(from: DataType, to: DataType): Unit = {
-    checkEvaluation(cast(Literal.create(null, from), to), null)
+    checkEvaluation(cast(Literal.create(null, from), to, Option("GMT")), null)
   }
 
   test("null cast") {
-- 
GitLab