From f45379173bc3a3e657b6229bec2faeb409b6ad53 Mon Sep 17 00:00:00 2001
From: gatorsmile <gatorsmile@gmail.com>
Date: Tue, 10 May 2016 11:57:01 +0800
Subject: [PATCH] [SPARK-15187][SQL] Disallow Dropping Default Database

#### What changes were proposed in this pull request?
In Hive Metastore, dropping default database is not allowed. However, in `InMemoryCatalog`, this is allowed.

This PR is to disallow users to drop default database.

#### How was this patch tested?
Previously, we already have a test case in HiveDDLSuite. Now, we also add the same one in DDLSuite

Author: gatorsmile <gatorsmile@gmail.com>

Closes #12962 from gatorsmile/dropDefaultDB.
---
 .../sql/catalyst/catalog/SessionCatalog.scala | 96 +++++++++++--------
 .../sql/execution/command/DDLSuite.scala      | 28 +++++-
 .../spark/sql/hive/HiveSessionCatalog.scala   | 15 +--
 .../sql/hive/execution/HiveDDLSuite.scala     | 19 +++-
 4 files changed, 106 insertions(+), 52 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index 18524e4118..b267798e7d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -82,7 +82,7 @@ class SessionCatalog(
       CatalogDatabase(defaultName, "default database", conf.warehousePath, Map())
     // Initialize default database if it doesn't already exist
     createDatabase(defaultDbDefinition, ignoreIfExists = true)
-    defaultName
+    formatDatabaseName(defaultName)
   }
 
   /**
@@ -92,6 +92,13 @@ class SessionCatalog(
     if (conf.caseSensitiveAnalysis) name else name.toLowerCase
   }
 
+  /**
+   * Format database name, taking into account case sensitivity.
+   */
+  protected[this] def formatDatabaseName(name: String): String = {
+    if (conf.caseSensitiveAnalysis) name else name.toLowerCase
+  }
+
   /**
    * This method is used to make the given path qualified before we
    * store this path in the underlying external catalog. So, when a path
@@ -112,25 +119,33 @@ class SessionCatalog(
 
   def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = {
     val qualifiedPath = makeQualifiedPath(dbDefinition.locationUri).toString
+    val dbName = formatDatabaseName(dbDefinition.name)
     externalCatalog.createDatabase(
-      dbDefinition.copy(locationUri = qualifiedPath),
+      dbDefinition.copy(name = dbName, locationUri = qualifiedPath),
       ignoreIfExists)
   }
 
   def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = {
-    externalCatalog.dropDatabase(db, ignoreIfNotExists, cascade)
+    val dbName = formatDatabaseName(db)
+    if (dbName == "default") {
+      throw new AnalysisException(s"Can not drop default database")
+    }
+    externalCatalog.dropDatabase(dbName, ignoreIfNotExists, cascade)
   }
 
   def alterDatabase(dbDefinition: CatalogDatabase): Unit = {
-    externalCatalog.alterDatabase(dbDefinition)
+    val dbName = formatDatabaseName(dbDefinition.name)
+    externalCatalog.alterDatabase(dbDefinition.copy(name = dbName))
   }
 
   def getDatabaseMetadata(db: String): CatalogDatabase = {
-    externalCatalog.getDatabase(db)
+    val dbName = formatDatabaseName(db)
+    externalCatalog.getDatabase(dbName)
   }
 
   def databaseExists(db: String): Boolean = {
-    externalCatalog.databaseExists(db)
+    val dbName = formatDatabaseName(db)
+    externalCatalog.databaseExists(dbName)
   }
 
   def listDatabases(): Seq[String] = {
@@ -144,10 +159,11 @@ class SessionCatalog(
   def getCurrentDatabase: String = synchronized { currentDb }
 
   def setCurrentDatabase(db: String): Unit = {
-    if (!databaseExists(db)) {
-      throw new AnalysisException(s"Database '$db' does not exist.")
+    val dbName = formatDatabaseName(db)
+    if (!databaseExists(dbName)) {
+      throw new AnalysisException(s"Database '$dbName' does not exist.")
     }
-    synchronized { currentDb = db }
+    synchronized { currentDb = dbName }
   }
 
   /**
@@ -155,7 +171,7 @@ class SessionCatalog(
    * by users.
    */
   def getDefaultDBPath(db: String): String = {
-    val database = if (conf.caseSensitiveAnalysis) db else db.toLowerCase
+    val database = formatDatabaseName(db)
     new Path(new Path(conf.warehousePath), database + ".db").toString
   }
 
