diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index ffaefeb09aedb16fbc47cc5045f9344bcea34231..d88b5ffc0511ceae6948ad073d5b765b4b60a9c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -34,6 +34,10 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.catalyst.util.StringUtils +object SessionCatalog { + val DEFAULT_DATABASE = "default" +} + /** * An internal catalog that is used by a Spark Session. This internal catalog serves as a * proxy to the underlying metastore (e.g. Hive Metastore) and it also manages temporary @@ -47,6 +51,7 @@ class SessionCatalog( functionRegistry: FunctionRegistry, conf: CatalystConf, hadoopConf: Configuration) extends Logging { + import SessionCatalog._ import CatalogTypes.TablePartitionSpec // For testing only. @@ -77,7 +82,7 @@ class SessionCatalog( // the corresponding item in the current database. @GuardedBy("this") protected var currentDb = { - val defaultName = "default" + val defaultName = DEFAULT_DATABASE val defaultDbDefinition = CatalogDatabase(defaultName, "default database", conf.warehousePath, Map()) // Initialize default database if it doesn't already exist @@ -146,8 +151,10 @@ class SessionCatalog( def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { val dbName = formatDatabaseName(db) - if (dbName == "default") { + if (dbName == DEFAULT_DATABASE) { throw new AnalysisException(s"Can not drop default database") + } else if (dbName == getCurrentDatabase) { + throw new AnalysisException(s"Can not drop current database `${dbName}`") } externalCatalog.dropDatabase(dbName, ignoreIfNotExists, cascade) } @@ -878,14 +885,14 @@ class SessionCatalog( * This is mainly used for tests. */ private[sql] def reset(): Unit = synchronized { - val default = "default" - listDatabases().filter(_ != default).foreach { db => + setCurrentDatabase(DEFAULT_DATABASE) + listDatabases().filter(_ != DEFAULT_DATABASE).foreach { db => dropDatabase(db, ignoreIfNotExists = false, cascade = true) } - listTables(default).foreach { table => + listTables(DEFAULT_DATABASE).foreach { table => dropTable(table, ignoreIfNotExists = false) } - listFunctions(default).map(_._1).foreach { func => + listFunctions(DEFAULT_DATABASE).map(_._1).foreach { func => if (func.database.isDefined) { dropFunction(func, ignoreIfNotExists = false) } else { @@ -902,7 +909,6 @@ class SessionCatalog( require(functionBuilder.isDefined, s"built-in function '$f' is missing function builder") functionRegistry.registerFunction(f, expressionInfo.get, functionBuilder.get) } - setCurrentDatabase(default) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 7d1f1d1e62fc7d093ac66d880aaf6c1c5d5bf663..b4294ed7ff1aa2c1c8a105055701510a233f6adb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1270,6 +1270,15 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { "WITH SERDEPROPERTIES ('spark.sql.sources.me'='anything')") } + test("drop current database") { + sql("CREATE DATABASE temp") + sql("USE temp") + val m = intercept[AnalysisException] { + sql("DROP DATABASE temp") + }.getMessage + assert(m.contains("Can not drop current database `temp`")) + } + test("drop default database") { Seq("true", "false").foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala index 3aa8174702513782e00d9fea2ca276fd8e2e013a..57363b7259c61e2a55ea89ed7e42041307e37050 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextCompatibilitySuite.scala @@ -93,6 +93,7 @@ class HiveContextCompatibilitySuite extends SparkFunSuite with BeforeAndAfterEac hc.sql("DROP TABLE mee_table") val tables2 = hc.sql("SHOW TABLES IN mee_db").collect().map(_.getString(0)) assert(tables2.isEmpty) + hc.sql("USE default") hc.sql("DROP DATABASE mee_db CASCADE") val databases3 = hc.sql("SHOW DATABASES").collect().map(_.getString(0)) assert(databases3.toSeq == Seq("default")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 93e50f4ee907b9096e6775218ee1bd5618897405..343d7bae98bff8381c71bde192aed8a71a78b52a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -472,6 +472,7 @@ class HiveDDLSuite sql(s"DROP TABLE $tabName") assert(tmpDir.listFiles.isEmpty) + sql("USE default") sql(s"DROP DATABASE $dbName") assert(!fs.exists(new Path(tmpDir.toString))) } @@ -526,6 +527,7 @@ class HiveDDLSuite assert(!tableDirectoryExists(TableIdentifier(tabName), Option(expectedDBLocation))) } + sql(s"USE default") val sqlDropDatabase = s"DROP DATABASE $dbName ${if (cascade) "CASCADE" else "RESTRICT"}" if (tableExists && !cascade) { val message = intercept[AnalysisException] {