Skip to content
Snippets Groups Projects
Commit 29b1f6b0 authored by gatorsmile's avatar gatorsmile
Browse files

[SPARK-21256][SQL] Add withSQLConf to Catalyst Test

### What changes were proposed in this pull request?
SQLConf is moved to Catalyst. We are adding more and more test cases for verifying the conf-specific behaviors. It is nice to add a helper function to simplify the test cases.

### How was this patch tested?
N/A

Author: gatorsmile <gatorsmile@gmail.com>

Closes #18469 from gatorsmile/withSQLConf.
parent d492cc5a
No related branches found
No related tags found
No related merge requests found
Showing
with 64 additions and 63 deletions
......@@ -206,13 +206,10 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
}
test("No inferred filter when constraint propagation is disabled") {
try {
SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false)
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") {
val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, originalQuery)
} finally {
SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED)
}
}
}
......@@ -234,9 +234,7 @@ class OuterJoinEliminationSuite extends PlanTest {
}
test("no outer join elimination if constraint propagation is disabled") {
try {
SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false)
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") {
val x = testRelation.subquery('x)
val y = testRelation1.subquery('y)
......@@ -251,8 +249,6 @@ class OuterJoinEliminationSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery.analyze)
comparePlans(optimized, originalQuery.analyze)
} finally {
SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED)
}
}
}
......@@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.CONSTRAINT_PROPAGATION_ENABLED
class PruneFiltersSuite extends PlanTest {
......@@ -149,8 +148,7 @@ class PruneFiltersSuite extends PlanTest {
("tr1.a".attr > 10 || "tr1.c".attr < 10) &&
'd.attr < 100)
SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false)
try {
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") {
val optimized = Optimize.execute(queryWithUselessFilter.analyze)
// When constraint propagation is disabled, the useless filter won't be pruned.
// It gets pushed down. Because the rule `CombineFilters` runs only once, there are redundant
......@@ -160,8 +158,6 @@ class PruneFiltersSuite extends PlanTest {
.join(tr2.where('d.attr < 100).where('d.attr < 100),
Inner, Some("tr1.a".attr === "tr2.a".attr)).analyze
comparePlans(optimized, correctAnswer)
} finally {
SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED)
}
}
}
......@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType}
class ConstraintPropagationSuite extends SparkFunSuite {
class ConstraintPropagationSuite extends SparkFunSuite with PlanTest {
private def resolveColumn(tr: LocalRelation, columnName: String): Expression =
resolveColumn(tr.analyze, columnName)
......@@ -400,26 +400,26 @@ class ConstraintPropagationSuite extends SparkFunSuite {
}
test("enable/disable constraint propagation") {
try {
val tr = LocalRelation('a.int, 'b.string, 'c.int)
val filterRelation = tr.where('a.attr > 10)
val tr = LocalRelation('a.int, 'b.string, 'c.int)
val filterRelation = tr.where('a.attr > 10)
SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true)
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "true") {
assert(filterRelation.analyze.constraints.nonEmpty)
}
SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false)
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") {
assert(filterRelation.analyze.constraints.isEmpty)
}
val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5)
.groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3)
val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5)
.groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3)
SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, true)
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "true") {
assert(aliasedRelation.analyze.constraints.nonEmpty)
}
SQLConf.get.setConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED, false)
withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") {
assert(aliasedRelation.analyze.constraints.isEmpty)
} finally {
SQLConf.get.unsetConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED)
}
}
}
......@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.plans
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
......@@ -28,8 +29,9 @@ import org.apache.spark.sql.internal.SQLConf
/**
* Provides helper methods for comparing plans.
*/
abstract class PlanTest extends SparkFunSuite with PredicateHelper {
trait PlanTest extends SparkFunSuite with PredicateHelper {
// TODO(gatorsmile): remove this from PlanTest and all the analyzer/optimizer rules
protected val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)
/**
......@@ -142,4 +144,32 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
plan1 == plan2
}
}
/**
* Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL
* configurations.
*/
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
val conf = SQLConf.get
val (keys, values) = pairs.unzip
val currentValues = keys.map { key =>
if (conf.contains(key)) {
Some(conf.getConfString(key))
} else {
None
}
}
(keys, values).zipped.foreach { (k, v) =>
if (SQLConf.staticConfKeys.contains(k)) {
throw new AnalysisException(s"Cannot modify the value of a static config: $k")
}
conf.setConfString(k, v)
}
try f finally {
keys.zip(currentValues).foreach {
case (key, Some(value)) => conf.setConfString(key, value)
case (key, None) => conf.unsetConf(key)
}
}
}
}
......@@ -19,12 +19,13 @@ package org.apache.spark.sql.catalyst.statsEstimation
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
import org.apache.spark.sql.internal.SQLConf
class AggregateEstimationSuite extends StatsEstimationTestBase {
class AggregateEstimationSuite extends StatsEstimationTestBase with PlanTest {
/** Columns for testing */
private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq(
......@@ -100,9 +101,7 @@ class AggregateEstimationSuite extends StatsEstimationTestBase {
size = Some(4 * (8 + 4)),
attributeStats = AttributeMap(Seq("key12").map(nameToColInfo)))
val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
try {
SQLConf.get.setConf(SQLConf.CBO_ENABLED, false)
withSQLConf(SQLConf.CBO_ENABLED.key -> "false") {
val noGroupAgg = Aggregate(groupingExpressions = Nil,
aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child)
assert(noGroupAgg.stats ==
......@@ -114,8 +113,6 @@ class AggregateEstimationSuite extends StatsEstimationTestBase {
assert(hasGroupAgg.stats ==
// From UnaryNode.computeStats, childSize * outputRowSize / childRowSize
Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4)))
} finally {
SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)
}
}
......
......@@ -18,12 +18,13 @@
package org.apache.spark.sql.catalyst.statsEstimation
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType
class BasicStatsEstimationSuite extends StatsEstimationTestBase {
class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase {
val attribute = attr("key")
val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
nullCount = 0, avgLen = 4, maxLen = 4)
......@@ -82,18 +83,15 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
plan: LogicalPlan,
expectedStatsCboOn: Statistics,
expectedStatsCboOff: Statistics): Unit = {
val originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
try {
withSQLConf(SQLConf.CBO_ENABLED.key -> "true") {
// Invalidate statistics
plan.invalidateStatsCache()
SQLConf.get.setConf(SQLConf.CBO_ENABLED, true)
assert(plan.stats == expectedStatsCboOn)
}
withSQLConf(SQLConf.CBO_ENABLED.key -> "false") {
plan.invalidateStatsCache()
SQLConf.get.setConf(SQLConf.CBO_ENABLED, false)
assert(plan.stats == expectedStatsCboOff)
} finally {
SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)
}
}
......
......@@ -18,6 +18,7 @@
package org.apache.spark.sql
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.sql.internal.SQLConf
/**
* Test cases for the builder pattern of [[SparkSession]].
......@@ -67,6 +68,8 @@ class SparkSessionBuilderSuite extends SparkFunSuite {
assert(activeSession != defaultSession)
assert(session == activeSession)
assert(session.conf.get("spark-config2") == "a")
assert(session.sessionState.conf == SQLConf.get)
assert(SQLConf.get.getConfString("spark-config2") == "a")
SparkSession.clearActiveSession()
assert(SparkSession.builder().getOrCreate() == defaultSession)
......
......@@ -35,9 +35,11 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.FilterExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.{UninterruptibleThread, Utils}
/**
......@@ -53,7 +55,8 @@ import org.apache.spark.util.{UninterruptibleThread, Utils}
private[sql] trait SQLTestUtils
extends SparkFunSuite with Eventually
with BeforeAndAfterAll
with SQLTestData { self =>
with SQLTestData
with PlanTest { self =>
protected def sparkContext = spark.sparkContext
......@@ -89,28 +92,9 @@ private[sql] trait SQLTestUtils
}
}
/**
* Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL
* configurations.
*
* @todo Probably this method should be moved to a more general place
*/
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
val (keys, values) = pairs.unzip
val currentValues = keys.map { key =>
if (spark.conf.contains(key)) {
Some(spark.conf.get(key))
} else {
None
}
}
(keys, values).zipped.foreach(spark.conf.set)
try f finally {
keys.zip(currentValues).foreach {
case (key, Some(value)) => spark.conf.set(key, value)
case (key, None) => spark.conf.unset(key)
}
}
protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
SparkSession.setActiveSession(spark)
super.withSQLConf(pairs: _*)(f)
}
/**
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment