From cad88f17e87e6cb96550b70e35d3ed75305dc59d Mon Sep 17 00:00:00 2001 From: Xingbo Jiang <xingbo.jiang@databricks.com> Date: Wed, 21 Jun 2017 09:40:06 -0700 Subject: [PATCH] [SPARK-17851][SQL][TESTS] Make sure all test sqls in catalyst pass checkAnalysis ## What changes were proposed in this pull request? Currently we have several tens of test sqls in catalyst will fail at `SimpleAnalyzer.checkAnalysis`, we should make sure they are valid. This PR makes the following changes: 1. Apply `checkAnalysis` on plans that tests `Optimizer` rules, but don't require the testcases for `Parser`/`Analyzer` pass `checkAnalysis`; 2. Fix testcases for `Optimizer` that would have fall. ## How was this patch tested? Apply `SimpleAnalyzer.checkAnalysis` on plans in `PlanTest.comparePlans`, update invalid test cases. Author: Xingbo Jiang <xingbo.jiang@databricks.com> Author: jiangxingbo <jiangxb1987@gmail.com> Closes #15417 from jiangxb1987/cptest. --- .../sql/catalyst/analysis/AnalysisTest.scala | 8 +++ .../analysis/DecimalPrecisionSuite.scala | 2 +- .../catalyst/analysis/TypeCoercionSuite.scala | 2 +- .../catalog/SessionCatalogSuite.scala | 2 +- .../optimizer/AggregateOptimizeSuite.scala | 4 +- .../BooleanSimplificationSuite.scala | 57 ++++++++++--------- .../optimizer/ColumnPruningSuite.scala | 4 +- .../optimizer/ConstantPropagationSuite.scala | 9 ++- .../optimizer/FilterPushdownSuite.scala | 11 ++-- .../optimizer/LimitPushdownSuite.scala | 12 ++-- .../optimizer/OptimizeCodegenSuite.scala | 4 +- .../optimizer/OuterJoinEliminationSuite.scala | 4 +- .../optimizer/SimplifyCastsSuite.scala | 9 ++- .../sql/catalyst/parser/PlanParserSuite.scala | 6 +- .../spark/sql/catalyst/plans/PlanTest.scala | 14 ++++- .../apache/spark/sql/DataFrameHintSuite.scala | 4 +- .../sql/execution/SparkSqlParserSuite.scala | 5 +- .../spark/sql/hive/HiveDDLCommandSuite.scala | 20 +++---- 18 files changed, 101 insertions(+), 76 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index edfa8c45f9..549a4355df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -59,6 +59,14 @@ trait AnalysisTest extends PlanTest { comparePlans(actualPlan, expectedPlan) } + protected override def comparePlans( + plan1: LogicalPlan, + plan2: LogicalPlan, + checkAnalysis: Boolean = false): Unit = { + // Analysis tests may have not been fully resolved, so skip checkAnalysis. + super.comparePlans(plan1, plan2, checkAnalysis) + } + protected def assertAnalysisSuccess( inputPlan: LogicalPlan, caseSensitive: Boolean = true): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 8f43171f30..ccf3c3fb09 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project, Unio import org.apache.spark.sql.types._ -class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { +class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) private val analyzer = new Analyzer(catalog, conf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 7358f401ed..b3994ab082 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval -class TypeCoercionSuite extends PlanTest { +class TypeCoercionSuite extends AnalysisTest { // scalastyle:off line.size.limit // The following table shows all implicit data type conversions that are not visible to the user. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index dce73b3635..a6dc21b03d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -44,7 +44,7 @@ class InMemorySessionCatalogSuite extends SessionCatalogSuite { * signatures but do not extend a common parent. This is largely by design but * unfortunately leads to very similar test code in two places. */ -abstract class SessionCatalogSuite extends PlanTest { +abstract class SessionCatalogSuite extends AnalysisTest { protected val utils: CatalogTestUtils protected val isHiveExternalCatalog = false diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index e6132ab2e4..a3184a4266 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -59,9 +59,9 @@ class AggregateOptimizeSuite extends PlanTest { } test("Remove aliased literals") { - val query = testRelation.select('a, Literal(1).as('y)).groupBy('a, 'y)(sum('b)) + val query = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a, 'y)(sum('b)) val optimized = Optimize.execute(analyzer.execute(query)) - val correctAnswer = testRelation.select('a, Literal(1).as('y)).groupBy('a)(sum('b)).analyze + val correctAnswer = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a)(sum('b)).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 1df0a89cf0..c6345b60b7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -41,7 +41,8 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { PruneFilters) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string) + val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string, + 'e.boolean, 'f.boolean, 'g.boolean, 'h.boolean) val testRelationWithData = LocalRelation.fromExternalRows( testRelation.output, Seq(Row(1, 2, 3, "abc")) @@ -101,52 +102,52 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { 'a === 'b || 'b > 3 && 'a > 3 && 'a < 5) } - test("a && (!a || b)") { - checkCondition('a && (!'a || 'b ), 'a && 'b) + test("e && (!e || f)") { + checkCondition('e && (!'e || 'f ), 'e && 'f) - checkCondition('a && ('b || !'a ), 'a && 'b) + checkCondition('e && ('f || !'e ), 'e && 'f) - checkCondition((!'a || 'b ) && 'a, 'b && 'a) + checkCondition((!'e || 'f ) && 'e, 'f && 'e) - checkCondition(('b || !'a ) && 'a, 'b && 'a) + checkCondition(('f || !'e ) && 'e, 'f && 'e) } - test("a < 1 && (!(a < 1) || b)") { - checkCondition('a < 1 && (!('a < 1) || 'b), ('a < 1) && 'b) - checkCondition('a < 1 && ('b || !('a < 1)), ('a < 1) && 'b) + test("a < 1 && (!(a < 1) || f)") { + checkCondition('a < 1 && (!('a < 1) || 'f), ('a < 1) && 'f) + checkCondition('a < 1 && ('f || !('a < 1)), ('a < 1) && 'f) - checkCondition('a <= 1 && (!('a <= 1) || 'b), ('a <= 1) && 'b) - checkCondition('a <= 1 && ('b || !('a <= 1)), ('a <= 1) && 'b) + checkCondition('a <= 1 && (!('a <= 1) || 'f), ('a <= 1) && 'f) + checkCondition('a <= 1 && ('f || !('a <= 1)), ('a <= 1) && 'f) - checkCondition('a > 1 && (!('a > 1) || 'b), ('a > 1) && 'b) - checkCondition('a > 1 && ('b || !('a > 1)), ('a > 1) && 'b) + checkCondition('a > 1 && (!('a > 1) || 'f), ('a > 1) && 'f) + checkCondition('a > 1 && ('f || !('a > 1)), ('a > 1) && 'f) - checkCondition('a >= 1 && (!('a >= 1) || 'b), ('a >= 1) && 'b) - checkCondition('a >= 1 && ('b || !('a >= 1)), ('a >= 1) && 'b) + checkCondition('a >= 1 && (!('a >= 1) || 'f), ('a >= 1) && 'f) + checkCondition('a >= 1 && ('f || !('a >= 1)), ('a >= 1) && 'f) } - test("a < 1 && ((a >= 1) || b)") { - checkCondition('a < 1 && ('a >= 1 || 'b ), ('a < 1) && 'b) - checkCondition('a < 1 && ('b || 'a >= 1), ('a < 1) && 'b) + test("a < 1 && ((a >= 1) || f)") { + checkCondition('a < 1 && ('a >= 1 || 'f ), ('a < 1) && 'f) + checkCondition('a < 1 && ('f || 'a >= 1), ('a < 1) && 'f) - checkCondition('a <= 1 && ('a > 1 || 'b ), ('a <= 1) && 'b) - checkCondition('a <= 1 && ('b || 'a > 1), ('a <= 1) && 'b) + checkCondition('a <= 1 && ('a > 1 || 'f ), ('a <= 1) && 'f) + checkCondition('a <= 1 && ('f || 'a > 1), ('a <= 1) && 'f) - checkCondition('a > 1 && (('a <= 1) || 'b), ('a > 1) && 'b) - checkCondition('a > 1 && ('b || ('a <= 1)), ('a > 1) && 'b) + checkCondition('a > 1 && (('a <= 1) || 'f), ('a > 1) && 'f) + checkCondition('a > 1 && ('f || ('a <= 1)), ('a > 1) && 'f) - checkCondition('a >= 1 && (('a < 1) || 'b), ('a >= 1) && 'b) - checkCondition('a >= 1 && ('b || ('a < 1)), ('a >= 1) && 'b) + checkCondition('a >= 1 && (('a < 1) || 'f), ('a >= 1) && 'f) + checkCondition('a >= 1 && ('f || ('a < 1)), ('a >= 1) && 'f) } test("DeMorgan's law") { - checkCondition(!('a && 'b), !'a || !'b) + checkCondition(!('e && 'f), !'e || !'f) - checkCondition(!('a || 'b), !'a && !'b) + checkCondition(!('e || 'f), !'e && !'f) - checkCondition(!(('a && 'b) || ('c && 'd)), (!'a || !'b) && (!'c || !'d)) + checkCondition(!(('e && 'f) || ('g && 'h)), (!'e || !'f) && (!'g || !'h)) - checkCondition(!(('a || 'b) && ('c || 'd)), (!'a && !'b) || (!'c && !'d)) + checkCondition(!(('e || 'f) && ('g || 'h)), (!'e && !'f) || (!'g && !'h)) } private val caseInsensitiveConf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> false) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index a0a0daea7d..0b419e9631 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -266,8 +266,8 @@ class ColumnPruningSuite extends PlanTest { test("Column pruning on Window with useless aggregate functions") { val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) - val winSpec = windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame) - val winExpr = windowExpr(count('b), winSpec) + val winSpec = windowSpec('a :: Nil, 'd.asc :: Nil, UnspecifiedFrame) + val winExpr = windowExpr(count('d), winSpec) val originalQuery = input.groupBy('a, 'c, 'd)('a, 'c, 'd, winExpr.as('window)).select('a, 'c) val correctAnswer = input.select('a, 'c, 'd).groupBy('a, 'c, 'd)('a, 'c).analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala index 81d2f3667e..94174eec8f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala @@ -35,7 +35,6 @@ class ConstantPropagationSuite extends PlanTest { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("ConstantPropagation", FixedPoint(10), - ColumnPruning, ConstantPropagation, ConstantFolding, BooleanSimplification) :: Nil @@ -43,9 +42,9 @@ class ConstantPropagationSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - private val columnA = 'a.int - private val columnB = 'b.int - private val columnC = 'c.int + private val columnA = 'a + private val columnB = 'b + private val columnC = 'c test("basic test") { val query = testRelation @@ -160,7 +159,7 @@ class ConstantPropagationSuite extends PlanTest { val correctAnswer = testRelation .select(columnA) - .where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)) + .where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)).analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index d4d281e7e0..3553d23560 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -629,14 +629,14 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { testRelationWithArrayType .generate(Explode('c_arr), true, false, Some("arr")) - .where(('b >= 5) && ('a + Rand(10).as("rnd") > 6) && ('c > 6)) + .where(('b >= 5) && ('a + Rand(10).as("rnd") > 6) && ('col > 6)) } val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = { testRelationWithArrayType .where('b >= 5) .generate(Explode('c_arr), true, false, Some("arr")) - .where('a + Rand(10).as("rnd") > 6 && 'c > 6) + .where('a + Rand(10).as("rnd") > 6 && 'col > 6) .analyze } @@ -676,7 +676,7 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = { testRelationWithArrayType .generate(Explode('c_arr), true, false, Some("arr")) - .where(('c > 6) || ('b > 5)).analyze + .where(('col > 6) || ('b > 5)).analyze } val optimized = Optimize.execute(originalQuery) @@ -1129,6 +1129,9 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = x.where("x.a".attr === 5).join(y.where("y.a".attr === 5), condition = Some("x.a".attr === Rand(10) && "y.b".attr === 5)) - comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) + // CheckAnalysis will ensure nondeterministic expressions not appear in join condition. + // TODO support nondeterministic expressions in join condition. + comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze, + checkAnalysis = false) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index 2885fd6841..fb34c82de4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -70,19 +70,21 @@ class LimitPushdownSuite extends PlanTest { } test("Union: no limit to both sides if children having smaller limit values") { - val unionQuery = Union(testRelation.limit(1), testRelation2.select('d).limit(1)).limit(2) + val unionQuery = + Union(testRelation.limit(1), testRelation2.select('d, 'e, 'f).limit(1)).limit(2) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = - Limit(2, Union(testRelation.limit(1), testRelation2.select('d).limit(1))).analyze + Limit(2, Union(testRelation.limit(1), testRelation2.select('d, 'e, 'f).limit(1))).analyze comparePlans(unionOptimized, unionCorrectAnswer) } test("Union: limit to each sides if children having larger limit values") { - val testLimitUnion = Union(testRelation.limit(3), testRelation2.select('d).limit(4)) - val unionQuery = testLimitUnion.limit(2) + val unionQuery = + Union(testRelation.limit(3), testRelation2.select('d, 'e, 'f).limit(4)).limit(2) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = - Limit(2, Union(LocalLimit(2, testRelation), LocalLimit(2, testRelation2.select('d)))).analyze + Limit(2, Union( + LocalLimit(2, testRelation), LocalLimit(2, testRelation2.select('d, 'e, 'f)))).analyze comparePlans(unionOptimized, unionCorrectAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala index f3b65cc797..9dc6738ba0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala @@ -50,10 +50,10 @@ class OptimizeCodegenSuite extends PlanTest { test("Nested CaseWhen Codegen.") { assertEquivalent( CaseWhen( - Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), Literal(3))), + Seq((CaseWhen(Seq((TrueLiteral, TrueLiteral)), FalseLiteral), Literal(3))), CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5))), CaseWhen( - Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), Literal(3))), + Seq((CaseWhen(Seq((TrueLiteral, TrueLiteral)), FalseLiteral).toCodegen(), Literal(3))), CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5)).toCodegen()).toCodegen()) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala index a37bc4bca2..623ff3d446 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -201,7 +201,7 @@ class OuterJoinEliminationSuite extends PlanTest { val originalQuery = x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) - .where(Coalesce("y.e".attr :: "x.a".attr :: Nil)) + .where(Coalesce("y.e".attr :: "x.a".attr :: Nil) === 0) val optimized = Optimize.execute(originalQuery.analyze) @@ -209,7 +209,7 @@ class OuterJoinEliminationSuite extends PlanTest { val right = testRelation1 val correctAnswer = left.join(right, FullOuter, Option("a".attr === "d".attr)) - .where(Coalesce("e".attr :: "a".attr :: Nil)).analyze + .where(Coalesce("e".attr :: "a".attr :: Nil) === 0).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala index e84f11272d..7b3f5b084b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala @@ -44,7 +44,9 @@ class SimplifyCastsSuite extends PlanTest { val input = LocalRelation('a.array(ArrayType(IntegerType, true))) val plan = input.select('a.cast(ArrayType(IntegerType, false)).as("casted")).analyze val optimized = Optimize.execute(plan) - comparePlans(optimized, plan) + // Though cast from `ArrayType(IntegerType, true)` to `ArrayType(IntegerType, false)` is not + // allowed, here we just ensure that `SimplifyCasts` rule respect the plan. + comparePlans(optimized, plan, checkAnalysis = false) } test("non-nullable value map to nullable value map cast") { @@ -61,7 +63,10 @@ class SimplifyCastsSuite extends PlanTest { val plan = input.select('m.cast(MapType(StringType, StringType, false)) .as("casted")).analyze val optimized = Optimize.execute(plan) - comparePlans(optimized, plan) + // Though cast from `MapType(StringType, StringType, true)` to + // `MapType(StringType, StringType, false)` is not allowed, here we just ensure that + // `SimplifyCasts` rule respect the plan. + comparePlans(optimized, plan, checkAnalysis = false) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index fef39a5b6a..0a4ae098d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedTableValuedFunction} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedTableValuedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -29,13 +29,13 @@ import org.apache.spark.sql.types.IntegerType * * There is also SparkSqlParserSuite in sql/core module for parser rules defined in sql/core module. */ -class PlanParserSuite extends PlanTest { +class PlanParserSuite extends AnalysisTest { import CatalystSqlParser._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ private def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { - comparePlans(parsePlan(sqlCommand), plan) + comparePlans(parsePlan(sqlCommand), plan, checkAnalysis = false) } private def intercept(sqlCommand: String, messages: String*): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index f44428c351..25313af2be 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical._ @@ -90,7 +91,16 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { } /** Fails the test if the two plans do not match */ - protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { + protected def comparePlans( + plan1: LogicalPlan, + plan2: LogicalPlan, + checkAnalysis: Boolean = true): Unit = { + if (checkAnalysis) { + // Make sure both plan pass checkAnalysis. + SimpleAnalyzer.checkAnalysis(plan1) + SimpleAnalyzer.checkAnalysis(plan2) + } + val normalized1 = normalizePlan(normalizeExprIds(plan1)) val normalized2 = normalizePlan(normalizeExprIds(plan2)) if (normalized1 != normalized2) { @@ -104,7 +114,7 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { /** Fails the test if the two expressions do not match */ protected def compareExpressions(e1: Expression, e2: Expression): Unit = { - comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation)) + comparePlans(Filter(e1, OneRowRelation), Filter(e2, OneRowRelation), checkAnalysis = false) } /** Fails the test if the join order in the two plans do not match */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala index 60f6f23860..0dd5bdcba2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameHintSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.test.SharedSQLContext -class DataFrameHintSuite extends PlanTest with SharedSQLContext { +class DataFrameHintSuite extends AnalysisTest with SharedSQLContext { import testImplicits._ lazy val df = spark.range(10) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index b32fb90e10..bd9c2ebd6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.expressions.{Ascending, Concat, SortOrder} import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, RepartitionByExpression, Sort} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.CreateTable @@ -36,7 +35,7 @@ import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType * See [[org.apache.spark.sql.catalyst.parser.PlanParserSuite]] for rules * defined in the Catalyst module. */ -class SparkSqlParserSuite extends PlanTest { +class SparkSqlParserSuite extends AnalysisTest { val newConf = new SQLConf private lazy val parser = new SparkSqlParser(newConf) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index d97b11e447..bee470d8e1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan import org.apache.spark.sql.catalyst.expressions.JsonTuple import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} +import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, ScriptTransformation} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} @@ -59,6 +59,11 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle }.head } + private def compareTransformQuery(sql: String, expected: LogicalPlan): Unit = { + val plan = parser.parsePlan(sql).asInstanceOf[ScriptTransformation].copy(ioschema = null) + comparePlans(plan, expected, checkAnalysis = false) + } + test("Test CTAS #1") { val s1 = """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view @@ -253,22 +258,15 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle } test("transform query spec") { - val plan1 = parser.parsePlan("select transform(a, b) using 'func' from e where f < 10") - .asInstanceOf[ScriptTransformation].copy(ioschema = null) - val plan2 = parser.parsePlan("map a, b using 'func' as c, d from e") - .asInstanceOf[ScriptTransformation].copy(ioschema = null) - val plan3 = parser.parsePlan("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e") - .asInstanceOf[ScriptTransformation].copy(ioschema = null) - val p = ScriptTransformation( Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), "func", Seq.empty, plans.table("e"), null) - comparePlans(plan1, + compareTransformQuery("select transform(a, b) using 'func' from e where f < 10", p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) - comparePlans(plan2, + compareTransformQuery("map a, b using 'func' as c, d from e", p.copy(output = Seq('c.string, 'd.string))) - comparePlans(plan3, + compareTransformQuery("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e", p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) } -- GitLab