Skip to content
Snippets Groups Projects
Commit 98bcc188 authored by Liang-Chi Hsieh's avatar Liang-Chi Hsieh Committed by Herman van Hovell
Browse files

[SPARK-19758][SQL] Resolving timezone aware expressions with time zone when resolving inline table

## What changes were proposed in this pull request?

When we resolve inline tables in analyzer, we will evaluate the expressions of inline tables.

When it evaluates a `TimeZoneAwareExpression` expression, an error will happen because the `TimeZoneAwareExpression` is not associated with timezone yet.

So we need to resolve these `TimeZoneAwareExpression`s with time zone when resolving inline tables.

## How was this patch tested?

Jenkins tests.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #17114 from viirya/resolve-timeawareexpr-inline-table.
parent 776fac39
No related branches found
No related tags found
No related merge requests found
...@@ -146,7 +146,7 @@ class Analyzer( ...@@ -146,7 +146,7 @@ class Analyzer(
GlobalAggregates :: GlobalAggregates ::
ResolveAggregateFunctions :: ResolveAggregateFunctions ::
TimeWindowing :: TimeWindowing ::
ResolveInlineTables :: ResolveInlineTables(conf) ::
TypeCoercion.typeCoercionRules ++ TypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*), extendedResolutionRules : _*),
Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
......
...@@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis ...@@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis
import scala.util.control.NonFatal import scala.util.control.NonFatal
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.{CatalystConf, InternalRow}
import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.expressions.{Cast, TimeZoneAwareExpression}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.sql.types.{StructField, StructType}
...@@ -28,7 +28,7 @@ import org.apache.spark.sql.types.{StructField, StructType} ...@@ -28,7 +28,7 @@ import org.apache.spark.sql.types.{StructField, StructType}
/** /**
* An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]].
*/ */
object ResolveInlineTables extends Rule[LogicalPlan] { case class ResolveInlineTables(conf: CatalystConf) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case table: UnresolvedInlineTable if table.expressionsResolved => case table: UnresolvedInlineTable if table.expressionsResolved =>
validateInputDimension(table) validateInputDimension(table)
...@@ -95,11 +95,15 @@ object ResolveInlineTables extends Rule[LogicalPlan] { ...@@ -95,11 +95,15 @@ object ResolveInlineTables extends Rule[LogicalPlan] {
InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) => InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) =>
val targetType = fields(ci).dataType val targetType = fields(ci).dataType
try { try {
if (e.dataType.sameType(targetType)) { val castedExpr = if (e.dataType.sameType(targetType)) {
e.eval() e
} else { } else {
Cast(e, targetType).eval() Cast(e, targetType)
} }
castedExpr.transform {
case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty =>
e.withTimeZone(conf.sessionLocalTimeZone)
}.eval()
} catch { } catch {
case NonFatal(ex) => case NonFatal(ex) =>
table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")
......
...@@ -20,68 +20,67 @@ package org.apache.spark.sql.catalyst.analysis ...@@ -20,68 +20,67 @@ package org.apache.spark.sql.catalyst.analysis
import org.scalatest.BeforeAndAfter import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Literal, Rand} import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.types.{LongType, NullType, TimestampType}
import org.apache.spark.sql.types.{LongType, NullType}
/** /**
* Unit tests for [[ResolveInlineTables]]. Note that there are also test cases defined in * Unit tests for [[ResolveInlineTables]]. Note that there are also test cases defined in
* end-to-end tests (in sql/core module) for verifying the correct error messages are shown * end-to-end tests (in sql/core module) for verifying the correct error messages are shown
* in negative cases. * in negative cases.
*/ */
class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter { class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter {
private def lit(v: Any): Literal = Literal(v) private def lit(v: Any): Literal = Literal(v)
test("validate inputs are foldable") { test("validate inputs are foldable") {
ResolveInlineTables.validateInputEvaluable( ResolveInlineTables(conf).validateInputEvaluable(
UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1))))) UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)))))
// nondeterministic (rand) should not work // nondeterministic (rand) should not work
intercept[AnalysisException] { intercept[AnalysisException] {
ResolveInlineTables.validateInputEvaluable( ResolveInlineTables(conf).validateInputEvaluable(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1))))) UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1)))))
} }
// aggregate should not work // aggregate should not work
intercept[AnalysisException] { intercept[AnalysisException] {
ResolveInlineTables.validateInputEvaluable( ResolveInlineTables(conf).validateInputEvaluable(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1)))))) UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1))))))
} }
// unresolved attribute should not work // unresolved attribute should not work
intercept[AnalysisException] { intercept[AnalysisException] {
ResolveInlineTables.validateInputEvaluable( ResolveInlineTables(conf).validateInputEvaluable(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A"))))) UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A")))))
} }
} }
test("validate input dimensions") { test("validate input dimensions") {
ResolveInlineTables.validateInputDimension( ResolveInlineTables(conf).validateInputDimension(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2))))) UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2)))))
// num alias != data dimension // num alias != data dimension
intercept[AnalysisException] { intercept[AnalysisException] {
ResolveInlineTables.validateInputDimension( ResolveInlineTables(conf).validateInputDimension(
UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2))))) UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2)))))
} }
// num alias == data dimension, but data themselves are inconsistent // num alias == data dimension, but data themselves are inconsistent
intercept[AnalysisException] { intercept[AnalysisException] {
ResolveInlineTables.validateInputDimension( ResolveInlineTables(conf).validateInputDimension(
UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22))))) UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22)))))
} }
} }
test("do not fire the rule if not all expressions are resolved") { test("do not fire the rule if not all expressions are resolved") {
val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A")))) val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A"))))
assert(ResolveInlineTables(table) == table) assert(ResolveInlineTables(conf)(table) == table)
} }
test("convert") { test("convert") {
val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
val converted = ResolveInlineTables.convert(table) val converted = ResolveInlineTables(conf).convert(table)
assert(converted.output.map(_.dataType) == Seq(LongType)) assert(converted.output.map(_.dataType) == Seq(LongType))
assert(converted.data.size == 2) assert(converted.data.size == 2)
...@@ -89,13 +88,24 @@ class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter { ...@@ -89,13 +88,24 @@ class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter {
assert(converted.data(1).getLong(0) == 2L) assert(converted.data(1).getLong(0) == 2L)
} }
test("convert TimeZoneAwareExpression") {
val table = UnresolvedInlineTable(Seq("c1"),
Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType))))
val converted = ResolveInlineTables(conf).convert(table)
val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType)
.withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long]
assert(converted.output.map(_.dataType) == Seq(TimestampType))
assert(converted.data.size == 1)
assert(converted.data(0).getLong(0) == correct)
}
test("nullability inference in convert") { test("nullability inference in convert") {
val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
val converted1 = ResolveInlineTables.convert(table1) val converted1 = ResolveInlineTables(conf).convert(table1)
assert(!converted1.schema.fields(0).nullable) assert(!converted1.schema.fields(0).nullable)
val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType)))) val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType))))
val converted2 = ResolveInlineTables.convert(table2) val converted2 = ResolveInlineTables(conf).convert(table2)
assert(converted2.schema.fields(0).nullable) assert(converted2.schema.fields(0).nullable)
} }
} }
...@@ -46,3 +46,6 @@ select * from values ("one", random_not_exist_func(1)), ("two", 2) as data(a, b) ...@@ -46,3 +46,6 @@ select * from values ("one", random_not_exist_func(1)), ("two", 2) as data(a, b)
-- error reporting: aggregate expression -- error reporting: aggregate expression
select * from values ("one", count(1)), ("two", 2) as data(a, b); select * from values ("one", count(1)), ("two", 2) as data(a, b);
-- string to timestamp
select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991-12-06 01:00:00.0'), timestamp('1991-12-06 12:00:00.0'))) as data(a, b);
-- Automatically generated by SQLQueryTestSuite -- Automatically generated by SQLQueryTestSuite
-- Number of queries: 16 -- Number of queries: 17
-- !query 0 -- !query 0
...@@ -143,3 +143,11 @@ struct<> ...@@ -143,3 +143,11 @@ struct<>
-- !query 15 output -- !query 15 output
org.apache.spark.sql.AnalysisException org.apache.spark.sql.AnalysisException
cannot evaluate expression count(1) in inline table definition; line 1 pos 29 cannot evaluate expression count(1) in inline table definition; line 1 pos 29
-- !query 16
select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991-12-06 01:00:00.0'), timestamp('1991-12-06 12:00:00.0'))) as data(a, b)
-- !query 16 schema
struct<a:timestamp,b:array<timestamp>>
-- !query 16 output
1991-12-06 00:00:00 [1991-12-06 01:00:00.0,1991-12-06 12:00:00.0]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment