diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 8a6cf53782b91fe2a832c15660f9e7df8c976ea8..a3b098afe57288b0ca9f77e71beacb88e54d8c91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -69,8 +69,11 @@ trait PredicateHelper { protected def replaceAlias( condition: Expression, aliases: AttributeMap[Expression]): Expression = { - condition.transform { - case a: Attribute => aliases.getOrElse(a, a) + // Use transformUp to prevent infinite recursion when the replacement expression + // redefines the same ExprId, + condition.transformUp { + case a: Attribute => + aliases.getOrElse(a, a) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e342c23d19c85f66a9608991033aa679be302ffe..a0a8c442d1d80e99c76ce56cdcb23f220dd57515 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -528,7 +528,8 @@ object CollapseProject extends Rule[LogicalPlan] { // Substitute any attributes that are produced by the lower projection, so that we safely // eliminate it. // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' - val rewrittenUpper = upper.map(_.transform { + // Use transformUp to prevent infinite recursion. + val rewrittenUpper = upper.map(_.transformUp { case a: Attribute => aliases.getOrElse(a, a) }) // collapse upper and lower Projects may introduce unnecessary Aliases, trim them here. @@ -1761,6 +1762,128 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { newExpression.asInstanceOf[E] } + /** + * Statically evaluate an expression containing zero or more placeholders, given a set + * of bindings for placeholder values. + */ + private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : Option[Any] = { + val rewrittenExpr = expr transform { + case r @ AttributeReference(_, dataType, _, _) => + bindings(r.exprId) match { + case Some(v) => Literal.create(v, dataType) + case None => Literal.default(NullType) + } + } + Option(rewrittenExpr.eval()) + } + + /** + * Statically evaluate an expression containing one or more aggregates on an empty input. + */ + private def evalAggOnZeroTups(expr: Expression) : Option[Any] = { + // AggregateExpressions are Unevaluable, so we need to replace all aggregates + // in the expression with the value they would return for zero input tuples. + // Also replace attribute refs (for example, for grouping columns) with NULL. + val rewrittenExpr = expr transform { + case a @ AggregateExpression(aggFunc, _, _, resultId) => + aggFunc.defaultResult.getOrElse(Literal.default(NullType)) + + case AttributeReference(_, _, _, _) => Literal.default(NullType) + } + Option(rewrittenExpr.eval()) + } + + /** + * Statically evaluate a scalar subquery on an empty input. + * + * <b>WARNING:</b> This method only covers subqueries that pass the checks under + * [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the checks in + * CheckAnalysis become less restrictive, this method will need to change. + */ + private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = { + // Inputs to this method will start with a chain of zero or more SubqueryAlias + // and Project operators, followed by an optional Filter, followed by an + // Aggregate. Traverse the operators recursively. + def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = { + lp match { + case SubqueryAlias(_, child) => evalPlan(child) + case Filter(condition, child) => + val bindings = evalPlan(child) + if (bindings.isEmpty) bindings + else { + val exprResult = evalExpr(condition, bindings).getOrElse(false) + .asInstanceOf[Boolean] + if (exprResult) bindings else Map.empty + } + + case Project(projectList, child) => + val bindings = evalPlan(child) + if (bindings.isEmpty) { + bindings + } else { + projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap + } + + case Aggregate(_, aggExprs, _) => + // Some of the expressions under the Aggregate node are the join columns + // for joining with the outer query block. Fill those expressions in with + // nulls and statically evaluate the remainder. + aggExprs.map(ne => ne match { + case AttributeReference(_, _, _, _) => (ne.exprId, None) + case Alias(AttributeReference(_, _, _, _), _) => (ne.exprId, None) + case _ => (ne.exprId, evalAggOnZeroTups(ne)) + }).toMap + + case _ => sys.error(s"Unexpected operator in scalar subquery: $lp") + } + } + + val resultMap = evalPlan(plan) + + // By convention, the scalar subquery result is the leftmost field. + resultMap(plan.output.head.exprId) + } + + /** + * Split the plan for a scalar subquery into the parts above the innermost query block + * (first part of returned value), the HAVING clause of the innermost query block + * (optional second part) and the parts below the HAVING CLAUSE (third part). + */ + private def splitSubquery(plan: LogicalPlan) : (Seq[LogicalPlan], Option[Filter], Aggregate) = { + val topPart = ArrayBuffer.empty[LogicalPlan] + var bottomPart : LogicalPlan = plan + while (true) { + bottomPart match { + case havingPart@Filter(_, aggPart@Aggregate(_, _, _)) => + return (topPart, Option(havingPart), aggPart.asInstanceOf[Aggregate]) + + case aggPart@Aggregate(_, _, _) => + // No HAVING clause + return (topPart, None, aggPart) + + case p@Project(_, child) => + topPart += p + bottomPart = child + + case s@SubqueryAlias(_, child) => + topPart += s + bottomPart = child + + case Filter(_, op@_) => + sys.error(s"Correlated subquery has unexpected operator $op below filter") + + case op@_ => sys.error(s"Unexpected operator $op in correlated subquery") + } + } + + sys.error("This line should be unreachable") + } + + + + // Name of generated column used in rewrite below + val ALWAYS_TRUE_COLNAME = "alwaysTrue" + /** * Construct a new child plan by left joining the given subqueries to a base plan. */ @@ -1769,9 +1892,76 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = { subqueries.foldLeft(child) { case (currentChild, ScalarSubquery(query, conditions, _)) => - Project( - currentChild.output :+ query.output.head, - Join(currentChild, query, LeftOuter, conditions.reduceOption(And))) + val origOutput = query.output.head + + val resultWithZeroTups = evalSubqueryOnZeroTups(query) + if (resultWithZeroTups.isEmpty) { + // CASE 1: Subquery guaranteed not to have the COUNT bug + Project( + currentChild.output :+ origOutput, + Join(currentChild, query, LeftOuter, conditions.reduceOption(And))) + } else { + // Subquery might have the COUNT bug. Add appropriate corrections. + val (topPart, havingNode, aggNode) = splitSubquery(query) + + // The next two cases add a leading column to the outer join input to make it + // possible to distinguish between the case when no tuples join and the case + // when the tuple that joins contains null values. + // The leading column always has the value TRUE. + val alwaysTrueExprId = NamedExpression.newExprId + val alwaysTrueExpr = Alias(Literal.TrueLiteral, + ALWAYS_TRUE_COLNAME)(exprId = alwaysTrueExprId) + val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME, + BooleanType)(exprId = alwaysTrueExprId) + + val aggValRef = query.output.head + + if (!havingNode.isDefined) { + // CASE 2: Subquery with no HAVING clause + Project( + currentChild.output :+ + Alias( + If(IsNull(alwaysTrueRef), + Literal(resultWithZeroTups.get, origOutput.dataType), + aggValRef), origOutput.name)(exprId = origOutput.exprId), + Join(currentChild, + Project(query.output :+ alwaysTrueExpr, query), + LeftOuter, conditions.reduceOption(And))) + + } else { + // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join. + // Need to modify any operators below the join to pass through all columns + // referenced in the HAVING clause. + var subqueryRoot : UnaryNode = aggNode + val havingInputs : Seq[NamedExpression] = aggNode.output + + topPart.reverse.foreach( + _ match { + case Project(projList, _) => + subqueryRoot = Project(projList ++ havingInputs, subqueryRoot) + case s@SubqueryAlias(alias, _) => subqueryRoot = SubqueryAlias(alias, subqueryRoot) + case op@_ => sys.error(s"Unexpected operator $op in corelated subquery") + } + ) + + // CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups + // WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>) + // ELSE (aggregate value) END AS (original column name) + val caseExpr = Alias(CaseWhen( + Seq[(Expression, Expression)] ( + (IsNull(alwaysTrueRef), Literal(resultWithZeroTups.get, origOutput.dataType)), + (Not(havingNode.get.condition), Literal(null, aggValRef.dataType)) + ), aggValRef + ), origOutput.name) (exprId = origOutput.exprId) + + Project( + currentChild.output :+ caseExpr, + Join(currentChild, + Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot), + LeftOuter, conditions.reduceOption(And))) + + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 05491a4a888d474ee15b924e7d5c8f708fa4192d..06ced999740eb3041cadac74070b8b6f457c5916 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -324,4 +324,85 @@ class SubquerySuite extends QueryTest with SharedSQLContext { """.stripMargin), Row(3) :: Nil) } + + test("SPARK-15370: COUNT bug in WHERE clause (Filter)") { + // Case 1: Canonical example of the COUNT bug + checkAnswer( + sql("select l.a from l where (select count(*) from r where l.a = r.c) < l.a"), + Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil) + // Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses + // a rewrite that is vulnerable to the COUNT bug + checkAnswer( + sql("select l.a from l where (select count(*) from r where l.a = r.c) = 0"), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) + // Case 3: COUNT bug without a COUNT aggregate + checkAnswer( + sql("select l.a from l where (select sum(r.d) is null from r where l.a = r.c)"), + Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil) + } + + test("SPARK-15370: COUNT bug in SELECT clause (Project)") { + checkAnswer( + sql("select a, (select count(*) from r where l.a = r.c) as cnt from l"), + Row(1, 0) :: Row(1, 0) :: Row(2, 2) :: Row(2, 2) :: Row(3, 1) :: Row(null, 0) + :: Row(null, 0) :: Row(6, 1) :: Nil) + } + + test("SPARK-15370: COUNT bug in HAVING clause (Filter)") { + checkAnswer( + sql("select l.a as grp_a from l group by l.a " + + "having (select count(*) from r where grp_a = r.c) = 0 " + + "order by grp_a"), + Row(null) :: Row(1) :: Nil) + } + + test("SPARK-15370: COUNT bug in Aggregate") { + checkAnswer( + sql("select l.a as aval, sum((select count(*) from r where l.a = r.c)) as cnt " + + "from l group by l.a order by aval"), + Row(null, 0) :: Row(1, 0) :: Row(2, 4) :: Row(3, 1) :: Row(6, 1) :: Nil) + } + + test("SPARK-15370: COUNT bug negative examples") { + // Case 1: Potential COUNT bug case that was working correctly prior to the fix + checkAnswer( + sql("select l.a from l where (select sum(r.d) from r where l.a = r.c) is null"), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Row(6) :: Nil) + // Case 2: COUNT aggregate but no COUNT bug due to > 0 test. + checkAnswer( + sql("select l.a from l where (select count(*) from r where l.a = r.c) > 0"), + Row(2) :: Row(2) :: Row(3) :: Row(6) :: Nil) + // Case 3: COUNT inside aggregate expression but no COUNT bug. + checkAnswer( + sql("select l.a from l where (select count(*) + sum(r.d) from r where l.a = r.c) = 0"), + Nil) + } + + test("SPARK-15370: COUNT bug in subquery in subquery in subquery") { + checkAnswer( + sql("""select l.a from l + |where ( + | select cntPlusOne + 1 as cntPlusTwo from ( + | select cnt + 1 as cntPlusOne from ( + | select sum(r.c) s, count(*) cnt from r where l.a = r.c having cnt = 0 + | ) + | ) + |) = 2""".stripMargin), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) + } + + test("SPARK-15370: COUNT bug with nasty predicate expr") { + checkAnswer( + sql("select l.a from l where " + + "(select case when count(*) = 1 then null else count(*) end as cnt " + + "from r where l.a = r.c) = 0"), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) + } + + test("SPARK-15370: COUNT bug with attribute ref in subquery input and output ") { + checkAnswer( + sql("select l.b, (select (r.c + count(*)) is null from r where l.a = r.c) from l"), + Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) :: + Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil) + } }