diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 921c84895598c7699c9e07e66458aa38103ba2f9..00f0acab21aa2645bdd650686099f05e51b760f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.{Alias, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View} import org.apache.spark.sql.types.MetadataBuilder @@ -154,6 +154,10 @@ case class CreateViewCommand( } else if (tableMetadata.tableType != CatalogTableType.VIEW) { throw new AnalysisException(s"$name is not a view") } else if (replace) { + // Detect cyclic view reference on CREATE OR REPLACE VIEW. + val viewIdent = tableMetadata.identifier + checkCyclicViewReference(analyzedPlan, Seq(viewIdent), viewIdent) + // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` catalog.alterTable(prepareTable(sparkSession, analyzedPlan)) } else { @@ -283,6 +287,10 @@ case class AlterViewAsCommand( throw new AnalysisException(s"${viewMeta.identifier} is not a view.") } + // Detect cyclic view reference on ALTER VIEW. + val viewIdent = viewMeta.identifier + checkCyclicViewReference(analyzedPlan, Seq(viewIdent), viewIdent) + val newProperties = generateViewProperties(viewMeta.properties, session, analyzedPlan) val updatedViewMeta = viewMeta.copy( @@ -358,4 +366,53 @@ object ViewHelper { generateViewDefaultDatabase(viewDefaultDatabase) ++ generateQueryColumnNames(queryOutput) } + + /** + * Recursively search the logical plan to detect cyclic view references, throw an + * AnalysisException if cycle detected. + * + * A cyclic view reference is a cycle of reference dependencies, for example, if the following + * statements are executed: + * CREATE VIEW testView AS SELECT id FROM tbl + * CREATE VIEW testView2 AS SELECT id FROM testView + * ALTER VIEW testView AS SELECT * FROM testView2 + * The view `testView` references `testView2`, and `testView2` also references `testView`, + * therefore a reference cycle (testView -> testView2 -> testView) exists. + * + * @param plan the logical plan we detect cyclic view references from. + * @param path the path between the altered view and current node. + * @param viewIdent the table identifier of the altered view, we compare two views by the + * `desc.identifier`. + */ + def checkCyclicViewReference( + plan: LogicalPlan, + path: Seq[TableIdentifier], + viewIdent: TableIdentifier): Unit = { + plan match { + case v: View => + val ident = v.desc.identifier + val newPath = path :+ ident + // If the table identifier equals to the `viewIdent`, current view node is the same with + // the altered view. We detect a view reference cycle, should throw an AnalysisException. + if (ident == viewIdent) { + throw new AnalysisException(s"Recursive view $viewIdent detected " + + s"(cycle: ${newPath.mkString(" -> ")})") + } else { + v.children.foreach { child => + checkCyclicViewReference(child, newPath, viewIdent) + } + } + case _ => + plan.children.foreach(child => checkCyclicViewReference(child, path, viewIdent)) + } + + // Detect cyclic references from subqueries. + plan.expressions.foreach { expr => + expr match { + case s: SubqueryExpression => + checkCyclicViewReference(s.plan, path, viewIdent) + case _ => // Do nothing. + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 0e5a1dc6ab629484a46a34935b11fcc0da04aa88..2ca2206bb9d44fcc798a2e0a9363cc744a49ec96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -609,12 +609,39 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } } - // TODO: Check for cyclic view references on ALTER VIEW. - ignore("correctly handle a cyclic view reference") { - withView("view1", "view2") { + test("correctly handle a cyclic view reference") { + withView("view1", "view2", "view3") { sql("CREATE VIEW view1 AS SELECT * FROM jt") sql("CREATE VIEW view2 AS SELECT * FROM view1") - intercept[AnalysisException](sql("ALTER VIEW view1 AS SELECT * FROM view2")) + sql("CREATE VIEW view3 AS SELECT * FROM view2") + + // Detect cyclic view reference on ALTER VIEW. + val e1 = intercept[AnalysisException] { + sql("ALTER VIEW view1 AS SELECT * FROM view2") + }.getMessage + assert(e1.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " + + "-> `default`.`view2` -> `default`.`view1`)")) + + // Detect the most left cycle when there exists multiple cyclic view references. + val e2 = intercept[AnalysisException] { + sql("ALTER VIEW view1 AS SELECT * FROM view3 JOIN view2") + }.getMessage + assert(e2.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " + + "-> `default`.`view3` -> `default`.`view2` -> `default`.`view1`)")) + + // Detect cyclic view reference on CREATE OR REPLACE VIEW. + val e3 = intercept[AnalysisException] { + sql("CREATE OR REPLACE VIEW view1 AS SELECT * FROM view2") + }.getMessage + assert(e3.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " + + "-> `default`.`view2` -> `default`.`view1`)")) + + // Detect cyclic view reference from subqueries. + val e4 = intercept[AnalysisException] { + sql("ALTER VIEW view1 AS SELECT * FROM jt WHERE EXISTS (SELECT 1 FROM view2)") + }.getMessage + assert(e4.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " + + "-> `default`.`view2` -> `default`.`view1`)")) } } }