From afd0debe075e9ea8466e384932a513ef0188273c Mon Sep 17 00:00:00 2001
From: Sameer Agarwal <sameer@databricks.com>
Date: Fri, 25 Mar 2016 12:57:26 -0700
Subject: [PATCH] [SPARK-14137] [SPARK-14150] [SQL] Infer IsNotNull constraints
 from non-nullable attributes

## What changes were proposed in this pull request?

This PR adds support for automatically inferring `IsNotNull` constraints from any non-nullable attributes that are part of an operator's output. This also fixes the issue that causes the optimizer to hit the maximum number of iterations for certain queries in https://github.com/apache/spark/pull/11828.

## How was this patch tested?

Unit test in `ConstraintPropagationSuite`

Author: Sameer Agarwal <sameer@databricks.com>

Closes #11953 from sameeragarwal/infer-isnotnull.
---
 .../spark/sql/catalyst/plans/QueryPlan.scala  | 38 +++++++++++--------
 .../plans/ConstraintPropagationSuite.scala    |  9 +++++
 .../execution/HiveCompatibilitySuite.scala    |  4 +-
 3 files changed, 33 insertions(+), 18 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index e9bfa09b7d..d31164fe94 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -39,29 +39,37 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
   }
 
   /**
-   * Infers a set of `isNotNull` constraints from a given set of equality/comparison expressions.
-   * For e.g., if an expression is of the form (`a > 5`), this returns a constraint of the form
-   * `isNotNull(a)`
+   * Infers a set of `isNotNull` constraints from a given set of equality/comparison expressions as
+   * well as non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this
+   * returns a constraint of the form `isNotNull(a)`
    */
   private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
-    // Currently we only propagate constraints if the condition consists of equality
-    // and ranges. For all other cases, we return an empty set of constraints
-    constraints.map {
+    var isNotNullConstraints = Set.empty[Expression]
+
+    // First, we propagate constraints if the condition consists of equality and ranges. For all
+    // other cases, we return an empty set of constraints
+    constraints.foreach {
       case EqualTo(l, r) =>
-        Set(IsNotNull(l), IsNotNull(r))
+        isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
       case GreaterThan(l, r) =>
-        Set(IsNotNull(l), IsNotNull(r))
+        isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
       case GreaterThanOrEqual(l, r) =>
-        Set(IsNotNull(l), IsNotNull(r))
+        isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
       case LessThan(l, r) =>
-        Set(IsNotNull(l), IsNotNull(r))
+        isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
       case LessThanOrEqual(l, r) =>
-        Set(IsNotNull(l), IsNotNull(r))
+        isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
       case Not(EqualTo(l, r)) =>
-        Set(IsNotNull(l), IsNotNull(r))
-      case _ =>
-        Set.empty[Expression]
-    }.foldLeft(Set.empty[Expression])(_ union _.toSet)
+        isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
+      case _ => // No inference
+    }
+
+    // Second, we infer additional constraints from non-nullable attributes that are part of the
+    // operator's output
+    val nonNullableAttributes = output.filterNot(_.nullable)
+    isNotNullConstraints ++= nonNullableAttributes.map(IsNotNull).toSet
+
+    isNotNullConstraints -- constraints
   }
 
   /**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
index f3ab026192..e5063599a3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types.{IntegerType, StringType}
 
 class ConstraintPropagationSuite extends SparkFunSuite {
 
@@ -217,4 +218,12 @@ class ConstraintPropagationSuite extends SparkFunSuite {
         IsNotNull(resolveColumn(tr, "a")),
         IsNotNull(resolveColumn(tr, "b")))))
   }
+
+  test("infer IsNotNull constraints from non-nullable attributes") {
+    val tr = LocalRelation('a.int, AttributeReference("b", IntegerType, nullable = false)(),
+      AttributeReference("c", StringType, nullable = false)())
+
+    verifyConstraints(tr.analyze.constraints,
+      ExpressionSet(Seq(IsNotNull(resolveColumn(tr, "b")), IsNotNull(resolveColumn(tr, "c")))))
+  }
 }
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 8bd731dda2..650797f768 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -341,9 +341,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
     "udf_round_3",
     "view_cast",
 
-    // enable this after fixing SPARK-14137
-    "union20",
-
     // These tests check the VIEW table definition, but Spark handles CREATE VIEW itself and
     // generates different View Expanded Text.
     "alter_view_as_select",
@@ -1046,6 +1043,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
     "union18",
     "union19",
     "union2",
+    "union20",
     "union22",
     "union23",
     "union24",
-- 
GitLab