Skip to content
Snippets Groups Projects
Commit 2432c2e2 authored by Wenchen Fan's avatar Wenchen Fan Committed by Michael Armbrust
Browse files

[SPARK-8382] [SQL] Improve Analysis Unit test framework

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #8025 from cloud-fan/analysis and squashes the following commits:

51461b1 [Wenchen Fan] move test file to test folder
ec88ace [Wenchen Fan] Improve Analysis Unit test framework
parent 76eaa701
No related branches found
No related tags found
No related merge requests found
...@@ -42,8 +42,8 @@ case class UnresolvedTestPlan() extends LeafNode { ...@@ -42,8 +42,8 @@ case class UnresolvedTestPlan() extends LeafNode {
override def output: Seq[Attribute] = Nil override def output: Seq[Attribute] = Nil
} }
class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { class AnalysisErrorSuite extends AnalysisTest with BeforeAndAfter {
import AnalysisSuite._ import TestRelations._
def errorTest( def errorTest(
name: String, name: String,
...@@ -51,15 +51,7 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { ...@@ -51,15 +51,7 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter {
errorMessages: Seq[String], errorMessages: Seq[String],
caseSensitive: Boolean = true): Unit = { caseSensitive: Boolean = true): Unit = {
test(name) { test(name) {
val error = intercept[AnalysisException] { assertAnalysisError(plan, errorMessages, caseSensitive)
if (caseSensitive) {
caseSensitiveAnalyze(plan)
} else {
caseInsensitiveAnalyze(plan)
}
}
errorMessages.foreach(m => assert(error.getMessage.toLowerCase.contains(m.toLowerCase)))
} }
} }
...@@ -69,21 +61,21 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { ...@@ -69,21 +61,21 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter {
"single invalid type, single arg", "single invalid type, single arg",
testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)), testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)),
"cannot resolve" :: "testfunction" :: "argument 1" :: "requires int type" :: "cannot resolve" :: "testfunction" :: "argument 1" :: "requires int type" ::
"'null' is of date type" ::Nil) "'null' is of date type" :: Nil)
errorTest( errorTest(
"single invalid type, second arg", "single invalid type, second arg",
testRelation.select( testRelation.select(
TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)), TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)),
"cannot resolve" :: "testfunction" :: "argument 2" :: "requires int type" :: "cannot resolve" :: "testfunction" :: "argument 2" :: "requires int type" ::
"'null' is of date type" ::Nil) "'null' is of date type" :: Nil)
errorTest( errorTest(
"multiple invalid type", "multiple invalid type",
testRelation.select( testRelation.select(
TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)), TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)),
"cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" :: "cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" ::
"requires int type" :: "'null' is of date type" ::Nil) "requires int type" :: "'null' is of date type" :: Nil)
errorTest( errorTest(
"unresolved window function", "unresolved window function",
...@@ -169,11 +161,7 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { ...@@ -169,11 +161,7 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter {
assert(plan.resolved) assert(plan.resolved)
val message = intercept[AnalysisException] { assertAnalysisError(plan, "resolved attribute(s) a#1 missing from a#2" :: Nil)
caseSensitiveAnalyze(plan)
}.getMessage
assert(message.contains("resolved attribute(s) a#1 missing from a#2"))
} }
test("error test for self-join") { test("error test for self-join") {
...@@ -194,10 +182,8 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { ...@@ -194,10 +182,8 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter {
AttributeReference("a", BinaryType)(exprId = ExprId(2)), AttributeReference("a", BinaryType)(exprId = ExprId(2)),
AttributeReference("b", IntegerType)(exprId = ExprId(1)))) AttributeReference("b", IntegerType)(exprId = ExprId(1))))
val error = intercept[AnalysisException] { assertAnalysisError(plan,
caseSensitiveAnalyze(plan) "binary type expression a cannot be used in grouping expression" :: Nil)
}
assert(error.message.contains("binary type expression a cannot be used in grouping expression"))
val plan2 = val plan2 =
Aggregate( Aggregate(
...@@ -207,10 +193,8 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { ...@@ -207,10 +193,8 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter {
AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
AttributeReference("b", IntegerType)(exprId = ExprId(1)))) AttributeReference("b", IntegerType)(exprId = ExprId(1))))
val error2 = intercept[AnalysisException] { assertAnalysisError(plan2,
caseSensitiveAnalyze(plan2) "map type expression a cannot be used in grouping expression" :: Nil)
}
assert(error2.message.contains("map type expression a cannot be used in grouping expression"))
} }
test("Join can't work on binary and map types") { test("Join can't work on binary and map types") {
...@@ -226,10 +210,7 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { ...@@ -226,10 +210,7 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter {
Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)), Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)),
AttributeReference("c", BinaryType)(exprId = ExprId(4))))) AttributeReference("c", BinaryType)(exprId = ExprId(4)))))
val error = intercept[AnalysisException] { assertAnalysisError(plan, "binary type expression a cannot be used in join conditions" :: Nil)
caseSensitiveAnalyze(plan)
}
assert(error.message.contains("binary type expression a cannot be used in join conditions"))
val plan2 = val plan2 =
Join( Join(
...@@ -243,9 +224,6 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { ...@@ -243,9 +224,6 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter {
Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4))))) AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)))))
val error2 = intercept[AnalysisException] { assertAnalysisError(plan2, "map type expression a cannot be used in join conditions" :: Nil)
caseSensitiveAnalyze(plan2)
}
assert(error2.message.contains("map type expression a cannot be used in join conditions"))
} }
} }
...@@ -24,61 +24,8 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf ...@@ -24,61 +24,8 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.plans._
// todo: remove this and use AnalysisTest instead.
object AnalysisSuite {
val caseSensitiveConf = new SimpleCatalystConf(true)
val caseInsensitiveConf = new SimpleCatalystConf(false)
val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf)
val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf)
val caseSensitiveAnalyzer =
new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) {
override val extendedResolutionRules = EliminateSubQueries :: Nil
}
val caseInsensitiveAnalyzer =
new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseInsensitiveConf) {
override val extendedResolutionRules = EliminateSubQueries :: Nil
}
def caseSensitiveAnalyze(plan: LogicalPlan): Unit =
caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer.execute(plan))
def caseInsensitiveAnalyze(plan: LogicalPlan): Unit =
caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer.execute(plan))
val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
val testRelation2 = LocalRelation(
AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", DoubleType)(),
AttributeReference("d", DecimalType(10, 2))(),
AttributeReference("e", ShortType)())
val nestedRelation = LocalRelation(
AttributeReference("top", StructType(
StructField("duplicateField", StringType) ::
StructField("duplicateField", StringType) ::
StructField("differentCase", StringType) ::
StructField("differentcase", StringType) :: Nil
))())
val nestedRelation2 = LocalRelation(
AttributeReference("top", StructType(
StructField("aField", StringType) ::
StructField("bField", StringType) ::
StructField("cField", StringType) :: Nil
))())
val listRelation = LocalRelation(
AttributeReference("list", ArrayType(IntegerType))())
caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
}
class AnalysisSuite extends AnalysisTest { class AnalysisSuite extends AnalysisTest {
import TestRelations._
test("union project *") { test("union project *") {
val plan = (1 to 100) val plan = (1 to 100)
......
...@@ -17,40 +17,11 @@ ...@@ -17,40 +17,11 @@
package org.apache.spark.sql.catalyst.analysis package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.types._
trait AnalysisTest extends PlanTest { trait AnalysisTest extends PlanTest {
val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
val testRelation2 = LocalRelation(
AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", DoubleType)(),
AttributeReference("d", DecimalType(10, 2))(),
AttributeReference("e", ShortType)())
val nestedRelation = LocalRelation(
AttributeReference("top", StructType(
StructField("duplicateField", StringType) ::
StructField("duplicateField", StringType) ::
StructField("differentCase", StringType) ::
StructField("differentcase", StringType) :: Nil
))())
val nestedRelation2 = LocalRelation(
AttributeReference("top", StructType(
StructField("aField", StringType) ::
StructField("bField", StringType) ::
StructField("cField", StringType) :: Nil
))())
val listRelation = LocalRelation(
AttributeReference("list", ArrayType(IntegerType))())
val (caseSensitiveAnalyzer, caseInsensitiveAnalyzer) = { val (caseSensitiveAnalyzer, caseInsensitiveAnalyzer) = {
val caseSensitiveConf = new SimpleCatalystConf(true) val caseSensitiveConf = new SimpleCatalystConf(true)
...@@ -59,8 +30,8 @@ trait AnalysisTest extends PlanTest { ...@@ -59,8 +30,8 @@ trait AnalysisTest extends PlanTest {
val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf) val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf)
val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf) val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf)
caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) caseSensitiveCatalog.registerTable(Seq("TaBlE"), TestRelations.testRelation)
caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) caseInsensitiveCatalog.registerTable(Seq("TaBlE"), TestRelations.testRelation)
new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) { new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) {
override val extendedResolutionRules = EliminateSubQueries :: Nil override val extendedResolutionRules = EliminateSubQueries :: Nil
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._
object TestRelations {
val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
val testRelation2 = LocalRelation(
AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", DoubleType)(),
AttributeReference("d", DecimalType(10, 2))(),
AttributeReference("e", ShortType)())
val nestedRelation = LocalRelation(
AttributeReference("top", StructType(
StructField("duplicateField", StringType) ::
StructField("duplicateField", StringType) ::
StructField("differentCase", StringType) ::
StructField("differentcase", StringType) :: Nil
))())
val nestedRelation2 = LocalRelation(
AttributeReference("top", StructType(
StructField("aField", StringType) ::
StructField("bField", StringType) ::
StructField("cField", StringType) :: Nil
))())
val listRelation = LocalRelation(
AttributeReference("list", ArrayType(IntegerType))())
}
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
package org.apache.spark.sql.catalyst.optimizer package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis.{AnalysisSuite, EliminateSubQueries} import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.PlanTest
...@@ -88,20 +89,24 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { ...@@ -88,20 +89,24 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
('a === 'b || 'b > 3 && 'a > 3 && 'a < 5)) ('a === 'b || 'b > 3 && 'a > 3 && 'a < 5))
} }
private def caseInsensitiveAnalyse(plan: LogicalPlan) = private val caseInsensitiveAnalyzer =
AnalysisSuite.caseInsensitiveAnalyzer.execute(plan) new Analyzer(EmptyCatalog, EmptyFunctionRegistry, new SimpleCatalystConf(false))
test("(a && b) || (a && c) => a && (b || c) when case insensitive") { test("(a && b) || (a && c) => a && (b || c) when case insensitive") {
val plan = caseInsensitiveAnalyse(testRelation.where(('a > 2 && 'b > 3) || ('A > 2 && 'b < 5))) val plan = caseInsensitiveAnalyzer.execute(
testRelation.where(('a > 2 && 'b > 3) || ('A > 2 && 'b < 5)))
val actual = Optimize.execute(plan) val actual = Optimize.execute(plan)
val expected = caseInsensitiveAnalyse(testRelation.where('a > 2 && ('b > 3 || 'b < 5))) val expected = caseInsensitiveAnalyzer.execute(
testRelation.where('a > 2 && ('b > 3 || 'b < 5)))
comparePlans(actual, expected) comparePlans(actual, expected)
} }
test("(a || b) && (a || c) => a || (b && c) when case insensitive") { test("(a || b) && (a || c) => a || (b && c) when case insensitive") {
val plan = caseInsensitiveAnalyse(testRelation.where(('a > 2 || 'b > 3) && ('A > 2 || 'b < 5))) val plan = caseInsensitiveAnalyzer.execute(
testRelation.where(('a > 2 || 'b > 3) && ('A > 2 || 'b < 5)))
val actual = Optimize.execute(plan) val actual = Optimize.execute(plan)
val expected = caseInsensitiveAnalyse(testRelation.where('a > 2 || ('b > 3 && 'b < 5))) val expected = caseInsensitiveAnalyzer.execute(
testRelation.where('a > 2 || ('b > 3 && 'b < 5)))
comparePlans(actual, expected) comparePlans(actual, expected)
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment