diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 61797bc34dc2773f220bc318b19e96308c2743b0..ea4560aac7259fe0baf363fa4adb0ad8851f3220 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -130,12 +130,13 @@ trait CheckAnalysis extends PredicateHelper { } case s @ ScalarSubquery(query, conditions, _) => + checkAnalysis(query) + // If no correlation, the output must be exactly one column if (conditions.isEmpty && query.output.size != 1) { failAnalysis( s"Scalar subquery must return only one column, but got ${query.output.size}") - } - else if (conditions.nonEmpty) { + } else if (conditions.nonEmpty) { def checkAggregate(agg: Aggregate): Unit = { // Make sure correlated scalar subqueries contain one row for every outer row by // enforcing that they are aggregates containing exactly one aggregate expression. @@ -179,7 +180,6 @@ trait CheckAnalysis extends PredicateHelper { case fail => failAnalysis(s"Correlated scalar subqueries must be Aggregated: $fail") } } - checkAnalysis(query) s case s: SubqueryExpression => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 131abf7c1e5d3ef8524fddf6ecece9208c114968..a01eb2a2162670b69771af75c37500a37117db2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -72,7 +72,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext { } } - test("rdd deserialization does not crash [SPARK-15791]") { + test("SPARK-15791: rdd deserialization does not crash") { sql("select (select 1 as b) as b").rdd.count() } @@ -867,4 +867,12 @@ class SubquerySuite extends QueryTest with SharedSQLContext { sql("select * from l, r where l.a = r.c + 1 AND (exists (select * from r) OR l.a = r.c)"), Row(3, 3.0, 2, 3.0) :: Row(3, 3.0, 2, 3.0) :: Nil) } + + test("SPARK-20688: correctly check analysis for scalar sub-queries") { + withTempView("t") { + Seq(1 -> "a").toDF("i", "j").createTempView("t") + val e = intercept[AnalysisException](sql("SELECT (SELECT count(*) FROM t WHERE a = 1)")) + assert(e.message.contains("cannot resolve '`a`' given input columns: [i, j]")) + } + } }