diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 45ee2964d4db0c0c098f6eabb84d115f39d3324a..b108017c4c482720419c91a8ddba6f1b447769b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -40,14 +40,13 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT } /** - * Infers a set of `isNotNull` constraints from a given set of equality/comparison expressions as - * well as non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this + * Infers a set of `isNotNull` constraints from null intolerant expressions as well as + * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this * returns a constraint of the form `isNotNull(a)` */ private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { // First, we propagate constraints from the null intolerant expressions. - var isNotNullConstraints: Set[Expression] = - constraints.flatMap(scanNullIntolerantExpr).map(IsNotNull(_)) + var isNotNullConstraints: Set[Expression] = constraints.flatMap(inferIsNotNullConstraints) // Second, we infer additional constraints from non-nullable attributes that are part of the // operator's output @@ -57,14 +56,28 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT isNotNullConstraints -- constraints } + /** + * Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions + * of constraints. + */ + private def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] = + constraint match { + // When the root is IsNotNull, we can push IsNotNull through the child null intolerant + // expressions + case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_)) + // Constraints always return true for all the inputs. That means, null will never be returned. + // Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child + // null intolerant expressions. + case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_)) + } + /** * Recursively explores the expressions which are null intolerant and returns all attributes * in these expressions. */ - private def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match { + private def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match { case a: Attribute => Seq(a) - case _: NullIntolerant | IsNotNull(_: NullIntolerant) => - expr.children.flatMap(scanNullIntolerantExpr) + case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) case _ => Seq.empty[Attribute] } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 8068ce922e63631633eeb6844a287437d88691e8..a191aa8fee70278e83f504504efa0183f4fd83ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -351,6 +351,15 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(IsNotNull(resolveColumn(tr, "b"))), IsNotNull(resolveColumn(tr, "a")), IsNotNull(resolveColumn(tr, "c"))))) + + verifyConstraints( + tr.where('a.attr === 1 && IsNotNull(resolveColumn(tr, "b")) && + IsNotNull(resolveColumn(tr, "c"))).analyze.constraints, + ExpressionSet(Seq( + resolveColumn(tr, "a") === 1, + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b"))))) } test("infer IsNotNull constraints from non-nullable attributes") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f5bc8785d5a2ceccdf4694b2eeeca51e406d4a70..312cd17c26d609b7392d8b05908ff68039166c6a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1697,6 +1697,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { expr = "cast((_1 + _2) as boolean)", expectedNonNullableColumns = Seq("_1", "_2")) } + test("SPARK-17897: Fixed IsNotNull Constraint Inference Rule") { + val data = Seq[java.lang.Integer](1, null).toDF("key") + checkAnswer(data.filter(!$"key".isNotNull), Row(null)) + checkAnswer(data.filter(!(- $"key").isNotNull), Row(null)) + } + test("SPARK-17957: outer join + na.fill") { val df1 = Seq((1, 2), (2, 3)).toDF("a", "b") val df2 = Seq((2, 5), (3, 4)).toDF("a", "c")