From 66449b8dcdbc3dca126c34b42c4d0419c7648696 Mon Sep 17 00:00:00 2001
From: Liang-Chi Hsieh <viirya@gmail.com>
Date: Thu, 28 Jan 2016 22:20:52 -0800
Subject: [PATCH] [SPARK-12968][SQL] Implement command to set current database

JIRA: https://issues.apache.org/jira/browse/SPARK-12968

Implement command to set current database.

Author: Liang-Chi Hsieh <viirya@gmail.com>
Author: Liang-Chi Hsieh <viirya@appier.com>

Closes #10916 from viirya/ddl-use-database.
---
 .../spark/sql/catalyst/analysis/Catalog.scala    |  4 ++++
 .../org/apache/spark/sql/execution/SparkQl.scala |  3 +++
 .../apache/spark/sql/execution/commands.scala    | 10 ++++++++++
 .../spark/sql/hive/thriftserver/CliSuite.scala   |  2 +-
 .../spark/sql/hive/HiveMetastoreCatalog.scala    |  4 ++++
 .../scala/org/apache/spark/sql/hive/HiveQl.scala |  2 --
 .../spark/sql/hive/client/ClientInterface.scala  |  3 +++
 .../spark/sql/hive/client/ClientWrapper.scala    |  9 +++++++++
 .../sql/hive/execution/HiveQuerySuite.scala      | 16 ++++++++++++++++
 9 files changed, 50 insertions(+), 3 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
index a8f89ce6de..f2f9ec5941 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
@@ -46,6 +46,10 @@ trait Catalog {
 
   def lookupRelation(tableIdent: TableIdentifier, alias: Option[String] = None): LogicalPlan
 
+  def setCurrentDatabase(databaseName: String): Unit = {
+    throw new UnsupportedOperationException
+  }
+
   /**
    * Returns tuples of (tableName, isTemporary) for all tables in the given database.
    * isTemporary is a Boolean value indicates if a table is a temporary or not.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala
index f6055306b6..a5bd8ee42d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala
@@ -55,6 +55,9 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly
           getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs)
         ExplainCommand(nodeToPlan(query), extended = extended.isDefined)
 
+      case Token("TOK_SWITCHDATABASE", Token(database, Nil) :: Nil) =>
+        SetDatabaseCommand(cleanIdentifier(database))
+
       case Token("TOK_DESCTABLE", describeArgs) =>
         // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
         val Some(tableType) :: formatted :: extended :: pretty :: Nil =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index 3cfa3dfd9c..703e4643cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -408,3 +408,13 @@ case class DescribeFunction(
     }
   }
 }
+
+case class SetDatabaseCommand(databaseName: String) extends RunnableCommand {
+
+  override def run(sqlContext: SQLContext): Seq[Row] = {
+    sqlContext.catalog.setCurrentDatabase(databaseName)
+    Seq.empty[Row]
+  }
+
+  override val output: Seq[Attribute] = Seq.empty
+}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
index ab31d45a79..72da266da4 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
@@ -183,7 +183,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging {
       "CREATE DATABASE hive_test_db;"
         -> "OK",
       "USE hive_test_db;"
-        -> "OK",
+        -> "",
       "CREATE TABLE hive_test(key INT, val STRING);"
         -> "OK",
       "SHOW TABLES;"
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index a9c0e9ab7c..848aa4ec6f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -711,6 +711,10 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
   }
 
   override def unregisterAllTables(): Unit = {}
+
+  override def setCurrentDatabase(databaseName: String): Unit = {
+    client.setCurrentDatabase(databaseName)
+  }
 }
 
 /**
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 22841ed211..752c037a84 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -155,8 +155,6 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging
     "TOK_SHOWLOCKS",
     "TOK_SHOWPARTITIONS",
 
-    "TOK_SWITCHDATABASE",
-
     "TOK_UNLOCKTABLE"
   )
 
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
index 9d9a55edd7..4eec3fef74 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
@@ -109,6 +109,9 @@ private[hive] trait ClientInterface {
   /** Returns the name of the active database. */
   def currentDatabase: String
 
+  /** Sets the name of current database. */
+  def setCurrentDatabase(databaseName: String): Unit
+
   /** Returns the metadata for specified database, throwing an exception if it doesn't exist */
   def getDatabase(name: String): HiveDatabase = {
     getDatabaseOption(name).getOrElse(throw new NoSuchDatabaseException)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
index ce7a305d43..5307e924e7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
@@ -35,6 +35,7 @@ import org.apache.hadoop.hive.shims.{HadoopShims, ShimLoader}
 import org.apache.hadoop.security.UserGroupInformation
 
 import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.execution.QueryExecutionException
 import org.apache.spark.util.{CircularBuffer, Utils}
@@ -229,6 +230,14 @@ private[hive] class ClientWrapper(
     state.getCurrentDatabase
   }
 
+  override def setCurrentDatabase(databaseName: String): Unit = withHiveState {
+    if (getDatabaseOption(databaseName).isDefined) {
+      state.setCurrentDatabase(databaseName)
+    } else {
+      throw new NoSuchDatabaseException
+    }
+  }
+
   override def createDatabase(database: HiveDatabase): Unit = withHiveState {
     client.createDatabase(
       new Database(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 4659d745fe..9632d27a2f 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -28,6 +28,7 @@ import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.{SparkException, SparkFiles}
 import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
+import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException
 import org.apache.spark.sql.catalyst.expressions.Cast
 import org.apache.spark.sql.catalyst.plans.logical.Project
 import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin
@@ -1262,6 +1263,21 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
 
   }
 
+  test("use database") {
+    val currentDatabase = sql("select current_database()").first().getString(0)
+
+    sql("CREATE DATABASE hive_test_db")
+    sql("USE hive_test_db")
+    assert("hive_test_db" == sql("select current_database()").first().getString(0))
+
+    intercept[NoSuchDatabaseException] {
+      sql("USE not_existing_db")
+    }
+
+    sql(s"USE $currentDatabase")
+    assert(currentDatabase == sql("select current_database()").first().getString(0))
+  }
+
   test("lookup hive UDF in another thread") {
     val e = intercept[AnalysisException] {
       range(1).selectExpr("not_a_udf()")
-- 
GitLab