diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index aa77a6efef347add5ca255b5595dde59e12788e1..65a2a7b04dd8f496c8eb916590c3f87bd8876a27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -117,66 +117,72 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis(s"Window specification $s is not valid because $m") case None => w } - case s @ ScalarSubquery(query, conditions, _) + + case s @ ScalarSubquery(query, conditions, _) => // If no correlation, the output must be exactly one column - if (conditions.isEmpty && query.output.size != 1) => + if (conditions.isEmpty && query.output.size != 1) { failAnalysis( s"Scalar subquery must return only one column, but got ${query.output.size}") + } + else if (conditions.nonEmpty) { + // Collect the columns from the subquery for further checking. + var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains) + + def checkAggregate(agg: Aggregate): Unit = { + // Make sure correlated scalar subqueries contain one row for every outer row by + // enforcing that they are aggregates containing exactly one aggregate expression. + // The analyzer has already checked that subquery contained only one output column, + // and added all the grouping expressions to the aggregate. + val aggregates = agg.expressions.flatMap(_.collect { + case a: AggregateExpression => a + }) + if (aggregates.isEmpty) { + failAnalysis("The output of a correlated scalar subquery must be aggregated") + } - case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty => - - // Collect the columns from the subquery for further checking. - var subqueryColumns = conditions.flatMap(_.references).filter(query.output.contains) - - def checkAggregate(agg: Aggregate): Unit = { - // Make sure correlated scalar subqueries contain one row for every outer row by - // enforcing that they are aggregates which contain exactly one aggregate expressions. - // The analyzer has already checked that subquery contained only one output column, - // and added all the grouping expressions to the aggregate. - val aggregates = agg.expressions.flatMap(_.collect { - case a: AggregateExpression => a - }) - if (aggregates.isEmpty) { - failAnalysis("The output of a correlated scalar subquery must be aggregated") + // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns + // are not part of the correlated columns. + val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) + val correlatedCols = AttributeSet(subqueryColumns) + val invalidCols = groupByCols -- correlatedCols + // GROUP BY columns must be a subset of columns in the predicates + if (invalidCols.nonEmpty) { + failAnalysis( + "A GROUP BY clause in a scalar correlated subquery " + + "cannot contain non-correlated columns: " + + invalidCols.mkString(",")) + } } - // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns - // are not part of the correlated columns. - val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) - val correlatedCols = AttributeSet(subqueryColumns) - val invalidCols = groupByCols -- correlatedCols - // GROUP BY columns must be a subset of columns in the predicates - if (invalidCols.nonEmpty) { - failAnalysis( - "A GROUP BY clause in a scalar correlated subquery " + - "cannot contain non-correlated columns: " + - invalidCols.mkString(",")) - } - } + // Skip subquery aliases added by the Analyzer and the SQLBuilder. + // For projects, do the necessary mapping and skip to its child. + def cleanQuery(p: LogicalPlan): LogicalPlan = p match { + case s: SubqueryAlias => cleanQuery(s.child) + case p: Project => + // SPARK-18814: Map any aliases to their AttributeReference children + // for the checking in the Aggregate operators below this Project. + subqueryColumns = subqueryColumns.map { + xs => p.projectList.collectFirst { + case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId => + child + }.getOrElse(xs) + } - // Skip subquery aliases added by the Analyzer and the SQLBuilder. - // For projects, do the necessary mapping and skip to its child. - def cleanQuery(p: LogicalPlan): LogicalPlan = p match { - case s: SubqueryAlias => cleanQuery(s.child) - case p: Project => - // SPARK-18814: Map any aliases to their AttributeReference children - // for the checking in the Aggregate operators below this Project. - subqueryColumns = subqueryColumns.map { - xs => p.projectList.collectFirst { - case e @ Alias(child : AttributeReference, _) if e.exprId == xs.exprId => - child - }.getOrElse(xs) - } + cleanQuery(p.child) + case child => child + } - cleanQuery(p.child) - case child => child + cleanQuery(query) match { + case a: Aggregate => checkAggregate(a) + case Filter(_, a: Aggregate) => checkAggregate(a) + case fail => failAnalysis(s"Correlated scalar subqueries must be Aggregated: $fail") + } } + checkAnalysis(query) + s - cleanQuery(query) match { - case a: Aggregate => checkAggregate(a) - case Filter(_, a: Aggregate) => checkAggregate(a) - case fail => failAnalysis(s"Correlated scalar subqueries must be Aggregated: $fail") - } + case s: SubqueryExpression => + checkAnalysis(s.plan) s } diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql new file mode 100644 index 0000000000000000000000000000000000000000..cf93c5a835971c2f4e1098da47b19ab1784608a3 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql @@ -0,0 +1,42 @@ +-- The test file contains negative test cases +-- of invalid queries where error messages are expected. + +create temporary view t1 as select * from values + (1, 2, 3) +as t1(t1a, t1b, t1c); + +create temporary view t2 as select * from values + (1, 0, 1) +as t2(t2a, t2b, t2c); + +create temporary view t3 as select * from values + (3, 1, 2) +as t3(t3a, t3b, t3c); + +-- TC 01.01 +-- The column t2b in the SELECT of the subquery is invalid +-- because it is neither an aggregate function nor a GROUP BY column. +select t1a, t2b +from t1, t2 +where t1b = t2c +and t2b = (select max(avg) + from (select t2b, avg(t2b) avg + from t2 + where t2a = t1.t1b + ) + ) +; + +-- TC 01.02 +-- Invalid due to the column t2b not part of the output from table t2. +select * +from t1 +where t1a in (select min(t2a) + from t2 + group by t2c + having t2c in (select max(t3c) + from t3 + group by t3b + having t3b > t2b )) +; + diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out new file mode 100644 index 0000000000000000000000000000000000000000..50ae01e181bcf2d9953dff8174e7d6e0eecddf3b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out @@ -0,0 +1,66 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 5 + + +-- !query 0 +create temporary view t1 as select * from values + (1, 2, 3) +as t1(t1a, t1b, t1c) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +create temporary view t2 as select * from values + (1, 0, 1) +as t2(t2a, t2b, t2c) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +create temporary view t3 as select * from values + (3, 1, 2) +as t3(t3a, t3b, t3c) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +select t1a, t2b +from t1, t2 +where t1b = t2c +and t2b = (select max(avg) + from (select t2b, avg(t2b) avg + from t2 + where t2a = t1.t1b + ) + ) +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +expression 't2.`t2b`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; + + +-- !query 4 +select * +from t1 +where t1a in (select min(t2a) + from t2 + group by t2c + having t2c in (select max(t3c) + from t3 + group by t3b + having t3b > t2b )) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter predicate-subquery#x [(t2c#x = max(t3c)#x) && (t3b#x > t2b#x)]; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index fdf940a7f9504ea59680d2a45b15140a329839df..91aecca537fb27ec43fe771c336998de521ad7b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -228,7 +228,10 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { } catch { case a: AnalysisException if a.plan.nonEmpty => // Do not output the logical plan tree which contains expression IDs. - (StructType(Seq.empty), Seq(a.getClass.getName, a.getSimpleMessage)) + // Also implement a crude way of masking expression IDs in the error message + // with a generic pattern "###". + (StructType(Seq.empty), + Seq(a.getClass.getName, a.getSimpleMessage.replaceAll("#\\d+", "#x"))) case NonFatal(e) => // If there is an exception, put the exception class followed by the message. (StructType(Seq.empty), Seq(e.getClass.getName, e.getMessage))