Skip to content
Snippets Groups Projects
Commit 0419d631 authored by Davies Liu's avatar Davies Liu Committed by Davies Liu
Browse files

[SPARK-14791] [SQL] fix risk condition between broadcast and subquery

## What changes were proposed in this pull request?

SparkPlan.prepare() could be called in different threads (BroadcastExchange will call it in a thread pool), it only make sure that doPrepare() will only be called once, the second call to prepare() may return earlier before all the children had finished prepare(). Then some operator may call doProduce() before prepareSubqueries(), `null` will be used as the result of subquery, which is wrong. This cause TPCDS Q23B returns wrong answer sometimes.

This PR added synchronization for prepare(), make sure all the children had finished prepare() before return. Also call prepare() in produce() (similar to execute()).

Added checking for ScalarSubquery to make sure that the subquery has finished before using the result.

## How was this patch tested?

Manually tested with Q23B, no wrong answer anymore.

Author: Davies Liu <davies@databricks.com>

Closes #12600 from davies/fix_risk.
parent c417cec0
No related branches found
No related tags found
No related merge requests found
...@@ -63,11 +63,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ ...@@ -63,11 +63,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
false false
} }
/**
* Whether the "prepare" method is called.
*/
private val prepareCalled = new AtomicBoolean(false)
/** Overridden make copy also propagates sqlContext to copied plan. */ /** Overridden make copy also propagates sqlContext to copied plan. */
override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = { override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = {
SQLContext.setActive(sqlContext) SQLContext.setActive(sqlContext)
...@@ -131,7 +126,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ ...@@ -131,7 +126,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* Execute a query after preparing the query and adding query plan information to created RDDs * Execute a query after preparing the query and adding query plan information to created RDDs
* for visualization. * for visualization.
*/ */
private final def executeQuery[T](query: => T): T = { protected final def executeQuery[T](query: => T): T = {
RDDOperationScope.withScope(sparkContext, nodeName, false, true) { RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
prepare() prepare()
waitForSubqueries() waitForSubqueries()
...@@ -165,7 +160,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ ...@@ -165,7 +160,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/** /**
* Blocks the thread until all subqueries finish evaluation and update the results. * Blocks the thread until all subqueries finish evaluation and update the results.
*/ */
protected def waitForSubqueries(): Unit = { protected def waitForSubqueries(): Unit = synchronized {
// fill in the result of subqueries // fill in the result of subqueries
subqueryResults.foreach { case (e, futureResult) => subqueryResults.foreach { case (e, futureResult) =>
val rows = ThreadUtils.awaitResult(futureResult, Duration.Inf) val rows = ThreadUtils.awaitResult(futureResult, Duration.Inf)
...@@ -184,14 +179,23 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ ...@@ -184,14 +179,23 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
subqueryResults.clear() subqueryResults.clear()
} }
/**
* Whether the "prepare" method is called.
*/
private var prepared = false
/** /**
* Prepare a SparkPlan for execution. It's idempotent. * Prepare a SparkPlan for execution. It's idempotent.
*/ */
final def prepare(): Unit = { final def prepare(): Unit = {
if (prepareCalled.compareAndSet(false, true)) { // doPrepare() may depend on it's children, we should call prepare() on all the children first.
doPrepare() children.foreach(_.prepare())
prepareSubqueries() synchronized {
children.foreach(_.prepare()) if (!prepared) {
prepareSubqueries()
doPrepare()
prepared = true
}
} }
} }
...@@ -202,6 +206,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ ...@@ -202,6 +206,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* *
* Note: the prepare method has already walked down the tree, so the implementation doesn't need * Note: the prepare method has already walked down the tree, so the implementation doesn't need
* to call children's prepare methods. * to call children's prepare methods.
*
* This will only be called once, protected by `this`.
*/ */
protected def doPrepare(): Unit = {} protected def doPrepare(): Unit = {}
......
...@@ -79,10 +79,9 @@ trait CodegenSupport extends SparkPlan { ...@@ -79,10 +79,9 @@ trait CodegenSupport extends SparkPlan {
/** /**
* Returns Java source code to process the rows from input RDD. * Returns Java source code to process the rows from input RDD.
*/ */
final def produce(ctx: CodegenContext, parent: CodegenSupport): String = { final def produce(ctx: CodegenContext, parent: CodegenSupport): String = executeQuery {
this.parent = parent this.parent = parent
ctx.freshNamePrefix = variablePrefix ctx.freshNamePrefix = variablePrefix
waitForSubqueries()
s""" s"""
|/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */ |/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */
|${doProduce(ctx)} |${doProduce(ctx)}
......
...@@ -48,15 +48,21 @@ case class ScalarSubquery( ...@@ -48,15 +48,21 @@ case class ScalarSubquery(
override def toString: String = s"subquery#${exprId.id}" override def toString: String = s"subquery#${exprId.id}"
// the first column in first row from `query`. // the first column in first row from `query`.
private var result: Any = null @volatile private var result: Any = null
@volatile private var updated: Boolean = false
def updateResult(v: Any): Unit = { def updateResult(v: Any): Unit = {
result = v result = v
updated = true
} }
override def eval(input: InternalRow): Any = result override def eval(input: InternalRow): Any = {
require(updated, s"$this has not finished")
result
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
require(updated, s"$this has not finished")
Literal.create(result, dataType).doGenCode(ctx, ev) Literal.create(result, dataType).doGenCode(ctx, ev)
} }
} }
......
...@@ -123,6 +123,14 @@ class SubquerySuite extends QueryTest with SharedSQLContext { ...@@ -123,6 +123,14 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
) )
} }
test("SPARK-14791: scalar subquery inside broadcast join") {
val df = sql("select a, sum(b) as s from l group by a having a > (select avg(a) from l)")
val expected = Row(3, 2.0, 3, 3.0) :: Row(6, null, 6, null) :: Nil
(1 to 10).foreach { _ =>
checkAnswer(r.join(df, $"c" === $"a"), expected)
}
}
test("EXISTS predicate subquery") { test("EXISTS predicate subquery") {
checkAnswer( checkAnswer(
sql("select * from l where exists (select * from r where l.a = r.c)"), sql("select * from l where exists (select * from r where l.a = r.c)"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment