Skip to content
Snippets Groups Projects
Commit 205e6d58 authored by Cheng Lian's avatar Cheng Lian Committed by Yin Huai
Browse files

[SPARK-18338][SQL][TEST-MAVEN] Fix test case initialization order under Maven builds

## What changes were proposed in this pull request?

Test case initialization order under Maven and SBT are different. Maven always creates instances of all test cases and then run them all together.

This fails `ObjectHashAggregateSuite` because the randomized test cases there register a temporary Hive function right before creating a test case, and can be cleared while initializing other successive test cases. In SBT, this is fine since the created test case is executed immediately after creating the temporary function.

To fix this issue, we should put initialization/destruction code into `beforeAll()` and `afterAll()`.

## How was this patch tested?

Existing tests.

Author: Cheng Lian <lian@databricks.com>

Closes #15802 from liancheng/fix-flaky-object-hash-agg-suite.
parent 02c5325b
No related branches found
No related tags found
No related merge requests found
...@@ -25,11 +25,10 @@ import org.scalatest.Matchers._ ...@@ -25,11 +25,10 @@ import org.scalatest.Matchers._
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction
import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, ExpressionInfo, Literal} import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.HiveSessionCatalog
import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.test.SQLTestUtils
...@@ -43,6 +42,14 @@ class ObjectHashAggregateSuite ...@@ -43,6 +42,14 @@ class ObjectHashAggregateSuite
import testImplicits._ import testImplicits._
protected override def beforeAll(): Unit = {
sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'")
}
protected override def afterAll(): Unit = {
sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max")
}
test("typed_count without grouping keys") { test("typed_count without grouping keys") {
val df = Seq((1: Integer, 2), (null, 2), (3: Integer, 4)).toDF("a", "b") val df = Seq((1: Integer, 2), (null, 2), (3: Integer, 4)).toDF("a", "b")
...@@ -199,10 +206,7 @@ class ObjectHashAggregateSuite ...@@ -199,10 +206,7 @@ class ObjectHashAggregateSuite
val typed = percentile_approx($"c0", 0.5) val typed = percentile_approx($"c0", 0.5)
// A Hive UDAF without partial aggregation support // A Hive UDAF without partial aggregation support
val withoutPartial = { val withoutPartial = function("hive_max", $"c1")
registerHiveFunction("hive_max", classOf[GenericUDAFMax])
function("hive_max", $"c1")
}
// A Spark SQL native aggregate function with partial aggregation support that can be executed // A Spark SQL native aggregate function with partial aggregation support that can be executed
// by the Tungsten `HashAggregateExec` // by the Tungsten `HashAggregateExec`
...@@ -420,13 +424,6 @@ class ObjectHashAggregateSuite ...@@ -420,13 +424,6 @@ class ObjectHashAggregateSuite
} }
} }
private def registerHiveFunction(functionName: String, clazz: Class[_]): Unit = {
val sessionCatalog = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog]
val builder = sessionCatalog.makeFunctionBuilder(functionName, clazz.getName)
val info = new ExpressionInfo(clazz.getName, functionName)
sessionCatalog.createTempFunction(functionName, info, builder, ignoreIfExists = false)
}
private def function(name: String, args: Column*): Column = { private def function(name: String, args: Column*): Column = {
Column(UnresolvedFunction(FunctionIdentifier(name), args.map(_.expr), isDistinct = false)) Column(UnresolvedFunction(FunctionIdentifier(name), args.map(_.expr), isDistinct = false))
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment