diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b64352a9e0dc27a082106b6e43e4d2a7ef96cb58..64d89f238ca79e56a56317c8b85a3bd3e62e3360 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -63,11 +63,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ false } - /** - * Whether the "prepare" method is called. - */ - private val prepareCalled = new AtomicBoolean(false) - /** Overridden make copy also propagates sqlContext to copied plan. */ override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = { SQLContext.setActive(sqlContext) @@ -131,7 +126,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * Execute a query after preparing the query and adding query plan information to created RDDs * for visualization. */ - private final def executeQuery[T](query: => T): T = { + protected final def executeQuery[T](query: => T): T = { RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() waitForSubqueries() @@ -165,7 +160,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Blocks the thread until all subqueries finish evaluation and update the results. */ - protected def waitForSubqueries(): Unit = { + protected def waitForSubqueries(): Unit = synchronized { // fill in the result of subqueries subqueryResults.foreach { case (e, futureResult) => val rows = ThreadUtils.awaitResult(futureResult, Duration.Inf) @@ -184,14 +179,23 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ subqueryResults.clear() } + /** + * Whether the "prepare" method is called. + */ + private var prepared = false + /** * Prepare a SparkPlan for execution. It's idempotent. */ final def prepare(): Unit = { - if (prepareCalled.compareAndSet(false, true)) { - doPrepare() - prepareSubqueries() - children.foreach(_.prepare()) + // doPrepare() may depend on it's children, we should call prepare() on all the children first. + children.foreach(_.prepare()) + synchronized { + if (!prepared) { + prepareSubqueries() + doPrepare() + prepared = true + } } } @@ -202,6 +206,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * * Note: the prepare method has already walked down the tree, so the implementation doesn't need * to call children's prepare methods. + * + * This will only be called once, protected by `this`. */ protected def doPrepare(): Unit = {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 23b2eabd0c809c6eb83414a71159bf9810edaead..944962b1c8844b25d467de761372e5a4bda6e50c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -79,10 +79,9 @@ trait CodegenSupport extends SparkPlan { /** * Returns Java source code to process the rows from input RDD. */ - final def produce(ctx: CodegenContext, parent: CodegenSupport): String = { + final def produce(ctx: CodegenContext, parent: CodegenSupport): String = executeQuery { this.parent = parent ctx.freshNamePrefix = variablePrefix - waitForSubqueries() s""" |/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */ |${doProduce(ctx)} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 71b6a978529661e5676186ef40e67ff1f19fe70e..c023cc573c672123f4728cac1fef7299b4a8ae8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -48,15 +48,21 @@ case class ScalarSubquery( override def toString: String = s"subquery#${exprId.id}" // the first column in first row from `query`. - private var result: Any = null + @volatile private var result: Any = null + @volatile private var updated: Boolean = false def updateResult(v: Any): Unit = { result = v + updated = true } - override def eval(input: InternalRow): Any = result + override def eval(input: InternalRow): Any = { + require(updated, s"$this has not finished") + result + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + require(updated, s"$this has not finished") Literal.create(result, dataType).doGenCode(ctx, ev) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index d1824957573a993a3d4bb1fb550d306487c8dc7b..f9bada156bf32b947f4aa85a2b9137bc9992b54f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -123,6 +123,14 @@ class SubquerySuite extends QueryTest with SharedSQLContext { ) } + test("SPARK-14791: scalar subquery inside broadcast join") { + val df = sql("select a, sum(b) as s from l group by a having a > (select avg(a) from l)") + val expected = Row(3, 2.0, 3, 3.0) :: Row(6, null, 6, null) :: Nil + (1 to 10).foreach { _ => + checkAnswer(r.join(df, $"c" === $"a"), expected) + } + } + test("EXISTS predicate subquery") { checkAnswer( sql("select * from l where exists (select * from r where l.a = r.c)"),