diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 2d2120dda8bde73cf0e8091009f22b59e65ed14e..c8b61d8df35857340045227455af4596fbe94179 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -923,6 +923,24 @@ class SessionCatalog( } } + /** + * Returns whether it is a temporary function. If not existed, returns false. + */ + def isTemporaryFunction(name: FunctionIdentifier): Boolean = { + // copied from HiveSessionCatalog + val hiveFunctions = Seq( + "hash", + "histogram_numeric", + "percentile") + + // A temporary function is a function that has been registered in functionRegistry + // without a database name, and is neither a built-in function nor a Hive function + name.database.isEmpty && + functionRegistry.functionExists(name.funcName) && + !FunctionRegistry.builtin.functionExists(name.funcName) && + !hiveFunctions.contains(name.funcName.toLowerCase) + } + protected def failFunctionLookup(name: String): Nothing = { throw new NoSuchFunctionException(db = currentDb, func = name) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index b77fef225a0c8d3d7c9d99fb2fc448917cb35499..001d9c47785d236efa3f395c8d5ecfff994984d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -919,6 +919,34 @@ class SessionCatalogSuite extends SparkFunSuite { catalog.lookupFunction(FunctionIdentifier("temp1"), arguments) === Literal(arguments.length)) } + test("isTemporaryFunction") { + val externalCatalog = newBasicCatalog() + val sessionCatalog = new SessionCatalog(externalCatalog) + + // Returns false when the function does not exist + assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("temp1"))) + + val tempFunc1 = (e: Seq[Expression]) => e.head + val info1 = new ExpressionInfo("tempFunc1", "temp1") + sessionCatalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) + + // Returns true when the function is temporary + assert(sessionCatalog.isTemporaryFunction(FunctionIdentifier("temp1"))) + + // Returns false when the function is permanent + assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) + assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("func1", Some("db2")))) + assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("db2.func1"))) + sessionCatalog.setCurrentDatabase("db2") + assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("func1"))) + + // Returns false when the function is built-in or hive + assert(FunctionRegistry.builtin.functionExists("sum")) + assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("sum"))) + assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("histogram_numeric"))) + assert(!sessionCatalog.isTemporaryFunction(FunctionIdentifier("percentile"))) + } + test("drop function") { val externalCatalog = newBasicCatalog() val sessionCatalog = new SessionCatalog(externalCatalog) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index bbcd9c4ef564cb70f93d7741e399dfbd8f0dfd64..30472ec45ce443d3103688f6e15dbc8cf9bc60f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.execution.command import scala.util.control.NonFatal -import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.{SQLBuilder, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.execution.datasources.{DataSource, LogicalRelation} -import org.apache.spark.sql.types.{MetadataBuilder, StructType} +import org.apache.spark.sql.types.MetadataBuilder /** @@ -131,6 +131,10 @@ case class CreateViewCommand( s"specified by CREATE VIEW (num: `${userSpecifiedColumns.length}`).") } + // When creating a permanent view, not allowed to reference temporary objects. + // This should be called after `qe.assertAnalyzed()` (i.e., `child` can be resolved) + verifyTemporaryObjectsNotExists(sparkSession) + val aliasedPlan = if (userSpecifiedColumns.isEmpty) { analyzedPlan } else { @@ -172,6 +176,34 @@ case class CreateViewCommand( Seq.empty[Row] } + /** + * Permanent views are not allowed to reference temp objects, including temp function and views + */ + private def verifyTemporaryObjectsNotExists(sparkSession: SparkSession): Unit = { + if (!isTemporary) { + // This func traverses the unresolved plan `child`. Below are the reasons: + // 1) Analyzer replaces unresolved temporary views by a SubqueryAlias with the corresponding + // logical plan. After replacement, it is impossible to detect whether the SubqueryAlias is + // added/generated from a temporary view. + // 2) The temp functions are represented by multiple classes. Most are inaccessible from this + // package (e.g., HiveGenericUDF). + child.collect { + // Disallow creating permanent views based on temporary views. + case s: UnresolvedRelation + if sparkSession.sessionState.catalog.isTemporaryTable(s.tableIdentifier) => + throw new AnalysisException(s"Not allowed to create a permanent view $name by " + + s"referencing a temporary view ${s.tableIdentifier}") + case other if !other.resolved => other.expressions.flatMap(_.collect { + // Disallow creating permanent views based on temporary UDFs. + case e: UnresolvedFunction + if sparkSession.sessionState.catalog.isTemporaryFunction(e.name) => + throw new AnalysisException(s"Not allowed to create a permanent view $name by " + + s"referencing a temporary function `${e.name}`") + }) + } + } + } + /** * Returns a [[CatalogTable]] that can be used to save in the catalog. This comment canonicalize * SQL based on the analyzed plan, and also creates the proper schema for the view. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 9df20ce1553ecd7360961292aa2d65d713a3f298..4a9b28a455a44777a5da6d3382d5fa76e027cf85 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -232,6 +232,7 @@ private[sql] class HiveSessionCatalog( // current_user, ewah_bitmap, ewah_bitmap_and, ewah_bitmap_empty, ewah_bitmap_or, field, // in_file, index, matchpath, ngrams, noop, noopstreaming, noopwithmap, // noopwithmapstreaming, parse_url_tuple, reflect2, windowingtablefunction. + // Note: don't forget to update SessionCatalog.isTemporaryFunction private val hiveFunctions = Seq( "histogram_numeric", "percentile" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala index 2af935da689c919d8010c01f760019ba08d1bfa3..ba65db71ede7f8e13ca27e0bf932dd1172f66345 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLViewSuite.scala @@ -38,21 +38,46 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { spark.sql(s"DROP TABLE IF EXISTS jt") } - test("nested views (interleaved with temporary views)") { - withView("jtv1", "jtv2", "jtv3", "temp_jtv1", "temp_jtv2", "temp_jtv3") { + test("create a permanent view on a permanent view") { + withView("jtv1", "jtv2") { sql("CREATE VIEW jtv1 AS SELECT * FROM jt WHERE id > 3") sql("CREATE VIEW jtv2 AS SELECT * FROM jtv1 WHERE id < 6") checkAnswer(sql("select count(*) FROM jtv2"), Row(2)) + } + } - // Checks temporary views + test("create a temp view on a permanent view") { + withView("jtv1", "temp_jtv1") { + sql("CREATE VIEW jtv1 AS SELECT * FROM jt WHERE id > 3") + sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jtv1 WHERE id < 6") + checkAnswer(sql("select count(*) FROM temp_jtv1"), Row(2)) + } + } + + test("create a temp view on a temp view") { + withView("temp_jtv1", "temp_jtv2") { sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jt WHERE id > 3") sql("CREATE TEMPORARY VIEW temp_jtv2 AS SELECT * FROM temp_jtv1 WHERE id < 6") checkAnswer(sql("select count(*) FROM temp_jtv2"), Row(2)) + } + } + + test("create a permanent view on a temp view") { + withView("jtv1", "temp_jtv1", "global_temp_jtv1") { + sql("CREATE TEMPORARY VIEW temp_jtv1 AS SELECT * FROM jt WHERE id > 3") + var e = intercept[AnalysisException] { + sql("CREATE VIEW jtv1 AS SELECT * FROM temp_jtv1 WHERE id < 6") + }.getMessage + assert(e.contains("Not allowed to create a permanent view `jtv1` by " + + "referencing a temporary view `temp_jtv1`")) - // Checks interleaved temporary view and normal view - sql("CREATE TEMPORARY VIEW temp_jtv3 AS SELECT * FROM jt WHERE id > 3") - sql("CREATE VIEW jtv3 AS SELECT * FROM temp_jtv3 WHERE id < 6") - checkAnswer(sql("select count(*) FROM jtv3"), Row(2)) + val globalTempDB = spark.sharedState.globalTempViewManager.database + sql("CREATE GLOBAL TEMP VIEW global_temp_jtv1 AS SELECT * FROM jt WHERE id > 0") + e = intercept[AnalysisException] { + sql(s"CREATE VIEW jtv1 AS SELECT * FROM $globalTempDB.global_temp_jtv1 WHERE id < 6") + }.getMessage + assert(e.contains(s"Not allowed to create a permanent view `jtv1` by referencing " + + s"a temporary view `global_temp`.`global_temp_jtv1`")) } } @@ -439,7 +464,7 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } - test("SPARK-14933 - create view from hive parquet tabale") { + test("SPARK-14933 - create view from hive parquet table") { withTable("t_part") { withView("v_part") { spark.sql("create table t_part stored as parquet as select 1 as a, 2 as b") @@ -451,7 +476,7 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } - test("SPARK-14933 - create view from hive orc tabale") { + test("SPARK-14933 - create view from hive orc table") { withTable("t_orc") { withView("v_orc") { spark.sql("create table t_orc stored as orc as select 1 as a, 2 as b") @@ -462,4 +487,60 @@ class SQLViewSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test("create a permanent/temp view using a hive, built-in, and permanent user function") { + val permanentFuncName = "myUpper" + val permanentFuncClass = + classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDFUpper].getCanonicalName + val builtInFuncNameInLowerCase = "abs" + val builtInFuncNameInMixedCase = "aBs" + val hiveFuncName = "histogram_numeric" + + withUserDefinedFunction(permanentFuncName -> false) { + sql(s"CREATE FUNCTION $permanentFuncName AS '$permanentFuncClass'") + withTable("tab1") { + (1 to 10).map(i => (s"$i", i)).toDF("str", "id").write.saveAsTable("tab1") + Seq("VIEW", "TEMPORARY VIEW").foreach { viewMode => + withView("view1") { + sql( + s""" + |CREATE $viewMode view1 + |AS SELECT + |$permanentFuncName(str), + |$builtInFuncNameInLowerCase(id), + |$builtInFuncNameInMixedCase(id) as aBs, + |$hiveFuncName(id, 5) over() + |FROM tab1 + """.stripMargin) + checkAnswer(sql("select count(*) FROM view1"), Row(10)) + } + } + } + } + } + + test("create a permanent/temp view using a temporary function") { + val tempFunctionName = "temp" + val functionClass = + classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDFUpper].getCanonicalName + withUserDefinedFunction(tempFunctionName -> true) { + sql(s"CREATE TEMPORARY FUNCTION $tempFunctionName AS '$functionClass'") + withView("view1", "tempView1") { + withTable("tab1") { + (1 to 10).map(i => s"$i").toDF("id").write.saveAsTable("tab1") + + // temporary view + sql(s"CREATE TEMPORARY VIEW tempView1 AS SELECT $tempFunctionName(id) from tab1") + checkAnswer(sql("select count(*) FROM tempView1"), Row(10)) + + // permanent view + val e = intercept[AnalysisException] { + sql(s"CREATE VIEW view1 AS SELECT $tempFunctionName(id) from tab1") + }.getMessage + assert(e.contains("Not allowed to create a permanent view `view1` by referencing " + + s"a temporary function `$tempFunctionName`")) + } + } + } + } }