diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 562711a1b990d0388222a8763e01116e05c8e2ae..23e4709bbd8829a73f6942f05e99d426c5f09c0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -123,13 +123,12 @@ class Analyzer( } substituted.getOrElse(u) case other => - // This can't be done in ResolveSubquery because that does not know the CTE. + // This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE. other transformExpressions { case e: SubqueryExpression => e.withNewPlan(substituteCTE(e.query, cteRelations)) } } - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index d7d768babc115ae950d3a5d50c1bb469b95c76e6..37bfe98d3ab24b7da85d6c0f374bf18efcd8596a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -255,14 +255,3 @@ case class Literal protected (value: Any, dataType: DataType) case _ => value.toString } } - -// TODO: Specialize -case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) - extends LeafExpression with CodegenFallback { - - def update(expression: Expression, input: InternalRow): Unit = { - value = expression.eval(input) - } - - override def eval(input: InternalRow): Any = value -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index d0c44b032866b974348c0e4692f909d0717dafe9..ddf214a4b30ac9369013b3aeacdf2a1282f41411 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -45,9 +45,8 @@ abstract class SubqueryExpression extends LeafExpression { } /** - * A subquery that will return only one row and one column. - * - * This will be converted into [[execution.ScalarSubquery]] during physical planning. + * A subquery that will return only one row and one column. This will be converted into a physical + * scalar subquery during planning. * * Note: `exprId` is used to have unique name in explain string output. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 872ccde88306091483cdda966cf461c9cb3a5622..477a9460d7dd8e3c55e0bb63276dfc131ad1a086 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -46,7 +46,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * populated by the query planning infrastructure. */ @transient - protected[spark] final val sqlContext = SQLContext.getActive().getOrElse(null) + protected[spark] final val sqlContext = SQLContext.getActive().orNull protected def sparkContext = sqlContext.sparkContext @@ -120,44 +120,49 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } - // All the subqueries and their Future of results. - @transient private val queryResults = ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]() + /** + * List of (uncorrelated scalar subquery, future holding the subquery result) for this plan node. + * This list is populated by [[prepareSubqueries]], which is called in [[prepare]]. + */ + @transient + private val subqueryResults = new ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])] /** - * Collects all the subqueries and create a Future to take the first two rows of them. + * Finds scalar subquery expressions in this plan node and starts evaluating them. + * The list of subqueries are added to [[subqueryResults]]. */ protected def prepareSubqueries(): Unit = { val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e}) allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e => val futureResult = Future { - // We only need the first row, try to take two rows so we can throw an exception if there - // are more than one rows returned. + // Each subquery should return only one row (and one column). We take two here and throws + // an exception later if the number of rows is greater than one. e.executedPlan.executeTake(2) }(SparkPlan.subqueryExecutionContext) - queryResults += e -> futureResult + subqueryResults += e -> futureResult } } /** - * Waits for all the subqueries to finish and updates the results. + * Blocks the thread until all subqueries finish evaluation and update the results. */ protected def waitForSubqueries(): Unit = { // fill in the result of subqueries - queryResults.foreach { - case (e, futureResult) => - val rows = Await.result(futureResult, Duration.Inf) - if (rows.length > 1) { - sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}") - } - if (rows.length == 1) { - assert(rows(0).numFields == 1, "Analyzer should make sure this only returns one column") - e.updateResult(rows(0).get(0, e.dataType)) - } else { - // There is no rows returned, the result should be null. - e.updateResult(null) - } + subqueryResults.foreach { case (e, futureResult) => + val rows = Await.result(futureResult, Duration.Inf) + if (rows.length > 1) { + sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}") + } + if (rows.length == 1) { + assert(rows(0).numFields == 1, + s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis") + e.updateResult(rows(0).get(0, e.dataType)) + } else { + // If there is no rows returned, the result should be null. + e.updateResult(null) + } } - queryResults.clear() + subqueryResults.clear() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 9c645c78e87324051750c0c028c36b4d85538bef..e6d7480b0422c10d5e03be1397b1665be47abd05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -62,7 +62,7 @@ case class ScalarSubquery( /** * Convert the subquery from logical plan into executed plan. */ -private[sql] case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] { +case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressions { case subquery: expressions.ScalarSubquery => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index e851eb02f01b34ae8c5d13123c3047c7c6aa0838..21b19fe7df8b2aa12ef74ebf94af904eba79d920 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -20,65 +20,64 @@ package org.apache.spark.sql import org.apache.spark.sql.test.SharedSQLContext class SubquerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("simple uncorrelated scalar subquery") { assertResult(Array(Row(1))) { sql("select (select 1 as b) as b").collect() } - assertResult(Array(Row(1))) { - sql("with t2 as (select 1 as b, 2 as c) " + - "select a from (select 1 as a union all select 2 as a) t " + - "where a = (select max(b) from t2) ").collect() - } - assertResult(Array(Row(3))) { sql("select (select (select 1) + 1) + 1").collect() } - // more than one columns - val error = intercept[AnalysisException] { - sql("select (select 1, 2) as b").collect() - } - assert(error.message contains "Scalar subquery must return only one column, but got 2") - - // more than one rows - val error2 = intercept[RuntimeException] { - sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect() - } - assert(error2.getMessage contains - "more than one row returned by a subquery used as an expression") - // string type assertResult(Array(Row("s"))) { sql("select (select 's' as s) as b").collect() } + } - // zero rows + test("uncorrelated scalar subquery in CTE") { + assertResult(Array(Row(1))) { + sql("with t2 as (select 1 as b, 2 as c) " + + "select a from (select 1 as a union all select 2 as a) t " + + "where a = (select max(b) from t2) ").collect() + } + } + + test("uncorrelated scalar subquery should return null if there is 0 rows") { assertResult(Array(Row(null))) { sql("select (select 's' as s limit 0) as b").collect() } } - test("uncorrelated scalar subquery on testData") { - // initialize test Data - testData + test("runtime error when the number of rows is greater than 1") { + val error2 = intercept[RuntimeException] { + sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect() + } + assert(error2.getMessage.contains( + "more than one row returned by a subquery used as an expression")) + } + + test("uncorrelated scalar subquery on a DataFrame generated query") { + val df = Seq((1, "one"), (2, "two"), (3, "three")).toDF("key", "value") + df.registerTempTable("subqueryData") - assertResult(Array(Row(5))) { - sql("select (select key from testData where key > 3 limit 1) + 1").collect() + assertResult(Array(Row(4))) { + sql("select (select key from subqueryData where key > 2 order by key limit 1) + 1").collect() } - assertResult(Array(Row(-100))) { - sql("select -(select max(key) from testData)").collect() + assertResult(Array(Row(-3))) { + sql("select -(select max(key) from subqueryData)").collect() } assertResult(Array(Row(null))) { - sql("select (select value from testData limit 0)").collect() + sql("select (select value from subqueryData limit 0)").collect() } - assertResult(Array(Row("99"))) { - sql("select (select min(value) from testData" + - " where key = (select max(key) from testData) - 1)").collect() + assertResult(Array(Row("two"))) { + sql("select (select min(value) from subqueryData" + + " where key = (select max(key) from subqueryData) - 1)").collect() } } }