diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala index 527626beeb596fad279388ddf07a05d96c7f8266..93fc5e8a5e37628e39d93817da6fab5aa755ce65 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ObjectHashAggregateSuite.scala @@ -25,11 +25,10 @@ import org.scalatest.Matchers._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction -import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, ExpressionInfo, Literal} +import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.HiveSessionCatalog import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -43,6 +42,14 @@ class ObjectHashAggregateSuite import testImplicits._ + protected override def beforeAll(): Unit = { + sql(s"CREATE TEMPORARY FUNCTION hive_max AS '${classOf[GenericUDAFMax].getName}'") + } + + protected override def afterAll(): Unit = { + sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_max") + } + test("typed_count without grouping keys") { val df = Seq((1: Integer, 2), (null, 2), (3: Integer, 4)).toDF("a", "b") @@ -199,10 +206,7 @@ class ObjectHashAggregateSuite val typed = percentile_approx($"c0", 0.5) // A Hive UDAF without partial aggregation support - val withoutPartial = { - registerHiveFunction("hive_max", classOf[GenericUDAFMax]) - function("hive_max", $"c1") - } + val withoutPartial = function("hive_max", $"c1") // A Spark SQL native aggregate function with partial aggregation support that can be executed // by the Tungsten `HashAggregateExec` @@ -420,13 +424,6 @@ class ObjectHashAggregateSuite } } - private def registerHiveFunction(functionName: String, clazz: Class[_]): Unit = { - val sessionCatalog = spark.sessionState.catalog.asInstanceOf[HiveSessionCatalog] - val builder = sessionCatalog.makeFunctionBuilder(functionName, clazz.getName) - val info = new ExpressionInfo(clazz.getName, functionName) - sessionCatalog.createTempFunction(functionName, info, builder, ignoreIfExists = false) - } - private def function(name: String, args: Column*): Column = { Column(UnresolvedFunction(FunctionIdentifier(name), args.map(_.expr), isDistinct = false)) }