From f7c145d8ce14b23019099c509d5a2b6dfb1fe62c Mon Sep 17 00:00:00 2001
From: Herman van Hovell <hvanhovell@databricks.com>
Date: Tue, 1 Nov 2016 15:41:45 +0100
Subject: [PATCH] [SPARK-17996][SQL] Fix unqualified catalog.getFunction(...)

## What changes were proposed in this pull request?

Currently an unqualified `getFunction(..)`call returns a wrong result; the returned function is shown as temporary function without a database. For example:

```
scala> sql("create function fn1 as 'org.apache.hadoop.hive.ql.udf.generic.GenericUDFAbs'")
res0: org.apache.spark.sql.DataFrame = []

scala> spark.catalog.getFunction("fn1")
res1: org.apache.spark.sql.catalog.Function = Function[name='fn1', className='org.apache.hadoop.hive.ql.udf.generic.GenericUDFAbs', isTemporary='true']
```

This PR fixes this by adding database information to ExpressionInfo (which is used to store the function information).
## How was this patch tested?

Added more thorough tests to `CatalogSuite`.

Author: Herman van Hovell <hvanhovell@databricks.com>

Closes #15542 from hvanhovell/SPARK-17996.
---
 .../sql/catalyst/expressions/ExpressionInfo.java  | 14 ++++++++++++--
 .../sql/catalyst/analysis/FunctionRegistry.scala  |  2 +-
 .../sql/catalyst/catalog/SessionCatalog.scala     | 10 ++++++++--
 .../spark/sql/execution/command/functions.scala   |  5 +++--
 .../apache/spark/sql/internal/CatalogImpl.scala   |  6 +++---
 .../apache/spark/sql/internal/CatalogSuite.scala  | 15 ++++++++++++---
 6 files changed, 39 insertions(+), 13 deletions(-)

diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java
index ba8e9cb4be..4565ed4487 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java
@@ -25,6 +25,7 @@ public class ExpressionInfo {
     private String usage;
     private String name;
     private String extended;
+    private String db;
 
     public String getClassName() {
         return className;
@@ -42,14 +43,23 @@ public class ExpressionInfo {
         return extended;
     }
 
-    public ExpressionInfo(String className, String name, String usage, String extended) {
+    public String getDb() {
+        return db;
+    }
+
+    public ExpressionInfo(String className, String db, String name, String usage, String extended) {
         this.className = className;
+        this.db = db;
         this.name = name;
         this.usage = usage;
         this.extended = extended;
     }
 
     public ExpressionInfo(String className, String name) {
-        this(className, name, null, null);
+        this(className, null, name, null, null);
+    }
+
+    public ExpressionInfo(String className, String db, String name) {
+        this(className, db, name, null, null);
     }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index b05f4f61f6..3e836ca375 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -495,7 +495,7 @@ object FunctionRegistry {
     val clazz = scala.reflect.classTag[T].runtimeClass
     val df = clazz.getAnnotation(classOf[ExpressionDescription])
     if (df != null) {
-      new ExpressionInfo(clazz.getCanonicalName, name, df.usage(), df.extended())
+      new ExpressionInfo(clazz.getCanonicalName, null, name, df.usage(), df.extended())
     } else {
       new ExpressionInfo(clazz.getCanonicalName, name)
     }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index 3d6eec81c0..714ef825ab 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -943,7 +943,10 @@ class SessionCatalog(
         requireDbExists(db)
         if (externalCatalog.functionExists(db, name.funcName)) {
           val metadata = externalCatalog.getFunction(db, name.funcName)
-          new ExpressionInfo(metadata.className, qualifiedName.unquotedString)
+          new ExpressionInfo(
+            metadata.className,
+            qualifiedName.database.orNull,
+            qualifiedName.identifier)
         } else {
           failFunctionLookup(name.funcName)
         }
@@ -1000,7 +1003,10 @@ class SessionCatalog(
     // catalog. So, it is possible that qualifiedName is not exactly the same as
     // catalogFunction.identifier.unquotedString (difference is on case-sensitivity).
     // At here, we preserve the input from the user.
-    val info = new ExpressionInfo(catalogFunction.className, qualifiedName.unquotedString)
+    val info = new ExpressionInfo(
+      catalogFunction.className,
+      qualifiedName.database.orNull,
+      qualifiedName.funcName)
     val builder = makeFunctionBuilder(qualifiedName.unquotedString, catalogFunction.className)
     createTempFunction(qualifiedName.unquotedString, info, builder, ignoreIfExists = false)
     // Now, we need to create the Expression.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala
index 26593d2918..24d825f5cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala
@@ -118,14 +118,15 @@ case class DescribeFunctionCommand(
       case _ =>
         try {
           val info = sparkSession.sessionState.catalog.lookupFunctionInfo(functionName)
+          val name = if (info.getDb != null) info.getDb + "." + info.getName else info.getName
           val result =
-            Row(s"Function: ${info.getName}") ::
+            Row(s"Function: $name") ::
               Row(s"Class: ${info.getClassName}") ::
               Row(s"Usage: ${replaceFunctionName(info.getUsage, info.getName)}") :: Nil
 
           if (isExtended) {
             result :+
-              Row(s"Extended Usage:\n${replaceFunctionName(info.getExtended, info.getName)}")
+              Row(s"Extended Usage:\n${replaceFunctionName(info.getExtended, name)}")
           } else {
             result
           }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
index f6c297e91b..44fd38dfb9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala
@@ -133,11 +133,11 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
   private def makeFunction(funcIdent: FunctionIdentifier): Function = {
     val metadata = sessionCatalog.lookupFunctionInfo(funcIdent)
     new Function(
-      name = funcIdent.identifier,
-      database = funcIdent.database.orNull,
+      name = metadata.getName,
+      database = metadata.getDb,
       description = null, // for now, this is always undefined
       className = metadata.getClassName,
-      isTemporary = funcIdent.database.isEmpty)
+      isTemporary = metadata.getDb == null)
   }
 
   /**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
index 214bc736bd..89ec162c8e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala
@@ -386,15 +386,24 @@ class CatalogSuite
         createFunction("fn2", Some(db))
 
         // Find a temporary function
-        assert(spark.catalog.getFunction("fn1").name === "fn1")
+        val fn1 = spark.catalog.getFunction("fn1")
+        assert(fn1.name === "fn1")
+        assert(fn1.database === null)
+        assert(fn1.isTemporary)
 
         // Find a qualified function
-        assert(spark.catalog.getFunction(db, "fn2").name === "fn2")
+        val fn2 = spark.catalog.getFunction(db, "fn2")
+        assert(fn2.name === "fn2")
+        assert(fn2.database === db)
+        assert(!fn2.isTemporary)
 
         // Find an unqualified function using the current database
         intercept[AnalysisException](spark.catalog.getFunction("fn2"))
         spark.catalog.setCurrentDatabase(db)
-        assert(spark.catalog.getFunction("fn2").name === "fn2")
+        val unqualified = spark.catalog.getFunction("fn2")
+        assert(unqualified.name === "fn2")
+        assert(unqualified.database === db)
+        assert(!unqualified.isTemporary)
       }
     }
   }
-- 
GitLab