Skip to content
Snippets Groups Projects
Commit 85e68b4b authored by Davies Liu's avatar Davies Liu Committed by Davies Liu
Browse files

[SPARK-14562] [SQL] improve constraints propagation in Union

## What changes were proposed in this pull request?

Currently, Union only takes intersect of the constraints from it's children, all others are dropped, we should try to merge them together.

This PR try to merge the constraints that have the same reference but came from different children, for example: `a > 10` and `a < 100` could be merged as `a > 10 || a < 100`.

## How was this patch tested?

Added more cases in existing test.

Author: Davies Liu <davies@databricks.com>

Closes #12328 from davies/union_const.
parent 852bbc6c
No related branches found
No related tags found
No related merge requests found
...@@ -236,10 +236,24 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { ...@@ -236,10 +236,24 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
}) })
} }
private def merge(a: Set[Expression], b: Set[Expression]): Set[Expression] = {
val common = a.intersect(b)
// The constraint with only one reference could be easily inferred as predicate
// Grouping the constraints by it's references so we can combine the constraints with same
// reference together
val othera = a.diff(common).filter(_.references.size == 1).groupBy(_.references.head)
val otherb = b.diff(common).filter(_.references.size == 1).groupBy(_.references.head)
// loose the constraints by: A1 && B1 || A2 && B2 -> (A1 || A2) && (B1 || B2)
val others = (othera.keySet intersect otherb.keySet).map { attr =>
Or(othera(attr).reduceLeft(And), otherb(attr).reduceLeft(And))
}
common ++ others
}
override protected def validConstraints: Set[Expression] = { override protected def validConstraints: Set[Expression] = {
children children
.map(child => rewriteConstraints(children.head.output, child.output, child.constraints)) .map(child => rewriteConstraints(children.head.output, child.output, child.constraints))
.reduce(_ intersect _) .reduce(merge(_, _))
} }
} }
......
...@@ -148,6 +148,20 @@ class ConstraintPropagationSuite extends SparkFunSuite { ...@@ -148,6 +148,20 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.analyze.constraints, .analyze.constraints,
ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, ExpressionSet(Seq(resolveColumn(tr1, "a") > 10,
IsNotNull(resolveColumn(tr1, "a"))))) IsNotNull(resolveColumn(tr1, "a")))))
val a = resolveColumn(tr1, "a")
verifyConstraints(tr1
.where('a.attr > 10)
.union(tr2.where('d.attr > 11))
.analyze.constraints,
ExpressionSet(Seq(a > 10 || a > 11, IsNotNull(a))))
val b = resolveColumn(tr1, "b")
verifyConstraints(tr1
.where('a.attr > 10 && 'b.attr < 10)
.union(tr2.where('d.attr > 11 && 'e.attr < 11))
.analyze.constraints,
ExpressionSet(Seq(a > 10 || a > 11, b < 10 || b < 11, IsNotNull(a), IsNotNull(b))))
} }
test("propagating constraints in intersect") { test("propagating constraints in intersect") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment