Skip to content
Snippets Groups Projects
Commit 01ff0350 authored by Xiao Li's avatar Xiao Li
Browse files

[SPARK-20349][SQL] ListFunctions returns duplicate functions after using persistent functions

### What changes were proposed in this pull request?
The session catalog caches some persistent functions in the `FunctionRegistry`, so there can be duplicates. Our Catalog API `listFunctions` does not handle it.

It would be better if `SessionCatalog` API can de-duplciate the records, instead of doing it by each API caller. In `FunctionRegistry`, our functions are identified by the unquoted string. Thus, this PR is try to parse it using our parser interface and then de-duplicate the names.

### How was this patch tested?
Added test cases.

Author: Xiao Li <gatorsmile@gmail.com>

Closes #17646 from gatorsmile/showFunctions.
parent 24f09b39
No related branches found
No related tags found
No related merge requests found
...@@ -22,6 +22,7 @@ import java.util.Locale ...@@ -22,6 +22,7 @@ import java.util.Locale
import javax.annotation.concurrent.GuardedBy import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable import scala.collection.mutable
import scala.util.{Failure, Success, Try}
import com.google.common.cache.{Cache, CacheBuilder} import com.google.common.cache.{Cache, CacheBuilder}
import org.apache.hadoop.conf.Configuration import org.apache.hadoop.conf.Configuration
...@@ -1202,15 +1203,25 @@ class SessionCatalog( ...@@ -1202,15 +1203,25 @@ class SessionCatalog(
def listFunctions(db: String, pattern: String): Seq[(FunctionIdentifier, String)] = { def listFunctions(db: String, pattern: String): Seq[(FunctionIdentifier, String)] = {
val dbName = formatDatabaseName(db) val dbName = formatDatabaseName(db)
requireDbExists(dbName) requireDbExists(dbName)
val dbFunctions = externalCatalog.listFunctions(dbName, pattern) val dbFunctions = externalCatalog.listFunctions(dbName, pattern).map { f =>
.map { f => FunctionIdentifier(f, Some(dbName)) } FunctionIdentifier(f, Some(dbName)) }
val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern) val loadedFunctions =
.map { f => FunctionIdentifier(f) } StringUtils.filterPattern(functionRegistry.listFunction(), pattern).map { f =>
// In functionRegistry, function names are stored as an unquoted format.
Try(parser.parseFunctionIdentifier(f)) match {
case Success(e) => e
case Failure(_) =>
// The names of some built-in functions are not parsable by our parser, e.g., %
FunctionIdentifier(f)
}
}
val functions = dbFunctions ++ loadedFunctions val functions = dbFunctions ++ loadedFunctions
// The session catalog caches some persistent functions in the FunctionRegistry
// so there can be duplicates.
functions.map { functions.map {
case f if FunctionRegistry.functionSet.contains(f.funcName) => (f, "SYSTEM") case f if FunctionRegistry.functionSet.contains(f.funcName) => (f, "SYSTEM")
case f => (f, "USER") case f => (f, "USER")
} }.distinct
} }
......
...@@ -207,8 +207,6 @@ case class ShowFunctionsCommand( ...@@ -207,8 +207,6 @@ case class ShowFunctionsCommand(
case (f, "USER") if showUserFunctions => f.unquotedString case (f, "USER") if showUserFunctions => f.unquotedString
case (f, "SYSTEM") if showSystemFunctions => f.unquotedString case (f, "SYSTEM") if showSystemFunctions => f.unquotedString
} }
// The session catalog caches some persistent functions in the FunctionRegistry functionNames.sorted.map(Row(_))
// so there can be duplicates.
functionNames.distinct.sorted.map(Row(_))
} }
} }
...@@ -573,6 +573,23 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { ...@@ -573,6 +573,23 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
checkAnswer(testData.selectExpr("statelessUDF() as s").agg(max($"s")), Row(1)) checkAnswer(testData.selectExpr("statelessUDF() as s").agg(max($"s")), Row(1))
} }
} }
test("Show persistent functions") {
val testData = spark.sparkContext.parallelize(StringCaseClass("") :: Nil).toDF()
withTempView("inputTable") {
testData.createOrReplaceTempView("inputTable")
withUserDefinedFunction("testUDFToListInt" -> false) {
val numFunc = spark.catalog.listFunctions().count()
sql(s"CREATE FUNCTION testUDFToListInt AS '${classOf[UDFToListInt].getName}'")
assert(spark.catalog.listFunctions().count() == numFunc + 1)
checkAnswer(
sql("SELECT testUDFToListInt(s) FROM inputTable"),
Seq(Row(Seq(1, 2, 3))))
assert(sql("show functions").count() == numFunc + 1)
assert(spark.catalog.listFunctions().count() == numFunc + 1)
}
}
}
} }
class TestPair(x: Int, y: Int) extends Writable with Serializable { class TestPair(x: Int, y: Int) extends Writable with Serializable {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment