From c3a6269ca994a977303a450043a577f435565f4e Mon Sep 17 00:00:00 2001 From: Sameer Agarwal <sameer@databricks.com> Date: Thu, 10 Mar 2016 17:29:45 -0800 Subject: [PATCH] [SPARK-13789] Infer additional constraints from attribute equality ## What changes were proposed in this pull request? This PR adds support for inferring an additional set of data constraints based on attribute equality. For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), we can now automatically infer an additional constraint of the form `b = 5` ## How was this patch tested? Tested that new constraints are properly inferred for filters (by adding a new test) and equi-joins (by modifying an existing test) Author: Sameer Agarwal <sameer@databricks.com> Closes #11618 from sameeragarwal/infer-isequal-constraints. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 21 +++++++++++++++++++ .../plans/ConstraintPropagationSuite.scala | 14 +++++++++++++ 2 files changed, 35 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 40c06ed6d4..c222571a34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -32,6 +32,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ protected def getRelevantConstraints(constraints: Set[Expression]): Set[Expression] = { constraints + .union(inferAdditionalConstraints(constraints)) .union(constructIsNotNullConstraints(constraints)) .filter(constraint => constraint.references.nonEmpty && constraint.references.subsetOf(outputSet)) @@ -63,6 +64,26 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT }.foldLeft(Set.empty[Expression])(_ union _.toSet) } + /** + * Infers an additional set of constraints from a given set of equality constraints. + * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an + * additional constraint of the form `b = 5` + */ + private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { + var inferredConstraints = Set.empty[Expression] + constraints.foreach { + case eq @ EqualTo(l: Attribute, r: Attribute) => + inferredConstraints ++= (constraints - eq).map(_ transform { + case a: Attribute if a.semanticEquals(l) => r + }) + inferredConstraints ++= (constraints - eq).map(_ transform { + case a: Attribute if a.semanticEquals(r) => l + }) + case _ => // No inference + } + inferredConstraints -- constraints + } + /** * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For * example, if this set contains the expression `a = 2` then that expression is guaranteed to diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index e70d3794ab..a9375a740d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -158,6 +158,7 @@ class ConstraintPropagationSuite extends SparkFunSuite { tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, tr1.resolveQuoted("a", caseInsensitiveResolution).get === tr2.resolveQuoted("a", caseInsensitiveResolution).get, + tr2.resolveQuoted("a", caseInsensitiveResolution).get > 10, IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get), IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get), IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))) @@ -203,4 +204,17 @@ class ConstraintPropagationSuite extends SparkFunSuite { .join(tr2.where('d.attr < 100), FullOuter, Some("tr1.a".attr === "tr2.a".attr)) .analyze.constraints.isEmpty) } + + test("infer additional constraints in filters") { + val tr = LocalRelation('a.int, 'b.int, 'c.int) + + verifyConstraints(tr + .where('a.attr > 10 && 'a.attr === 'b.attr) + .analyze.constraints, + ExpressionSet(Seq(resolveColumn(tr, "a") > 10, + resolveColumn(tr, "b") > 10, + resolveColumn(tr, "a") === resolveColumn(tr, "b"), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b"))))) + } } -- GitLab