diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 647fc0b9342c1908a388d46691559140947c0240..193082eb77024c09843ae6fd9132c19b887a504d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects} import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ -import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ @@ -1257,217 +1256,16 @@ class Analyzer( } /** - * Validates to make sure the outer references appearing inside the subquery - * are legal. This function also returns the list of expressions - * that contain outer references. These outer references would be kept as children - * of subquery expressions by the caller of this function. - */ - private def checkAndGetOuterReferences(sub: LogicalPlan): Seq[Expression] = { - val outerReferences = ArrayBuffer.empty[Expression] - - // Validate that correlated aggregate expression do not contain a mixture - // of outer and local references. - def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = { - expr.foreach { - case a: AggregateExpression if containsOuter(a) => - val outer = a.collect { case OuterReference(e) => e.toAttribute } - val local = a.references -- outer - if (local.nonEmpty) { - val msg = - s""" - |Found an aggregate expression in a correlated predicate that has both - |outer and local references, which is not supported yet. - |Aggregate expression: ${SubExprUtils.stripOuterReference(a).sql}, - |Outer references: ${outer.map(_.sql).mkString(", ")}, - |Local references: ${local.map(_.sql).mkString(", ")}. - """.stripMargin.replace("\n", " ").trim() - failAnalysis(msg) - } - case _ => - } - } - - // Make sure a plan's subtree does not contain outer references - def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { - if (hasOuterReferences(p)) { - failAnalysis(s"Accessing outer query column is not allowed in:\n$p") - } - } - - // Make sure a plan's expressions do not contain : - // 1. Aggregate expressions that have mixture of outer and local references. - // 2. Expressions containing outer references on plan nodes other than Filter. - def failOnInvalidOuterReference(p: LogicalPlan): Unit = { - p.expressions.foreach(checkMixedReferencesInsideAggregateExpr) - if (!p.isInstanceOf[Filter] && p.expressions.exists(containsOuter)) { - failAnalysis( - "Expressions referencing the outer query are not supported outside of WHERE/HAVING " + - s"clauses:\n$p") - } - } - - // SPARK-17348: A potential incorrect result case. - // When a correlated predicate is a non-equality predicate, - // certain operators are not permitted from the operator - // hosting the correlated predicate up to the operator on the outer table. - // Otherwise, the pull up of the correlated predicate - // will generate a plan with a different semantics - // which could return incorrect result. - // Currently we check for Aggregate and Window operators - // - // Below shows an example of a Logical Plan during Analyzer phase that - // show this problem. Pulling the correlated predicate [outer(c2#77) >= ..] - // through the Aggregate (or Window) operator could alter the result of - // the Aggregate. - // - // Project [c1#76] - // +- Project [c1#87, c2#88] - // : (Aggregate or Window operator) - // : +- Filter [outer(c2#77) >= c2#88)] - // : +- SubqueryAlias t2, `t2` - // : +- Project [_1#84 AS c1#87, _2#85 AS c2#88] - // : +- LocalRelation [_1#84, _2#85] - // +- SubqueryAlias t1, `t1` - // +- Project [_1#73 AS c1#76, _2#74 AS c2#77] - // +- LocalRelation [_1#73, _2#74] - def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = { - if (found) { - // Report a non-supported case as an exception - failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p") - } - } - - var foundNonEqualCorrelatedPred : Boolean = false - - // Simplify the predicates before validating any unsupported correlation patterns - // in the plan. - BooleanSimplification(sub).foreachUp { - - // Whitelist operators allowed in a correlated subquery - // There are 4 categories: - // 1. Operators that are allowed anywhere in a correlated subquery, and, - // by definition of the operators, they either do not contain - // any columns or cannot host outer references. - // 2. Operators that are allowed anywhere in a correlated subquery - // so long as they do not host outer references. - // 3. Operators that need special handlings. These operators are - // Project, Filter, Join, Aggregate, and Generate. - // - // Any operators that are not in the above list are allowed - // in a correlated subquery only if they are not on a correlation path. - // In other word, these operators are allowed only under a correlation point. - // - // A correlation path is defined as the sub-tree of all the operators that - // are on the path from the operator hosting the correlated expressions - // up to the operator producing the correlated values. - - // Category 1: - // BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias - case _: ResolvedHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => - - // Category 2: - // These operators can be anywhere in a correlated subquery. - // so long as they do not host outer references in the operators. - case s: Sort => - failOnInvalidOuterReference(s) - case r: RepartitionByExpression => - failOnInvalidOuterReference(r) - - // Category 3: - // Filter is one of the two operators allowed to host correlated expressions. - // The other operator is Join. Filter can be anywhere in a correlated subquery. - case f: Filter => - // Find all predicates with an outer reference. - val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter) - - // Find any non-equality correlated predicates - foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { - case _: EqualTo | _: EqualNullSafe => false - case _ => true - } - - failOnInvalidOuterReference(f) - // The aggregate expressions are treated in a special way by getOuterReferences. If the - // aggregate expression contains only outer reference attributes then the entire aggregate - // expression is isolated as an OuterReference. - // i.e min(OuterReference(b)) => OuterReference(min(b)) - outerReferences ++= getOuterReferences(correlated) - - // Project cannot host any correlated expressions - // but can be anywhere in a correlated subquery. - case p: Project => - failOnInvalidOuterReference(p) - - // Aggregate cannot host any correlated expressions - // It can be on a correlation path if the correlation contains - // only equality correlated predicates. - // It cannot be on a correlation path if the correlation has - // non-equality correlated predicates. - case a: Aggregate => - failOnInvalidOuterReference(a) - failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) - - // Join can host correlated expressions. - case j @ Join(left, right, joinType, _) => - joinType match { - // Inner join, like Filter, can be anywhere. - case _: InnerLike => - failOnInvalidOuterReference(j) - - // Left outer join's right operand cannot be on a correlation path. - // LeftAnti and ExistenceJoin are special cases of LeftOuter. - // Note that ExistenceJoin cannot be expressed externally in both SQL and DataFrame - // so it should not show up here in Analysis phase. This is just a safety net. - // - // LeftSemi does not allow output from the right operand. - // Any correlated references in the subplan - // of the right operand cannot be pulled up. - case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => - failOnInvalidOuterReference(j) - failOnOuterReferenceInSubTree(right) - - // Likewise, Right outer join's left operand cannot be on a correlation path. - case RightOuter => - failOnInvalidOuterReference(j) - failOnOuterReferenceInSubTree(left) - - // Any other join types not explicitly listed above, - // including Full outer join, are treated as Category 4. - case _ => - failOnOuterReferenceInSubTree(j) - } - - // Generator with join=true, i.e., expressed with - // LATERAL VIEW [OUTER], similar to inner join, - // allows to have correlation under it - // but must not host any outer references. - // Note: - // Generator with join=false is treated as Category 4. - case g: Generate if g.join => - failOnInvalidOuterReference(g) - - // Category 4: Any other operators not in the above 3 categories - // cannot be on a correlation path, that is they are allowed only - // under a correlation point but they and their descendant operators - // are not allowed to have any correlated expressions. - case p => - failOnOuterReferenceInSubTree(p) - } - outerReferences - } - - /** - * Resolves the subquery. The subquery is resolved using its outer plans. This method - * will resolve the subquery by alternating between the regular analyzer and by applying the - * resolveOuterReferences rule. + * Resolves the subquery plan that is referenced in a subquery expression. The normal + * attribute references are resolved using regular analyzer and the outer references are + * resolved from the outer plans using the resolveOuterReferences method. * * Outer references from the correlated predicates are updated as children of * Subquery expression. */ private def resolveSubQuery( e: SubqueryExpression, - plans: Seq[LogicalPlan], - requiredColumns: Int = 0)( + plans: Seq[LogicalPlan])( f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = { // Step 1: Resolve the outer expressions. var previous: LogicalPlan = null @@ -1488,15 +1286,8 @@ class Analyzer( // Step 2: If the subquery plan is fully resolved, pull the outer references and record // them as children of SubqueryExpression. if (current.resolved) { - // Make sure the resolved query has the required number of output columns. This is only - // needed for Scalar and IN subqueries. - if (requiredColumns > 0 && requiredColumns != current.output.size) { - failAnalysis(s"The number of columns in the subquery (${current.output.size}) " + - s"does not match the required number of columns ($requiredColumns)") - } - // Validate the outer reference and record the outer references as children of - // subquery expression. - f(current, checkAndGetOuterReferences(current)) + // Record the outer references as children of subquery expression. + f(current, SubExprUtils.getOuterReferences(current)) } else { e.withNewPlan(current) } @@ -1514,16 +1305,11 @@ class Analyzer( private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): LogicalPlan = { plan transformExpressions { case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved => - resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId)) + resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId) if !sub.resolved => resolveSubQuery(e, plans)(Exists(_, _, exprId)) case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved => - // Get the left hand side expressions. - val expressions = value match { - case cns : CreateNamedStruct => cns.valExprs - case expr => Seq(expr) - } - val expr = resolveSubQuery(l, plans, expressions.size)(ListQuery(_, _, exprId)) + val expr = resolveSubQuery(l, plans)(ListQuery(_, _, exprId)) In(value, Seq(expr)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 2e3ac3e474866606d6da5fa199fcc7594669ed76..fb81a7006bc5e5b99c2f67e66289a57e87c3aed2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -21,6 +21,8 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ +import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -129,61 +131,8 @@ trait CheckAnalysis extends PredicateHelper { case None => w } - case s @ ScalarSubquery(query, conditions, _) => - checkAnalysis(query) - - // If no correlation, the output must be exactly one column - if (conditions.isEmpty && query.output.size != 1) { - failAnalysis( - s"Scalar subquery must return only one column, but got ${query.output.size}") - } else if (conditions.nonEmpty) { - def checkAggregate(agg: Aggregate): Unit = { - // Make sure correlated scalar subqueries contain one row for every outer row by - // enforcing that they are aggregates containing exactly one aggregate expression. - // The analyzer has already checked that subquery contained only one output column, - // and added all the grouping expressions to the aggregate. - val aggregates = agg.expressions.flatMap(_.collect { - case a: AggregateExpression => a - }) - if (aggregates.isEmpty) { - failAnalysis("The output of a correlated scalar subquery must be aggregated") - } - - // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns - // are not part of the correlated columns. - val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) - // Collect the local references from the correlated predicate in the subquery. - val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references) - .filterNot(conditions.flatMap(_.references).contains) - val correlatedCols = AttributeSet(subqueryColumns) - val invalidCols = groupByCols -- correlatedCols - // GROUP BY columns must be a subset of columns in the predicates - if (invalidCols.nonEmpty) { - failAnalysis( - "A GROUP BY clause in a scalar correlated subquery " + - "cannot contain non-correlated columns: " + - invalidCols.mkString(",")) - } - } - - // Skip subquery aliases added by the Analyzer. - // For projects, do the necessary mapping and skip to its child. - def cleanQuery(p: LogicalPlan): LogicalPlan = p match { - case s: SubqueryAlias => cleanQuery(s.child) - case p: Project => cleanQuery(p.child) - case child => child - } - - cleanQuery(query) match { - case a: Aggregate => checkAggregate(a) - case Filter(_, a: Aggregate) => checkAggregate(a) - case fail => failAnalysis(s"Correlated scalar subqueries must be Aggregated: $fail") - } - } - s - case s: SubqueryExpression => - checkAnalysis(s.plan) + checkSubqueryExpression(operator, s) s } @@ -291,19 +240,6 @@ trait CheckAnalysis extends PredicateHelper { case LocalLimit(limitExpr, _) => checkLimitClause(limitExpr) - case p if p.expressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) => - p match { - case _: Filter | _: Aggregate | _: Project => // Ok - case other => failAnalysis( - s"Correlated scalar sub-queries can only be used in a Filter/Aggregate/Project: $p") - } - - case p if p.expressions.exists(SubqueryExpression.hasInOrExistsSubquery) => - p match { - case _: Filter => // Ok - case _ => failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") - } - case _: Union | _: SetOperation if operator.children.length > 1 => def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType) def ordinalNumber(i: Int): String = i match { @@ -414,4 +350,272 @@ trait CheckAnalysis extends PredicateHelper { plan.foreach(_.setAnalyzed()) } + + /** + * Validates subquery expressions in the plan. Upon failure, returns an user facing error. + */ + private def checkSubqueryExpression(plan: LogicalPlan, expr: SubqueryExpression): Unit = { + def checkAggregateInScalarSubquery( + conditions: Seq[Expression], + query: LogicalPlan, agg: Aggregate): Unit = { + // Make sure correlated scalar subqueries contain one row for every outer row by + // enforcing that they are aggregates containing exactly one aggregate expression. + val aggregates = agg.expressions.flatMap(_.collect { + case a: AggregateExpression => a + }) + if (aggregates.isEmpty) { + failAnalysis("The output of a correlated scalar subquery must be aggregated") + } + + // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns + // are not part of the correlated columns. + val groupByCols = AttributeSet(agg.groupingExpressions.flatMap(_.references)) + // Collect the local references from the correlated predicate in the subquery. + val subqueryColumns = getCorrelatedPredicates(query).flatMap(_.references) + .filterNot(conditions.flatMap(_.references).contains) + val correlatedCols = AttributeSet(subqueryColumns) + val invalidCols = groupByCols -- correlatedCols + // GROUP BY columns must be a subset of columns in the predicates + if (invalidCols.nonEmpty) { + failAnalysis( + "A GROUP BY clause in a scalar correlated subquery " + + "cannot contain non-correlated columns: " + + invalidCols.mkString(",")) + } + } + + // Skip subquery aliases added by the Analyzer. + // For projects, do the necessary mapping and skip to its child. + def cleanQueryInScalarSubquery(p: LogicalPlan): LogicalPlan = p match { + case s: SubqueryAlias => cleanQueryInScalarSubquery(s.child) + case p: Project => cleanQueryInScalarSubquery(p.child) + case child => child + } + + // Validate the subquery plan. + checkAnalysis(expr.plan) + + expr match { + case ScalarSubquery(query, conditions, _) => + // Scalar subquery must return one column as output. + if (query.output.size != 1) { + failAnalysis( + s"Scalar subquery must return only one column, but got ${query.output.size}") + } + + if (conditions.nonEmpty) { + cleanQueryInScalarSubquery(query) match { + case a: Aggregate => checkAggregateInScalarSubquery(conditions, query, a) + case Filter(_, a: Aggregate) => checkAggregateInScalarSubquery(conditions, query, a) + case fail => failAnalysis(s"Correlated scalar subqueries must be aggregated: $fail") + } + + // Only certain operators are allowed to host subquery expression containing + // outer references. + plan match { + case _: Filter | _: Aggregate | _: Project => // Ok + case other => failAnalysis( + "Correlated scalar sub-queries can only be used in a " + + s"Filter/Aggregate/Project: $plan") + } + } + + case inSubqueryOrExistsSubquery => + plan match { + case _: Filter => // Ok + case _ => + failAnalysis(s"IN/EXISTS predicate sub-queries can only be used in a Filter: $plan") + } + } + + // Validate to make sure the correlations appearing in the query are valid and + // allowed by spark. + checkCorrelationsInSubquery(expr.plan) + } + + /** + * Validates to make sure the outer references appearing inside the subquery + * are allowed. + */ + private def checkCorrelationsInSubquery(sub: LogicalPlan): Unit = { + // Validate that correlated aggregate expression do not contain a mixture + // of outer and local references. + def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = { + expr.foreach { + case a: AggregateExpression if containsOuter(a) => + val outer = a.collect { case OuterReference(e) => e.toAttribute } + val local = a.references -- outer + if (local.nonEmpty) { + val msg = + s""" + |Found an aggregate expression in a correlated predicate that has both + |outer and local references, which is not supported yet. + |Aggregate expression: ${SubExprUtils.stripOuterReference(a).sql}, + |Outer references: ${outer.map(_.sql).mkString(", ")}, + |Local references: ${local.map(_.sql).mkString(", ")}. + """.stripMargin.replace("\n", " ").trim() + failAnalysis(msg) + } + case _ => + } + } + + // Make sure a plan's subtree does not contain outer references + def failOnOuterReferenceInSubTree(p: LogicalPlan): Unit = { + if (hasOuterReferences(p)) { + failAnalysis(s"Accessing outer query column is not allowed in:\n$p") + } + } + + // Make sure a plan's expressions do not contain : + // 1. Aggregate expressions that have mixture of outer and local references. + // 2. Expressions containing outer references on plan nodes other than Filter. + def failOnInvalidOuterReference(p: LogicalPlan): Unit = { + p.expressions.foreach(checkMixedReferencesInsideAggregateExpr) + if (!p.isInstanceOf[Filter] && p.expressions.exists(containsOuter)) { + failAnalysis( + "Expressions referencing the outer query are not supported outside of WHERE/HAVING " + + s"clauses:\n$p") + } + } + + // SPARK-17348: A potential incorrect result case. + // When a correlated predicate is a non-equality predicate, + // certain operators are not permitted from the operator + // hosting the correlated predicate up to the operator on the outer table. + // Otherwise, the pull up of the correlated predicate + // will generate a plan with a different semantics + // which could return incorrect result. + // Currently we check for Aggregate and Window operators + // + // Below shows an example of a Logical Plan during Analyzer phase that + // show this problem. Pulling the correlated predicate [outer(c2#77) >= ..] + // through the Aggregate (or Window) operator could alter the result of + // the Aggregate. + // + // Project [c1#76] + // +- Project [c1#87, c2#88] + // : (Aggregate or Window operator) + // : +- Filter [outer(c2#77) >= c2#88)] + // : +- SubqueryAlias t2, `t2` + // : +- Project [_1#84 AS c1#87, _2#85 AS c2#88] + // : +- LocalRelation [_1#84, _2#85] + // +- SubqueryAlias t1, `t1` + // +- Project [_1#73 AS c1#76, _2#74 AS c2#77] + // +- LocalRelation [_1#73, _2#74] + def failOnNonEqualCorrelatedPredicate(found: Boolean, p: LogicalPlan): Unit = { + if (found) { + // Report a non-supported case as an exception + failAnalysis(s"Correlated column is not allowed in a non-equality predicate:\n$p") + } + } + + var foundNonEqualCorrelatedPred: Boolean = false + + // Simplify the predicates before validating any unsupported correlation patterns + // in the plan. + BooleanSimplification(sub).foreachUp { + // Whitelist operators allowed in a correlated subquery + // There are 4 categories: + // 1. Operators that are allowed anywhere in a correlated subquery, and, + // by definition of the operators, they either do not contain + // any columns or cannot host outer references. + // 2. Operators that are allowed anywhere in a correlated subquery + // so long as they do not host outer references. + // 3. Operators that need special handlings. These operators are + // Filter, Join, Aggregate, and Generate. + // + // Any operators that are not in the above list are allowed + // in a correlated subquery only if they are not on a correlation path. + // In other word, these operators are allowed only under a correlation point. + // + // A correlation path is defined as the sub-tree of all the operators that + // are on the path from the operator hosting the correlated expressions + // up to the operator producing the correlated values. + + // Category 1: + // ResolvedHint, Distinct, LeafNode, Repartition, and SubqueryAlias + case _: ResolvedHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => + + // Category 2: + // These operators can be anywhere in a correlated subquery. + // so long as they do not host outer references in the operators. + case p: Project => + failOnInvalidOuterReference(p) + + case s: Sort => + failOnInvalidOuterReference(s) + + case r: RepartitionByExpression => + failOnInvalidOuterReference(r) + + // Category 3: + // Filter is one of the two operators allowed to host correlated expressions. + // The other operator is Join. Filter can be anywhere in a correlated subquery. + case f: Filter => + val (correlated, _) = splitConjunctivePredicates(f.condition).partition(containsOuter) + + // Find any non-equality correlated predicates + foundNonEqualCorrelatedPred = foundNonEqualCorrelatedPred || correlated.exists { + case _: EqualTo | _: EqualNullSafe => false + case _ => true + } + failOnInvalidOuterReference(f) + + // Aggregate cannot host any correlated expressions + // It can be on a correlation path if the correlation contains + // only equality correlated predicates. + // It cannot be on a correlation path if the correlation has + // non-equality correlated predicates. + case a: Aggregate => + failOnInvalidOuterReference(a) + failOnNonEqualCorrelatedPredicate(foundNonEqualCorrelatedPred, a) + + // Join can host correlated expressions. + case j @ Join(left, right, joinType, _) => + joinType match { + // Inner join, like Filter, can be anywhere. + case _: InnerLike => + failOnInvalidOuterReference(j) + + // Left outer join's right operand cannot be on a correlation path. + // LeftAnti and ExistenceJoin are special cases of LeftOuter. + // Note that ExistenceJoin cannot be expressed externally in both SQL and DataFrame + // so it should not show up here in Analysis phase. This is just a safety net. + // + // LeftSemi does not allow output from the right operand. + // Any correlated references in the subplan + // of the right operand cannot be pulled up. + case LeftOuter | LeftSemi | LeftAnti | ExistenceJoin(_) => + failOnInvalidOuterReference(j) + failOnOuterReferenceInSubTree(right) + + // Likewise, Right outer join's left operand cannot be on a correlation path. + case RightOuter => + failOnInvalidOuterReference(j) + failOnOuterReferenceInSubTree(left) + + // Any other join types not explicitly listed above, + // including Full outer join, are treated as Category 4. + case _ => + failOnOuterReferenceInSubTree(j) + } + + // Generator with join=true, i.e., expressed with + // LATERAL VIEW [OUTER], similar to inner join, + // allows to have correlation under it + // but must not host any outer references. + // Note: + // Generator with join=false is treated as Category 4. + case g: Generate if g.join => + failOnInvalidOuterReference(g) + + // Category 4: Any other operators not in the above 3 categories + // cannot be on a correlation path, that is they are allowed only + // under a correlation point but they and their descendant operators + // are not allowed to have any correlated expressions. + case p => + failOnOuterReferenceInSubTree(p) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index c15ee2ab270bca78644d17932958ef29afc86e68..f3fe58caa6fe2e4abe6a17a3a8d09e36b4cfba12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -144,27 +144,39 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { case cns: CreateNamedStruct => cns.valExprs case expr => Seq(expr) } - - val mismatchedColumns = valExprs.zip(sub.output).flatMap { - case (l, r) if l.dataType != r.dataType => - s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" - case _ => None - } - - if (mismatchedColumns.nonEmpty) { + if (valExprs.length != sub.output.length) { TypeCheckResult.TypeCheckFailure( s""" - |The data type of one or more elements in the left hand side of an IN subquery - |is not compatible with the data type of the output of the subquery - |Mismatched columns: - |[${mismatchedColumns.mkString(", ")}] - |Left side: - |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. - |Right side: - |[${sub.output.map(_.dataType.catalogString).mkString(", ")}]. + |The number of columns in the left hand side of an IN subquery does not match the + |number of columns in the output of subquery. + |#columns in left hand side: ${valExprs.length}. + |#columns in right hand side: ${sub.output.length}. + |Left side columns: + |[${valExprs.map(_.sql).mkString(", ")}]. + |Right side columns: + |[${sub.output.map(_.sql).mkString(", ")}]. """.stripMargin) } else { - TypeCheckResult.TypeCheckSuccess + val mismatchedColumns = valExprs.zip(sub.output).flatMap { + case (l, r) if l.dataType != r.dataType => + s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})" + case _ => None + } + if (mismatchedColumns.nonEmpty) { + TypeCheckResult.TypeCheckFailure( + s""" + |The data type of one or more elements in the left hand side of an IN subquery + |is not compatible with the data type of the output of the subquery + |Mismatched columns: + |[${mismatchedColumns.mkString(", ")}] + |Left side: + |[${valExprs.map(_.dataType.catalogString).mkString(", ")}]. + |Right side: + |[${sub.output.map(_.dataType.catalogString).mkString(", ")}]. + """.stripMargin) + } else { + TypeCheckResult.TypeCheckSuccess + } } case _ => if (list.exists(l => l.dataType != value.dataType)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 5050318d963589a64edd65500542d53392387fcc..4ed995e20d7ce342f3105c21d95e9f4aba69f979 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -111,8 +111,7 @@ class AnalysisErrorSuite extends AnalysisTest { "scalar subquery with 2 columns", testRelation.select( (ScalarSubquery(testRelation.select('a, dateLit.as('b))) + Literal(1)).as('a)), - "The number of columns in the subquery (2)" :: - "does not match the required number of columns (1)":: Nil) + "Scalar subquery must return only one column, but got 2" :: Nil) errorTest( "scalar subquery with no column", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 55693121431a2bdacd653d1f153122f5a31d042b..1bf8d76da04d8781b3f308f5dfbad3e9b4af2411 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -35,7 +35,7 @@ class ResolveSubquerySuite extends AnalysisTest { test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") { val expr = Filter(In(a, Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1) val m = intercept[AnalysisException] { - SimpleAnalyzer.ResolveSubquery(expr) + SimpleAnalyzer.checkAnalysis(SimpleAnalyzer.ResolveSubquery(expr)) }.getMessage assert(m.contains( "Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses")) diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql new file mode 100644 index 0000000000000000000000000000000000000000..b15f4da81dd93ebc70e9c2b5c521bee595a0291b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/subq-input-typecheck.sql @@ -0,0 +1,47 @@ +-- The test file contains negative test cases +-- of invalid queries where error messages are expected. + +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (1, 2, 3) +AS t1(t1a, t1b, t1c); + +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES + (1, 0, 1) +AS t2(t2a, t2b, t2c); + +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES + (3, 1, 2) +AS t3(t3a, t3b, t3c); + +-- TC 01.01 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b = t1.t1b + GROUP BY t2.t2b + ) +FROM t1; + +-- TC 01.01 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b > 0 + GROUP BY t2.t2b + ) +FROM t1; + +-- TC 01.03 +SELECT * FROM t1 +WHERE +t1a IN (SELECT t2a, t2b + FROM t2 + WHERE t1a = t2a); + +-- TC 01.04 +SELECT * FROM T1 +WHERE +(t1a, t1b) IN (SELECT t2a + FROM t2 + WHERE t1a = t2a); + diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out new file mode 100644 index 0000000000000000000000000000000000000000..9ea9d3c4c6f4089c5f8f53a2f77a012ff5d98075 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out @@ -0,0 +1,106 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES + (1, 2, 3) +AS t1(t1a, t1b, t1c) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES + (1, 0, 1) +AS t2(t2a, t2b, t2c) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES + (3, 1, 2) +AS t3(t3a, t3b, t3c) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b = t1.t1b + GROUP BY t2.t2b + ) +FROM t1 +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +Scalar subquery must return only one column, but got 2; + + +-- !query 4 +SELECT + ( SELECT max(t2b), min(t2b) + FROM t2 + WHERE t2.t2b > 0 + GROUP BY t2.t2b + ) +FROM t1 +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +Scalar subquery must return only one column, but got 2; + + +-- !query 5 +SELECT * FROM t1 +WHERE +t1a IN (SELECT t2a, t2b + FROM t2 + WHERE t1a = t2a) +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +cannot resolve '(t1.`t1a` IN (listquery(t1.`t1a`)))' due to data type mismatch: +The number of columns in the left hand side of an IN subquery does not match the +number of columns in the output of subquery. +#columns in left hand side: 1. +#columns in right hand side: 2. +Left side columns: +[t1.`t1a`]. +Right side columns: +[t2.`t2a`, t2.`t2b`]. + ; + + +-- !query 6 +SELECT * FROM T1 +WHERE +(t1a, t1b) IN (SELECT t2a + FROM t2 + WHERE t1a = t2a) +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException +cannot resolve '(named_struct('t1a', t1.`t1a`, 't1b', t1.`t1b`) IN (listquery(t1.`t1a`)))' due to data type mismatch: +The number of columns in the left hand side of an IN subquery does not match the +number of columns in the output of subquery. +#columns in left hand side: 2. +#columns in right hand side: 1. +Left side columns: +[t1.`t1a`, t1.`t1b`]. +Right side columns: +[t2.`t2a`]. + ; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 4629a8c0dbe5ff8d8ca2c3990c66c5720c7d3b04..820cff655c4ffe9ce3b934fd5d7e38630ee082d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -517,7 +517,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { val msg1 = intercept[AnalysisException] { sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1") } - assert(msg1.getMessage.contains("Correlated scalar subqueries must be Aggregated")) + assert(msg1.getMessage.contains("Correlated scalar subqueries must be aggregated")) val msg2 = intercept[AnalysisException] { sql("select a, (select b from l l2 where l2.a = l1.a group by 1) sum_b from l l1")