From 9770f6ee60f6834e4e1200234109120427a5cc0d Mon Sep 17 00:00:00 2001
From: frreiss <frreiss@us.ibm.com>
Date: Sun, 12 Jun 2016 14:21:10 -0700
Subject: [PATCH] [SPARK-15370][SQL] Update RewriteCorrelatedScalarSubquery
 rule to fix COUNT bug

## What changes were proposed in this pull request?
This pull request fixes the COUNT bug in the `RewriteCorrelatedScalarSubquery` rule.

After this change, the rule tests the expression at the root of the correlated subquery to determine whether the expression returns NULL on empty input. If the expression does not return NULL, the rule generates additional logic in the Project operator above the rewritten subquery.  This additional logic intercepts NULL values coming from the outer join and replaces them with the value that the subquery's expression would return on empty input.

## How was this patch tested?
Added regression tests to cover all branches of the updated rule (see changes to `SubquerySuite.scala`).
Ran all existing automated regression tests after merging with latest trunk.

Author: frreiss <frreiss@us.ibm.com>

Closes #13155 from frreiss/master.
---
 .../sql/catalyst/expressions/predicates.scala |   7 +-
 .../sql/catalyst/optimizer/Optimizer.scala    | 198 +++++++++++++++++-
 .../org/apache/spark/sql/SubquerySuite.scala  |  81 +++++++
 3 files changed, 280 insertions(+), 6 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 8a6cf53782..a3b098afe5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -69,8 +69,11 @@ trait PredicateHelper {
   protected def replaceAlias(
       condition: Expression,
       aliases: AttributeMap[Expression]): Expression = {
-    condition.transform {
-      case a: Attribute => aliases.getOrElse(a, a)
+    // Use transformUp to prevent infinite recursion when the replacement expression
+    // redefines the same ExprId,
+    condition.transformUp {
+      case a: Attribute =>
+        aliases.getOrElse(a, a)
     }
   }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index e342c23d19..a0a8c442d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -528,7 +528,8 @@ object CollapseProject extends Rule[LogicalPlan] {
     // Substitute any attributes that are produced by the lower projection, so that we safely
     // eliminate it.
     // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
-    val rewrittenUpper = upper.map(_.transform {
+    // Use transformUp to prevent infinite recursion.
+    val rewrittenUpper = upper.map(_.transformUp {
       case a: Attribute => aliases.getOrElse(a, a)
     })
     // collapse upper and lower Projects may introduce unnecessary Aliases, trim them here.
@@ -1761,6 +1762,128 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
     newExpression.asInstanceOf[E]
   }
 
+  /**
+   * Statically evaluate an expression containing zero or more placeholders, given a set
+   * of bindings for placeholder values.
+   */
+  private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : Option[Any] = {
+    val rewrittenExpr = expr transform {
+      case r @ AttributeReference(_, dataType, _, _) =>
+        bindings(r.exprId) match {
+          case Some(v) => Literal.create(v, dataType)
+          case None => Literal.default(NullType)
+        }
+    }
+    Option(rewrittenExpr.eval())
+  }
+
+  /**
+   * Statically evaluate an expression containing one or more aggregates on an empty input.
+   */
+  private def evalAggOnZeroTups(expr: Expression) : Option[Any] = {
+    // AggregateExpressions are Unevaluable, so we need to replace all aggregates
+    // in the expression with the value they would return for zero input tuples.
+    // Also replace attribute refs (for example, for grouping columns) with NULL.
+    val rewrittenExpr = expr transform {
+      case a @ AggregateExpression(aggFunc, _, _, resultId) =>
+        aggFunc.defaultResult.getOrElse(Literal.default(NullType))
+
+      case AttributeReference(_, _, _, _) => Literal.default(NullType)
+    }
+    Option(rewrittenExpr.eval())
+  }
+
+  /**
+   * Statically evaluate a scalar subquery on an empty input.
+   *
+   * <b>WARNING:</b> This method only covers subqueries that pass the checks under
+   * [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the checks in
+   * CheckAnalysis become less restrictive, this method will need to change.
+   */
+  private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = {
+    // Inputs to this method will start with a chain of zero or more SubqueryAlias
+    // and Project operators, followed by an optional Filter, followed by an
+    // Aggregate. Traverse the operators recursively.
+    def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = {
+      lp match {
+        case SubqueryAlias(_, child) => evalPlan(child)
+        case Filter(condition, child) =>
+          val bindings = evalPlan(child)
+          if (bindings.isEmpty) bindings
+          else {
+            val exprResult = evalExpr(condition, bindings).getOrElse(false)
+              .asInstanceOf[Boolean]
+            if (exprResult) bindings else Map.empty
+          }
+
+        case Project(projectList, child) =>
+          val bindings = evalPlan(child)
+          if (bindings.isEmpty) {
+            bindings
+          } else {
+            projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap
+          }
+
+        case Aggregate(_, aggExprs, _) =>
+          // Some of the expressions under the Aggregate node are the join columns
+          // for joining with the outer query block. Fill those expressions in with
+          // nulls and statically evaluate the remainder.
+          aggExprs.map(ne => ne match {
+            case AttributeReference(_, _, _, _) => (ne.exprId, None)
+            case Alias(AttributeReference(_, _, _, _), _) => (ne.exprId, None)
+            case _ => (ne.exprId, evalAggOnZeroTups(ne))
+          }).toMap
+
+        case _ => sys.error(s"Unexpected operator in scalar subquery: $lp")
+      }
+    }
+
+    val resultMap = evalPlan(plan)
+
+    // By convention, the scalar subquery result is the leftmost field.
+    resultMap(plan.output.head.exprId)
+  }
+
+  /**
+   * Split the plan for a scalar subquery into the parts above the innermost query block
+   * (first part of returned value), the HAVING clause of the innermost query block
+   * (optional second part) and the parts below the HAVING CLAUSE (third part).
+   */
+  private def splitSubquery(plan: LogicalPlan) : (Seq[LogicalPlan], Option[Filter], Aggregate) = {
+    val topPart = ArrayBuffer.empty[LogicalPlan]
+    var bottomPart : LogicalPlan = plan
+    while (true) {
+      bottomPart match {
+        case havingPart@Filter(_, aggPart@Aggregate(_, _, _)) =>
+          return (topPart, Option(havingPart), aggPart.asInstanceOf[Aggregate])
+
+        case aggPart@Aggregate(_, _, _) =>
+          // No HAVING clause
+          return (topPart, None, aggPart)
+
+        case p@Project(_, child) =>
+          topPart += p
+          bottomPart = child
+
+        case s@SubqueryAlias(_, child) =>
+          topPart += s
+          bottomPart = child
+
+        case Filter(_, op@_) =>
+          sys.error(s"Correlated subquery has unexpected operator $op below filter")
+
+        case op@_ => sys.error(s"Unexpected operator $op in correlated subquery")
+      }
+    }
+
+    sys.error("This line should be unreachable")
+  }
+
+
+
+  // Name of generated column used in rewrite below
+  val ALWAYS_TRUE_COLNAME = "alwaysTrue"
+
   /**
    * Construct a new child plan by left joining the given subqueries to a base plan.
    */
@@ -1769,9 +1892,76 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
       subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = {
     subqueries.foldLeft(child) {
       case (currentChild, ScalarSubquery(query, conditions, _)) =>
-        Project(
-          currentChild.output :+ query.output.head,
-          Join(currentChild, query, LeftOuter, conditions.reduceOption(And)))
+        val origOutput = query.output.head
+
+        val resultWithZeroTups = evalSubqueryOnZeroTups(query)
+        if (resultWithZeroTups.isEmpty) {
+          // CASE 1: Subquery guaranteed not to have the COUNT bug
+          Project(
+            currentChild.output :+ origOutput,
+            Join(currentChild, query, LeftOuter, conditions.reduceOption(And)))
+        } else {
+          // Subquery might have the COUNT bug. Add appropriate corrections.
+          val (topPart, havingNode, aggNode) = splitSubquery(query)
+
+          // The next two cases add a leading column to the outer join input to make it
+          // possible to distinguish between the case when no tuples join and the case
+          // when the tuple that joins contains null values.
+          // The leading column always has the value TRUE.
+          val alwaysTrueExprId = NamedExpression.newExprId
+          val alwaysTrueExpr = Alias(Literal.TrueLiteral,
+            ALWAYS_TRUE_COLNAME)(exprId = alwaysTrueExprId)
+          val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME,
+            BooleanType)(exprId = alwaysTrueExprId)
+
+          val aggValRef = query.output.head
+
+          if (!havingNode.isDefined) {
+            // CASE 2: Subquery with no HAVING clause
+            Project(
+              currentChild.output :+
+                Alias(
+                  If(IsNull(alwaysTrueRef),
+                    Literal(resultWithZeroTups.get, origOutput.dataType),
+                    aggValRef), origOutput.name)(exprId = origOutput.exprId),
+              Join(currentChild,
+                Project(query.output :+ alwaysTrueExpr, query),
+                LeftOuter, conditions.reduceOption(And)))
+
+          } else {
+            // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join.
+            // Need to modify any operators below the join to pass through all columns
+            // referenced in the HAVING clause.
+            var subqueryRoot : UnaryNode = aggNode
+            val havingInputs : Seq[NamedExpression] = aggNode.output
+
+            topPart.reverse.foreach(
+              _ match {
+                case Project(projList, _) =>
+                  subqueryRoot = Project(projList ++ havingInputs, subqueryRoot)
+                case s@SubqueryAlias(alias, _) => subqueryRoot = SubqueryAlias(alias, subqueryRoot)
+                case op@_ => sys.error(s"Unexpected operator $op in corelated subquery")
+              }
+            )
+
+            // CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups
+            //      WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
+            //      ELSE (aggregate value) END AS (original column name)
+            val caseExpr = Alias(CaseWhen(
+              Seq[(Expression, Expression)] (
+                (IsNull(alwaysTrueRef), Literal(resultWithZeroTups.get, origOutput.dataType)),
+                (Not(havingNode.get.condition), Literal(null, aggValRef.dataType))
+              ), aggValRef
+            ), origOutput.name) (exprId = origOutput.exprId)
+
+            Project(
+              currentChild.output :+ caseExpr,
+              Join(currentChild,
+                Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
+                LeftOuter, conditions.reduceOption(And)))
+
+          }
+        }
     }
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 05491a4a88..06ced99974 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -324,4 +324,85 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
         """.stripMargin),
       Row(3) :: Nil)
   }
+
+  test("SPARK-15370: COUNT bug in WHERE clause (Filter)") {
+    // Case 1: Canonical example of the COUNT bug
+    checkAnswer(
+      sql("select l.a from l where (select count(*) from r where l.a = r.c) < l.a"),
+      Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil)
+    // Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses
+    // a rewrite that is vulnerable to the COUNT bug
+    checkAnswer(
+      sql("select l.a from l where (select count(*) from r where l.a = r.c) = 0"),
+      Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
+    // Case 3: COUNT bug without a COUNT aggregate
+    checkAnswer(
+      sql("select l.a from l where (select sum(r.d) is null from r where l.a = r.c)"),
+      Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil)
+  }
+
+  test("SPARK-15370: COUNT bug in SELECT clause (Project)") {
+    checkAnswer(
+      sql("select a, (select count(*) from r where l.a = r.c) as cnt from l"),
+      Row(1, 0) :: Row(1, 0) :: Row(2, 2) :: Row(2, 2) :: Row(3, 1) :: Row(null, 0)
+        :: Row(null, 0) :: Row(6, 1) :: Nil)
+  }
+
+  test("SPARK-15370: COUNT bug in HAVING clause (Filter)") {
+    checkAnswer(
+      sql("select l.a as grp_a from l group by l.a " +
+        "having (select count(*) from r where grp_a = r.c) = 0 " +
+        "order by grp_a"),
+      Row(null) :: Row(1) :: Nil)
+  }
+
+  test("SPARK-15370: COUNT bug in Aggregate") {
+    checkAnswer(
+      sql("select l.a as aval, sum((select count(*) from r where l.a = r.c)) as cnt " +
+        "from l group by l.a order by aval"),
+      Row(null, 0) :: Row(1, 0) :: Row(2, 4) :: Row(3, 1) :: Row(6, 1)  :: Nil)
+  }
+
+  test("SPARK-15370: COUNT bug negative examples") {
+    // Case 1: Potential COUNT bug case that was working correctly prior to the fix
+    checkAnswer(
+      sql("select l.a from l where (select sum(r.d) from r where l.a = r.c) is null"),
+      Row(1) :: Row(1) :: Row(null) :: Row(null) :: Row(6) :: Nil)
+    // Case 2: COUNT aggregate but no COUNT bug due to > 0 test.
+    checkAnswer(
+      sql("select l.a from l where (select count(*) from r where l.a = r.c) > 0"),
+      Row(2) :: Row(2) :: Row(3) :: Row(6) :: Nil)
+    // Case 3: COUNT inside aggregate expression but no COUNT bug.
+    checkAnswer(
+      sql("select l.a from l where (select count(*) + sum(r.d) from r where l.a = r.c) = 0"),
+      Nil)
+  }
+
+  test("SPARK-15370: COUNT bug in subquery in subquery in subquery") {
+    checkAnswer(
+      sql("""select l.a from l
+            |where (
+            |    select cntPlusOne + 1 as cntPlusTwo from (
+            |        select cnt + 1 as cntPlusOne from (
+            |            select sum(r.c) s, count(*) cnt from r where l.a = r.c having cnt = 0
+            |        )
+            |    )
+            |) = 2""".stripMargin),
+      Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
+  }
+
+  test("SPARK-15370: COUNT bug with nasty predicate expr") {
+    checkAnswer(
+      sql("select l.a from l where " +
+        "(select case when count(*) = 1 then null else count(*) end as cnt " +
+        "from r where l.a = r.c) = 0"),
+      Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
+  }
+
+  test("SPARK-15370: COUNT bug with attribute ref in subquery input and output ") {
+    checkAnswer(
+      sql("select l.b, (select (r.c + count(*)) is null from r where l.a = r.c) from l"),
+      Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) ::
+        Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil)
+  }
 }
-- 
GitLab