diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index d88b5ffc0511ceae6948ad073d5b765b4b60a9c7..c0ebb2b1fa1ea694ebd882afd4a94b3b46ecd343 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -22,7 +22,7 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException @@ -253,9 +253,27 @@ class SessionCatalog( def getTableMetadata(name: TableIdentifier): CatalogTable = { val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase)) val table = formatTableName(name.table) - requireDbExists(db) - requireTableExists(TableIdentifier(table, Some(db))) - externalCatalog.getTable(db, table) + val tid = TableIdentifier(table) + if (isTemporaryTable(name)) { + CatalogTable( + identifier = tid, + tableType = CatalogTableType.VIEW, + storage = CatalogStorageFormat.empty, + schema = tempTables(table).output.map { c => + CatalogColumn( + name = c.name, + dataType = c.dataType.catalogString, + nullable = c.nullable, + comment = Option(c.name) + ) + }, + properties = Map(), + viewText = None) + } else { + requireDbExists(db) + requireTableExists(TableIdentifier(table, Some(db))) + externalCatalog.getTable(db, table) + } } /** @@ -432,10 +450,10 @@ class SessionCatalog( def tableExists(name: TableIdentifier): Boolean = synchronized { val db = formatDatabaseName(name.database.getOrElse(currentDb)) val table = formatTableName(name.table) - if (name.database.isDefined || !tempTables.contains(table)) { - externalCatalog.tableExists(db, table) + if (isTemporaryTable(name)) { + true } else { - true // it's a temporary table + externalCatalog.tableExists(db, table) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 05eb302c3c03ab972a52fa80082db8eda22711af..adce5df81cb7f6a0b7cb06a320b274e773ed5f1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -432,6 +432,39 @@ class SessionCatalogSuite extends SparkFunSuite { assert(catalog.tableExists(TableIdentifier("tbl3"))) } + test("tableExists on temporary views") { + val catalog = new SessionCatalog(newBasicCatalog()) + val tempTable = Range(1, 10, 2, 10) + assert(!catalog.tableExists(TableIdentifier("view1"))) + assert(!catalog.tableExists(TableIdentifier("view1", Some("default")))) + catalog.createTempView("view1", tempTable, overrideIfExists = false) + assert(catalog.tableExists(TableIdentifier("view1"))) + assert(!catalog.tableExists(TableIdentifier("view1", Some("default")))) + } + + test("getTableMetadata on temporary views") { + val catalog = new SessionCatalog(newBasicCatalog()) + val tempTable = Range(1, 10, 2, 10) + val m = intercept[AnalysisException] { + catalog.getTableMetadata(TableIdentifier("view1")) + }.getMessage + assert(m.contains("Table or view 'view1' not found in database 'default'")) + + val m2 = intercept[AnalysisException] { + catalog.getTableMetadata(TableIdentifier("view1", Some("default"))) + }.getMessage + assert(m2.contains("Table or view 'view1' not found in database 'default'")) + + catalog.createTempView("view1", tempTable, overrideIfExists = false) + assert(catalog.getTableMetadata(TableIdentifier("view1")).identifier.table == "view1") + assert(catalog.getTableMetadata(TableIdentifier("view1")).schema(0).name == "id") + + val m3 = intercept[AnalysisException] { + catalog.getTableMetadata(TableIdentifier("view1", Some("default"))) + }.getMessage + assert(m3.contains("Table or view 'view1' not found in database 'default'")) + } + test("list tables without pattern") { val catalog = new SessionCatalog(newBasicCatalog()) val tempTable = Range(1, 10, 2, 10) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 91ed9b3258a12c537c5bc85d4f3e5b49b7725cae..1aed245fdd332dc8528c212b6fa3cd84e5cb6a9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -85,7 +85,8 @@ abstract class Catalog { def listFunctions(dbName: String): Dataset[Function] /** - * Returns a list of columns for the given table in the current database. + * Returns a list of columns for the given table in the current database or + * the given temporary table. * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 44babcc93a1deaf9a64b4268e209d1381dd3f5a0..a6ae6fe2aad2e60fe6f9bd5f6e291dcefd6dba05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -138,7 +138,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ @throws[AnalysisException]("table does not exist") override def listColumns(tableName: String): Dataset[Column] = { - listColumns(currentDatabase, tableName) + listColumns(TableIdentifier(tableName, None)) } /** @@ -147,7 +147,11 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { @throws[AnalysisException]("database or table does not exist") override def listColumns(dbName: String, tableName: String): Dataset[Column] = { requireTableExists(dbName, tableName) - val tableMetadata = sessionCatalog.getTableMetadata(TableIdentifier(tableName, Some(dbName))) + listColumns(TableIdentifier(tableName, Some(dbName))) + } + + private def listColumns(tableIdentifier: TableIdentifier): Dataset[Column] = { + val tableMetadata = sessionCatalog.getTableMetadata(tableIdentifier) val partitionColumnNames = tableMetadata.partitionColumnNames.toSet val bucketColumnNames = tableMetadata.bucketColumnNames.toSet val columns = tableMetadata.schema.map { c => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index d862e4cfa943a0b344241f7da0a01b29eda4cfc6..d75df56dd608aa6405374c33c85cf410b7f1c14f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -234,6 +234,11 @@ class CatalogSuite testListColumns("tab1", dbName = None) } + test("list columns in temporary table") { + createTempTable("temp1") + spark.catalog.listColumns("temp1") + } + test("list columns in database") { createDatabase("db1") createTable("tab1", Some("db1"))