diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index ae1f6006135bbee03688e3c89b7b118eeef2fe01..07ba7d5e4a8493607cfd5c51a434b06015072a36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -71,6 +71,15 @@ object Canonicalize extends { case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l) + case Not(GreaterThan(l, r)) if l.hashCode() > r.hashCode() => GreaterThan(r, l) + case Not(GreaterThan(l, r)) => LessThanOrEqual(l, r) + case Not(LessThan(l, r)) if l.hashCode() > r.hashCode() => LessThan(r, l) + case Not(LessThan(l, r)) => GreaterThanOrEqual(l, r) + case Not(GreaterThanOrEqual(l, r)) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l) + case Not(GreaterThanOrEqual(l, r)) => LessThan(l, r) + case Not(LessThanOrEqual(l, r)) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) + case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r) + case _ => e } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala index 0b350c6a9825570c10a1e60479d7954f3f30b879..60939ee0eda5d10e215133661d221b64930b4fcf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -74,6 +74,12 @@ class ExpressionSetSuite extends SparkFunSuite { setTest(1, aUpper > bUpper, bUpper < aUpper) setTest(1, aUpper >= bUpper, bUpper <= aUpper) + // `Not` canonicalization + setTest(1, Not(aUpper > 1), aUpper <= 1, Not(Literal(1) < aUpper), Literal(1) >= aUpper) + setTest(1, Not(aUpper < 1), aUpper >= 1, Not(Literal(1) > aUpper), Literal(1) <= aUpper) + setTest(1, Not(aUpper >= 1), aUpper < 1, Not(Literal(1) <= aUpper), Literal(1) > aUpper) + setTest(1, Not(aUpper <= 1), aUpper > 1, Not(Literal(1) >= aUpper), Literal(1) < aUpper) + test("add to / remove from set") { val initialSet = ExpressionSet(aUpper + 1 :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 47b79fe4624572ea8110ff10e8cc671c63df6c99..2ab31eea8ab3865af7dddc23ab6181c78ff7f45f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -99,6 +99,34 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { checkCondition(('b || !'a ) && 'a, 'b && 'a) } + test("a < 1 && (!(a < 1) || b)") { + checkCondition('a < 1 && (!('a < 1) || 'b), ('a < 1) && 'b) + checkCondition('a < 1 && ('b || !('a < 1)), ('a < 1) && 'b) + + checkCondition('a <= 1 && (!('a <= 1) || 'b), ('a <= 1) && 'b) + checkCondition('a <= 1 && ('b || !('a <= 1)), ('a <= 1) && 'b) + + checkCondition('a > 1 && (!('a > 1) || 'b), ('a > 1) && 'b) + checkCondition('a > 1 && ('b || !('a > 1)), ('a > 1) && 'b) + + checkCondition('a >= 1 && (!('a >= 1) || 'b), ('a >= 1) && 'b) + checkCondition('a >= 1 && ('b || !('a >= 1)), ('a >= 1) && 'b) + } + + test("a < 1 && ((a >= 1) || b)") { + checkCondition('a < 1 && ('a >= 1 || 'b ), ('a < 1) && 'b) + checkCondition('a < 1 && ('b || 'a >= 1), ('a < 1) && 'b) + + checkCondition('a <= 1 && ('a > 1 || 'b ), ('a <= 1) && 'b) + checkCondition('a <= 1 && ('b || 'a > 1), ('a <= 1) && 'b) + + checkCondition('a > 1 && (('a <= 1) || 'b), ('a > 1) && 'b) + checkCondition('a > 1 && ('b || ('a <= 1)), ('a > 1) && 'b) + + checkCondition('a >= 1 && (('a < 1) || 'b), ('a >= 1) && 'b) + checkCondition('a >= 1 && ('b || ('a < 1)), ('a >= 1) && 'b) + } + test("DeMorgan's law") { checkCondition(!('a && 'b), !'a || !'b)