diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index c494e5d7042136e0a38651a7cebe86748d9d099e..b423f0fa04f6978b09dadd5f05aba6136555b239 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -48,6 +48,15 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } } + def parseTableIdentifier(input: String): TableIdentifier = { + // Initialize the Keywords. + initLexical + phrase(tableIdentifier)(new lexical.Scanner(input)) match { + case Success(ident, _) => ident + case failureOrError => sys.error(failureOrError.toString) + } + } + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object protected val ALL = Keyword("ALL") @@ -444,4 +453,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { case i1 ~ i2 ~ rest => UnresolvedAttribute(Seq(i1, i2) ++ rest) } + + protected lazy val tableIdentifier: Parser[TableIdentifier] = + (ident <~ ".").? ~ ident ^^ { + case maybeDbName ~ tableName => TableIdentifier(tableName, maybeDbName) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala new file mode 100644 index 0000000000000000000000000000000000000000..aebcdeb9d070fe085b4925f9ad4e331cfe13d48f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +/** + * Identifies a `table` in `database`. If `database` is not defined, the current database is used. + */ +private[sql] case class TableIdentifier(table: String, database: Option[String] = None) { + def withDatabase(database: String): TableIdentifier = this.copy(database = Some(database)) + + def toSeq: Seq[String] = database.toSeq :+ table + + override def toString: String = toSeq.map("`" + _ + "`").mkString(".") + + def unquotedString: String = toSeq.mkString(".") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 1541491608b24ce3073af9978378032ffdb8a1a8..5766e6a2dd51a40c8807e5703348c5f1b7d8aa62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -23,8 +23,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.EmptyConf +import org.apache.spark.sql.catalyst.{TableIdentifier, CatalystConf, EmptyConf} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} /** @@ -54,7 +53,7 @@ trait Catalog { */ def getTables(databaseName: Option[String]): Seq[(String, Boolean)] - def refreshTable(databaseName: String, tableName: String): Unit + def refreshTable(tableIdent: TableIdentifier): Unit def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit @@ -132,7 +131,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { result } - override def refreshTable(databaseName: String, tableName: String): Unit = { + override def refreshTable(tableIdent: TableIdentifier): Unit = { throw new UnsupportedOperationException } } @@ -241,7 +240,7 @@ object EmptyCatalog extends Catalog { override def unregisterAllTables(): Unit = {} - override def refreshTable(databaseName: String, tableName: String): Unit = { + override def refreshTable(tableIdent: TableIdentifier): Unit = { throw new UnsupportedOperationException } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 05da05d7b80508785234ca60c5d76c5d423b0b6b..7e3318cefe62c2f336d1db66a3e9b0fa65a387bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.util.Properties import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} @@ -159,15 +160,19 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def insertInto(tableName: String): Unit = { - val partitions = - partitioningColumns.map(_.map(col => col -> (None: Option[String])).toMap) - val overwrite = (mode == SaveMode.Overwrite) - df.sqlContext.executePlan(InsertIntoTable( - UnresolvedRelation(Seq(tableName)), - partitions.getOrElse(Map.empty[String, Option[String]]), - df.logicalPlan, - overwrite, - ifNotExists = false)).toRdd + insertInto(new SqlParser().parseTableIdentifier(tableName)) + } + + private def insertInto(tableIdent: TableIdentifier): Unit = { + val partitions = partitioningColumns.map(_.map(col => col -> (None: Option[String])).toMap) + val overwrite = mode == SaveMode.Overwrite + df.sqlContext.executePlan( + InsertIntoTable( + UnresolvedRelation(tableIdent.toSeq), + partitions.getOrElse(Map.empty[String, Option[String]]), + df.logicalPlan, + overwrite, + ifNotExists = false)).toRdd } /** @@ -183,35 +188,37 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - if (df.sqlContext.catalog.tableExists(tableName :: Nil) && mode != SaveMode.Overwrite) { - mode match { - case SaveMode.Ignore => - // Do nothing - - case SaveMode.ErrorIfExists => - throw new AnalysisException(s"Table $tableName already exists.") - - case SaveMode.Append => - // If it is Append, we just ask insertInto to handle it. We will not use insertInto - // to handle saveAsTable with Overwrite because saveAsTable can change the schema of - // the table. But, insertInto with Overwrite requires the schema of data be the same - // the schema of the table. - insertInto(tableName) - - case SaveMode.Overwrite => - throw new UnsupportedOperationException("overwrite mode unsupported.") - } - } else { - val cmd = - CreateTableUsingAsSelect( - tableName, - source, - temporary = false, - partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), - mode, - extraOptions.toMap, - df.logicalPlan) - df.sqlContext.executePlan(cmd).toRdd + saveAsTable(new SqlParser().parseTableIdentifier(tableName)) + } + + private def saveAsTable(tableIdent: TableIdentifier): Unit = { + val tableExists = df.sqlContext.catalog.tableExists(tableIdent.toSeq) + + (tableExists, mode) match { + case (true, SaveMode.Ignore) => + // Do nothing + + case (true, SaveMode.ErrorIfExists) => + throw new AnalysisException(s"Table $tableIdent already exists.") + + case (true, SaveMode.Append) => + // If it is Append, we just ask insertInto to handle it. We will not use insertInto + // to handle saveAsTable with Overwrite because saveAsTable can change the schema of + // the table. But, insertInto with Overwrite requires the schema of data be the same + // the schema of the table. + insertInto(tableIdent) + + case _ => + val cmd = + CreateTableUsingAsSelect( + tableIdent.unquotedString, + source, + temporary = false, + partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + mode, + extraOptions.toMap, + df.logicalPlan) + df.sqlContext.executePlan(cmd).toRdd } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 0e25e06e99ab201071cacd32d5e41d409aada472..dbb2a09846548e505fe8a9a7571aad86ac86ee0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -798,8 +798,10 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group ddl_ops * @since 1.3.0 */ - def table(tableName: String): DataFrame = - DataFrame(this, catalog.lookupRelation(Seq(tableName))) + def table(tableName: String): DataFrame = { + val tableIdent = new SqlParser().parseTableIdentifier(tableName) + DataFrame(this, catalog.lookupRelation(tableIdent.toSeq)) + } /** * Returns a [[DataFrame]] containing names of existing tables in the current database. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 1f2797ec5527a6568ade2ec9d034210616bc2a24..e73b3704d4dfe80a72d53c1163bae47f15e5cef5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -21,16 +21,17 @@ import scala.language.{existentials, implicitConversions} import scala.util.matching.Regex import org.apache.hadoop.fs.Path + import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SQLContext, SaveMode} -import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SQLContext, SaveMode} import org.apache.spark.util.Utils /** @@ -151,7 +152,7 @@ private[sql] class DDLParser( protected lazy val refreshTable: Parser[LogicalPlan] = REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ { case maybeDatabaseName ~ tableName => - RefreshTable(maybeDatabaseName.getOrElse("default"), tableName) + RefreshTable(TableIdentifier(tableName, maybeDatabaseName)) } protected lazy val options: Parser[Map[String, String]] = @@ -442,16 +443,16 @@ private[sql] case class CreateTempTableUsingAsSelect( } } -private[sql] case class RefreshTable(databaseName: String, tableName: String) +private[sql] case class RefreshTable(tableIdent: TableIdentifier) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { // Refresh the given table's metadata first. - sqlContext.catalog.refreshTable(databaseName, tableName) + sqlContext.catalog.refreshTable(tableIdent) // If this table is cached as a InMemoryColumnarRelation, drop the original // cached version and make the new version cached lazily. - val logicalPlan = sqlContext.catalog.lookupRelation(Seq(databaseName, tableName)) + val logicalPlan = sqlContext.catalog.lookupRelation(tableIdent.toSeq) // Use lookupCachedData directly since RefreshTable also takes databaseName. val isCached = sqlContext.cacheManager.lookupCachedData(logicalPlan).nonEmpty if (isCached) { @@ -461,7 +462,7 @@ private[sql] case class RefreshTable(databaseName: String, tableName: String) // Uncache the logicalPlan. sqlContext.cacheManager.tryUncacheQuery(df, blocking = true) // Cache it again. - sqlContext.cacheManager.cacheQuery(df, Some(tableName)) + sqlContext.cacheManager.cacheQuery(df, Some(tableIdent.table)) } Seq.empty[Row] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala index eb15a1609f1d0a2b05a0f1d6c01509e0d4b833b7..64e94056f209a07edc8a9b6a7d2296018c617245 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -22,6 +22,7 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.{DataFrame, SaveMode} @@ -32,8 +33,7 @@ import org.apache.spark.sql.{DataFrame, SaveMode} * convenient to use tuples rather than special case classes when writing test cases/suites. * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ -private[sql] trait ParquetTest extends SQLTestUtils { - +private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite => /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` * returns. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index fa01823e9417cdc62d8ebd55ce4778e7fd2af486..4c11acdab9ec03b0c34e2cf63dd72ca70cb34b3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -18,13 +18,15 @@ package org.apache.spark.sql.test import java.io.File +import java.util.UUID import scala.util.Try +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils -trait SQLTestUtils { +trait SQLTestUtils { this: SparkFunSuite => def sqlContext: SQLContext protected def configuration = sqlContext.sparkContext.hadoopConfiguration @@ -87,4 +89,29 @@ trait SQLTestUtils { } } } + + /** + * Creates a temporary database and switches current database to it before executing `f`. This + * database is dropped after `f` returns. + */ + protected def withTempDatabase(f: String => Unit): Unit = { + val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" + + try { + sqlContext.sql(s"CREATE DATABASE $dbName") + } catch { case cause: Throwable => + fail("Failed to create temporary database", cause) + } + + try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE") + } + + /** + * Activates database `db` before executing `f`, then switches back to `default` database after + * `f` returns. + */ + protected def activateDatabase(db: String)(f: => Unit): Unit = { + sqlContext.sql(s"USE $db") + try f finally sqlContext.sql(s"USE default") + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 1b8edefef40936bb972c39234481083f2585a85b..110f51a305861995a32a736270c7e7ecdd91523a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -40,7 +40,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry._ -import org.apache.spark.sql.catalyst.ParserDialect +import org.apache.spark.sql.catalyst.{TableIdentifier, ParserDialect} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand} @@ -267,7 +267,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - catalog.refreshTable(catalog.client.currentDatabase, tableName) + val tableIdent = TableIdentifier(tableName).withDatabase(catalog.client.currentDatabase) + catalog.refreshTable(tableIdent) } protected[hive] def invalidateTable(tableName: String): Unit = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 262923531216f61734f4efe787356770639f56ad..9c707a7a2eca1f391eb7d45049981264db392052 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -29,13 +29,13 @@ import org.apache.hadoop.hive.ql.metadata._ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} import org.apache.spark.sql.execution.datasources import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.hive.client._ @@ -43,7 +43,6 @@ import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} - private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: HiveContext) extends Catalog with Logging { @@ -115,7 +114,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive CacheBuilder.newBuilder().maximumSize(1000).build(cacheLoader) } - override def refreshTable(databaseName: String, tableName: String): Unit = { + override def refreshTable(tableIdent: TableIdentifier): Unit = { // refreshTable does not eagerly reload the cache. It just invalidate the cache. // Next time when we use the table, it will be populated in the cache. // Since we also cache ParquetRelations converted from Hive Parquet tables and @@ -124,7 +123,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // it is better at here to invalidate the cache to avoid confusing waring logs from the // cache loader (e.g. cannot find data source provider, which is only defined for // data source table.). - invalidateTable(databaseName, tableName) + invalidateTable(tableIdent.database.getOrElse(client.currentDatabase), tableIdent.table) } def invalidateTable(databaseName: String, tableName: String): Unit = { @@ -144,7 +143,27 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive provider: String, options: Map[String, String], isExternal: Boolean): Unit = { - val (dbName, tblName) = processDatabaseAndTableName(client.currentDatabase, tableName) + createDataSourceTable( + new SqlParser().parseTableIdentifier(tableName), + userSpecifiedSchema, + partitionColumns, + provider, + options, + isExternal) + } + + private def createDataSourceTable( + tableIdent: TableIdentifier, + userSpecifiedSchema: Option[StructType], + partitionColumns: Array[String], + provider: String, + options: Map[String, String], + isExternal: Boolean): Unit = { + val (dbName, tblName) = { + val database = tableIdent.database.getOrElse(client.currentDatabase) + processDatabaseAndTableName(database, tableIdent.table) + } + val tableProperties = new scala.collection.mutable.HashMap[String, String] tableProperties.put("spark.sql.sources.provider", provider) @@ -177,7 +196,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // partitions when we load the table. However, if there are specified partition columns, // we simplily ignore them and provide a warning message.. logWarning( - s"The schema and partitions of table $tableName will be inferred when it is loaded. " + + s"The schema and partitions of table $tableIdent will be inferred when it is loaded. " + s"Specified partition columns (${partitionColumns.mkString(",")}) will be ignored.") } Seq.empty[HiveColumn] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..73852f13ad20d1a804456659dfd1ce2a5d7a3f3f --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{QueryTest, SQLContext, SaveMode} + +class MultiDatabaseSuite extends QueryTest with SQLTestUtils { + override val sqlContext: SQLContext = TestHive + + import sqlContext.sql + + private val df = sqlContext.range(10).coalesce(1) + + test(s"saveAsTable() to non-default database - with USE - Overwrite") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + checkAnswer(sqlContext.table("t"), df) + } + + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df) + } + } + + test(s"saveAsTable() to non-default database - without USE - Overwrite") { + withTempDatabase { db => + df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df) + } + } + + test(s"saveAsTable() to non-default database - with USE - Append") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + df.write.mode(SaveMode.Append).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + checkAnswer(sqlContext.table("t"), df.unionAll(df)) + } + + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + } + } + + test(s"saveAsTable() to non-default database - without USE - Append") { + withTempDatabase { db => + df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") + df.write.mode(SaveMode.Append).saveAsTable(s"$db.t") + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + } + } + + test(s"insertInto() non-default database - with USE") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + + df.write.insertInto(s"$db.t") + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + } + } + } + + test(s"insertInto() non-default database - without USE") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + } + + assert(sqlContext.tableNames(db).contains("t")) + + df.write.insertInto(s"$db.t") + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + } + } + + test("Looks up tables in non-default database") { + withTempDatabase { db => + activateDatabase(db) { + sql("CREATE TABLE t (key INT)") + checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) + } + + checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame) + } + } + + test("Drops a table in a non-default database") { + withTempDatabase { db => + activateDatabase(db) { + sql(s"CREATE TABLE t (key INT)") + assert(sqlContext.tableNames().contains("t")) + assert(!sqlContext.tableNames("default").contains("t")) + } + + assert(!sqlContext.tableNames().contains("t")) + assert(sqlContext.tableNames(db).contains("t")) + + activateDatabase(db) { + sql(s"DROP TABLE t") + assert(!sqlContext.tableNames().contains("t")) + assert(!sqlContext.tableNames("default").contains("t")) + } + + assert(!sqlContext.tableNames().contains("t")) + assert(!sqlContext.tableNames(db).contains("t")) + } + } + + test("Refreshes a table in a non-default database") { + import org.apache.spark.sql.functions.lit + + withTempDatabase { db => + withTempPath { dir => + val path = dir.getCanonicalPath + + activateDatabase(db) { + sql( + s"""CREATE EXTERNAL TABLE t (id BIGINT) + |PARTITIONED BY (p INT) + |STORED AS PARQUET + |LOCATION '$path' + """.stripMargin) + + checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) + + df.write.parquet(s"$path/p=1") + sql("ALTER TABLE t ADD PARTITION (p=1)") + sql("REFRESH TABLE t") + checkAnswer(sqlContext.table("t"), df.withColumn("p", lit(1))) + } + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 9d76d6503a3e6e4d57feaf93ffb7f7ba5841e7b3..145965388da01229d265b7565254e8f7d186976d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -22,14 +22,15 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.test.SQLTestUtils -private[sql] trait OrcTest extends SQLTestUtils { +private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite => lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive - import sqlContext.sparkContext import sqlContext.implicits._ + import sqlContext.sparkContext /** * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f`