@@ -177,7 +193,7 @@ class SessionCatalog(
    * If no such database is specified, create it in the current database.
    */
   def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = {
-    val db = tableDefinition.identifier.database.getOrElse(getCurrentDatabase)
+    val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase))
     val table = formatTableName(tableDefinition.identifier.table)
     val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db)))
     externalCatalog.createTable(db, newTableDefinition, ignoreIfExists)
@@ -193,7 +209,7 @@ class SessionCatalog(
    * this becomes a no-op.
    */
   def alterTable(tableDefinition: CatalogTable): Unit = {
-    val db = tableDefinition.identifier.database.getOrElse(getCurrentDatabase)
+    val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase))
     val table = formatTableName(tableDefinition.identifier.table)
     val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db)))
     externalCatalog.alterTable(db, newTableDefinition)
@@ -205,7 +221,7 @@ class SessionCatalog(
    * If the specified table is not found in the database then an [[AnalysisException]] is thrown.
    */
   def getTableMetadata(name: TableIdentifier): CatalogTable = {
-    val db = name.database.getOrElse(getCurrentDatabase)
+    val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
     val table = formatTableName(name.table)
     externalCatalog.getTable(db, table)
   }
@@ -216,7 +232,7 @@ class SessionCatalog(
    * If the specified table is not found in the database then return None if it doesn't exist.
    */
   def getTableMetadataOption(name: TableIdentifier): Option[CatalogTable] = {
-    val db = name.database.getOrElse(getCurrentDatabase)
+    val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
     val table = formatTableName(name.table)
     externalCatalog.getTableOption(db, table)
   }
@@ -231,7 +247,7 @@ class SessionCatalog(
       loadPath: String,
       isOverwrite: Boolean,
       holdDDLTime: Boolean): Unit = {
-    val db = name.database.getOrElse(getCurrentDatabase)
+    val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
     val table = formatTableName(name.table)
     externalCatalog.loadTable(db, table, loadPath, isOverwrite, holdDDLTime)
   }
@@ -249,14 +265,14 @@ class SessionCatalog(
       holdDDLTime: Boolean,
       inheritTableSpecs: Boolean,
       isSkewedStoreAsSubdir: Boolean): Unit = {
-    val db = name.database.getOrElse(getCurrentDatabase)
+    val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
     val table = formatTableName(name.table)
     externalCatalog.loadPartition(db, table, loadPath, partition, isOverwrite, holdDDLTime,
       inheritTableSpecs, isSkewedStoreAsSubdir)
   }
 
   def defaultTablePath(tableIdent: TableIdentifier): String = {
-    val dbName = tableIdent.database.getOrElse(getCurrentDatabase)
+    val dbName = formatDatabaseName(tableIdent.database.getOrElse(getCurrentDatabase))
     val dbLocation = getDatabaseMetadata(dbName).locationUri
 
     new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toString
@@ -290,8 +306,8 @@ class SessionCatalog(
    * This assumes the database specified in `oldName` matches the one specified in `newName`.
    */
   def renameTable(oldName: TableIdentifier, newName: TableIdentifier): Unit = synchronized {
-    val db = oldName.database.getOrElse(currentDb)
-    val newDb = newName.database.getOrElse(currentDb)
+    val db = formatDatabaseName(oldName.database.getOrElse(currentDb))
+    val newDb = formatDatabaseName(newName.database.getOrElse(currentDb))
     if (db != newDb) {
       throw new AnalysisException(
         s"RENAME TABLE source and destination databases do not match: '$db' != '$newDb'")
@@ -324,7 +340,7 @@ class SessionCatalog(
    * the same name, then, if that does not exist, drop the table from the current database.
    */
   def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = synchronized {
-    val db = name.database.getOrElse(currentDb)
+    val db = formatDatabaseName(name.database.getOrElse(currentDb))
     val table = formatTableName(name.table)
     if (name.database.isDefined || !tempTables.contains(table)) {
       // When ignoreIfNotExists is false, no exception is issued when the table does not exist.
@@ -348,7 +364,7 @@ class SessionCatalog(
    */
   def lookupRelation(name: TableIdentifier, alias: Option[String] = None): LogicalPlan = {
     synchronized {
-      val db = name.database.getOrElse(currentDb)
+      val db = formatDatabaseName(name.database.getOrElse(currentDb))
       val table = formatTableName(name.table)
       val relation =
         if (name.database.isDefined || !tempTables.contains(table)) {
@@ -373,7 +389,7 @@ class SessionCatalog(
    * contain the table.
    */
   def tableExists(name: TableIdentifier): Boolean = synchronized {
-    val db = name.database.getOrElse(currentDb)
+    val db = formatDatabaseName(name.database.getOrElse(currentDb))
     val table = formatTableName(name.table)
     if (name.database.isDefined || !tempTables.contains(table)) {
       externalCatalog.tableExists(db, table)
@@ -395,14 +411,15 @@ class SessionCatalog(
   /**
    * List all tables in the specified database, including temporary tables.
    */
-  def listTables(db: String): Seq[TableIdentifier] = listTables(db, "*")
+  def listTables(db: String): Seq[TableIdentifier] = listTables(formatDatabaseName(db), "*")
 
   /**
    * List all matching tables in the specified database, including temporary tables.
    */
   def listTables(db: String, pattern: String): Seq[TableIdentifier] = {
+    val dbName = formatDatabaseName(db)
     val dbTables =
-      externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(db)) }
+      externalCatalog.listTables(dbName, pattern).map { t => TableIdentifier(t, Some(dbName)) }
     synchronized {
       val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern)
         .map { t => TableIdentifier(t) }
@@ -458,7 +475,7 @@ class SessionCatalog(
       tableName: TableIdentifier,
       parts: Seq[CatalogTablePartition],
       ignoreIfExists: Boolean): Unit = {
-    val db = tableName.database.getOrElse(getCurrentDatabase)
+    val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase))
     val table = formatTableName(tableName.table)
     externalCatalog.createPartitions(db, table, parts, ignoreIfExists)
   }
@@ -471,7 +488,7 @@ class SessionCatalog(
       tableName: TableIdentifier,
       parts: Seq[TablePartitionSpec],
       ignoreIfNotExists: Boolean): Unit = {
-    val db = tableName.database.getOrElse(getCurrentDatabase)
+    val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase))
     val table = formatTableName(tableName.table)
     externalCatalog.dropPartitions(db, table, parts, ignoreIfNotExists)
   }
@@ -486,7 +503,7 @@ class SessionCatalog(
       tableName: TableIdentifier,
       specs: Seq[TablePartitionSpec],
       newSpecs: Seq[TablePartitionSpec]): Unit = {
-    val db = tableName.database.getOrElse(getCurrentDatabase)
+    val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase))
     val table = formatTableName(tableName.table)
     externalCatalog.renamePartitions(db, table, specs, newSpecs)
   }
@@ -501,7 +518,7 @@ class SessionCatalog(
    * this becomes a no-op.
    */
   def alterPartitions(tableName: TableIdentifier, parts: Seq[CatalogTablePartition]): Unit = {
-    val db = tableName.database.getOrElse(getCurrentDatabase)
+    val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase))
     val table = formatTableName(tableName.table)
     externalCatalog.alterPartitions(db, table, parts)
   }
@@ -511,7 +528,7 @@ class SessionCatalog(
    * If no database is specified, assume the table is in the current database.
    */
   def getPartition(tableName: TableIdentifier, spec: TablePartitionSpec): CatalogTablePartition = {
-    val db = tableName.database.getOrElse(getCurrentDatabase)
+    val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase))
     val table = formatTableName(tableName.table)
     externalCatalog.getPartition(db, table, spec)
   }
@@ -526,7 +543,7 @@ class SessionCatalog(
   def listPartitions(
       tableName: TableIdentifier,
       partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = {
-    val db = tableName.database.getOrElse(getCurrentDatabase)
+    val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase))
     val table = formatTableName(tableName.table)
     externalCatalog.listPartitions(db, table, partialSpec)
   }
@@ -549,7 +566,7 @@ class SessionCatalog(
    * If no such database is specified, create it in the current database.
    */
   def createFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = {
-    val db = funcDefinition.identifier.database.getOrElse(getCurrentDatabase)
+    val db = formatDatabaseName(funcDefinition.identifier.database.getOrElse(getCurrentDatabase))
     val identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db))
     val newFuncDefinition = funcDefinition.copy(identifier = identifier)
     if (!functionExists(identifier)) {
@@ -564,7 +581,7 @@ class SessionCatalog(
    * If no database is specified, assume the function is in the current database.
    */
   def dropFunction(name: FunctionIdentifier, ignoreIfNotExists: Boolean): Unit = {
-    val db = name.database.getOrElse(getCurrentDatabase)
+    val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
     val identifier = name.copy(database = Some(db))
     if (functionExists(identifier)) {
       // TODO: registry should just take in FunctionIdentifier for type safety
@@ -588,7 +605,7 @@ class SessionCatalog(
    * If no database is specified, this will return the function in the current database.
    */
   def getFunctionMetadata(name: FunctionIdentifier): CatalogFunction = {
-    val db = name.database.getOrElse(getCurrentDatabase)
+    val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
     externalCatalog.getFunction(db, name.funcName)
   }
 
@@ -596,7 +613,7 @@ class SessionCatalog(
    * Check if the specified function exists.
    */
   def functionExists(name: FunctionIdentifier): Boolean = {
-    val db = name.database.getOrElse(getCurrentDatabase)
+    val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
     functionRegistry.functionExists(name.unquotedString) ||
       externalCatalog.functionExists(db, name.funcName)
   }
@@ -661,7 +678,8 @@ class SessionCatalog(
    */
   private[spark] def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = synchronized {
     // TODO: just make function registry take in FunctionIdentifier instead of duplicating this
-    val qualifiedName = name.copy(database = name.database.orElse(Some(currentDb)))
+    val database = name.database.orElse(Some(currentDb)).map(formatDatabaseName)
+    val qualifiedName = name.copy(database = database)
     functionRegistry.lookupFunction(name.funcName)
       .orElse(functionRegistry.lookupFunction(qualifiedName.unquotedString))
       .getOrElse {
@@ -700,7 +718,8 @@ class SessionCatalog(
     }
 
     // If the name itself is not qualified, add the current database to it.
-    val qualifiedName = if (name.database.isEmpty) name.copy(database = Some(currentDb)) else name
+    val database = name.database.orElse(Some(currentDb)).map(formatDatabaseName)
+    val qualifiedName = name.copy(database = database)
 
     if (functionRegistry.functionExists(qualifiedName.unquotedString)) {
       // This function has been already loaded into the function registry.
@@ -740,8 +759,9 @@ class SessionCatalog(
    * List all matching functions in the specified database, including temporary functions.
    */
   def listFunctions(db: String, pattern: String): Seq[FunctionIdentifier] = {
-    val dbFunctions =
-      externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) }
+    val dbName = formatDatabaseName(db)
+    val dbFunctions = externalCatalog.listFunctions(dbName, pattern)
+      .map { f => FunctionIdentifier(f, Some(dbName)) }
     val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern)
       .map { f => FunctionIdentifier(f) }
     dbFunctions ++ loadedFunctions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
index d1155678e7..3586ddf7b6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
@@ -644,16 +644,16 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
 
     checkAnswer(
       sql("SHOW DATABASES LIKE '*db1A'"),
-      Row("showdb1A") :: Nil)
+      Row("showdb1a") :: Nil)
 
     checkAnswer(
       sql("SHOW DATABASES LIKE 'showdb1A'"),
-      Row("showdb1A") :: Nil)
+      Row("showdb1a") :: Nil)
 
     checkAnswer(
       sql("SHOW DATABASES LIKE '*db1A|*db2B'"),
-      Row("showdb1A") ::
-        Row("showdb2B") :: Nil)
+      Row("showdb1a") ::
+        Row("showdb2b") :: Nil)
 
     checkAnswer(
       sql("SHOW DATABASES LIKE 'non-existentdb'"),
@@ -1000,4 +1000,24 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
         Row("Usage: a ^ b - Bitwise exclusive OR.") :: Nil
     )
   }
+
+  test("drop default database") {
+    Seq("true", "false").foreach { caseSensitive =>
+      withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) {
+        var message = intercept[AnalysisException] {
+          sql("DROP DATABASE default")
+        }.getMessage
+        assert(message.contains("Can not drop default database"))
+
+        message = intercept[AnalysisException] {
+          sql("DROP DATABASE DeFault")
+        }.getMessage
+        if (caseSensitive == "true") {
+          assert(message.contains("Database 'DeFault' does not exist"))
+        } else {
+          assert(message.contains("Can not drop default database"))
+        }
+      }
+    }
+  }
 }
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 3220f143aa..75a252ccba 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -62,7 +62,8 @@ private[sql] class HiveSessionCatalog(
   override def lookupRelation(name: TableIdentifier, alias: Option[String]): LogicalPlan = {
     val table = formatTableName(name.table)
     if (name.database.isDefined || !tempTables.contains(table)) {
-      val newName = name.copy(table = table)
+      val database = name.database.map(formatDatabaseName)
+      val newName = name.copy(database = database, table = table)
       metastoreCatalog.lookupRelation(newName, alias)
     } else {
       val relation = tempTables(table)
@@ -181,10 +182,12 @@ private[sql] class HiveSessionCatalog(
     //   // This function is a Hive builtin function.
     //   ...
     // }
-    Try(super.lookupFunction(name, children)) match {
+    val database = name.database.map(formatDatabaseName)
+    val funcName = name.copy(database = database)
+    Try(super.lookupFunction(funcName, children)) match {
       case Success(expr) => expr
       case Failure(error) =>
-        if (functionRegistry.functionExists(name.unquotedString)) {
+        if (functionRegistry.functionExists(funcName.unquotedString)) {
           // If the function actually exists in functionRegistry, it means that there is an
           // error when we create the Expression using the given children.
           // We need to throw the original exception.
@@ -193,7 +196,7 @@ private[sql] class HiveSessionCatalog(
           // This function is not in functionRegistry, let's try to load it as a Hive's
           // built-in function.
           // Hive is case insensitive.
-          val functionName = name.unquotedString.toLowerCase
+          val functionName = funcName.unquotedString.toLowerCase
           // TODO: This may not really work for current_user because current_user is not evaluated
           // with session info.
           // We do not need to use executionHive at here because we only load
@@ -201,12 +204,12 @@ private[sql] class HiveSessionCatalog(
           val functionInfo = {
             try {
               Option(HiveFunctionRegistry.getFunctionInfo(functionName)).getOrElse(
-                failFunctionLookup(name.unquotedString))
+                failFunctionLookup(funcName.unquotedString))
             } catch {
               // If HiveFunctionRegistry.getFunctionInfo throws an exception,
               // we are failing to load a Hive builtin function, which means that
               // the given function is not a Hive builtin function.
-              case NonFatal(e) => failFunctionLookup(name.unquotedString)
+              case NonFatal(e) => failFunctionLookup(funcName.unquotedString)
             }
           }
           val className = functionInfo.getFunctionClass.getName
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
index aa5b5e6309..a8ba952b49 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
@@ -498,10 +498,21 @@ class HiveDDLSuite
   }
 
   test("drop default database") {
-    val message = intercept[AnalysisException] {
-      sql("DROP DATABASE default")
-    }.getMessage
-    assert(message.contains("Can not drop default database"))
+    Seq("true", "false").foreach { caseSensitive =>
+      withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) {
+        var message = intercept[AnalysisException] {
+          sql("DROP DATABASE default")
+        }.getMessage
+        assert(message.contains("Can not drop default database"))
+
+        // SQLConf.CASE_SENSITIVE does not affect the result
+        // because the Hive metastore is not case sensitive.
+        message = intercept[AnalysisException] {
+          sql("DROP DATABASE DeFault")
+        }.getMessage
+        assert(message.contains("Can not drop default database"))
+      }
+    }
   }
 
   test("desc table for data source table") {
-- 
GitLab