diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index e876450c73fde0c77fc7671e326b30dcf8ae69fe..65e497afc12cd9adce35d4fc69abec3879ffd521 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -37,7 +37,7 @@ object Canonicalize extends { } /** Remove names and nullability from types. */ - private def ignoreNamesTypes(e: Expression): Expression = e match { + private[expressions] def ignoreNamesTypes(e: Expression): Expression = e match { case a: AttributeReference => AttributeReference("none", a.dataType.asNullable)(exprId = a.exprId) case _ => e @@ -78,13 +78,11 @@ object Canonicalize extends { case GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) case LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l) - case Not(GreaterThan(l, r)) if l.hashCode() > r.hashCode() => GreaterThan(r, l) + // Note in the following `NOT` cases, `l.hashCode() <= r.hashCode()` holds. The reason is that + // canonicalization is conducted bottom-up -- see [[Expression.canonicalized]]. case Not(GreaterThan(l, r)) => LessThanOrEqual(l, r) - case Not(LessThan(l, r)) if l.hashCode() > r.hashCode() => LessThan(r, l) case Not(LessThan(l, r)) => GreaterThanOrEqual(l, r) - case Not(GreaterThanOrEqual(l, r)) if l.hashCode() > r.hashCode() => GreaterThanOrEqual(r, l) case Not(GreaterThanOrEqual(l, r)) => LessThan(l, r) - case Not(LessThanOrEqual(l, r)) if l.hashCode() > r.hashCode() => LessThanOrEqual(r, l) case Not(LessThanOrEqual(l, r)) => GreaterThan(l, r) case _ => e diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala index c587d4f6325313c8335e6fa3ea438702a4f0b6f2..d617ad540d5ff3ad7b88f47645c9daecb9ba8d7a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -32,6 +32,38 @@ class ExpressionSetSuite extends SparkFunSuite { val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil) + // An [AttributeReference] with almost the maximum hashcode, to make testing canonicalize rules + // like `case GreaterThan(l, r) if l.hashcode > r.hashcode => GreaterThan(r, l)` easier + val maxHash = + Canonicalize.ignoreNamesTypes( + AttributeReference("maxHash", IntegerType)(exprId = + new ExprId(4, NamedExpression.jvmId) { + // maxHash's hashcode is calculated based on this exprId's hashcode, so we set this + // exprId's hashCode to this specific value to make sure maxHash's hashcode is + // `Int.MaxValue` + override def hashCode: Int = -1030353449 + // We are implementing this equals() only because the style-checking rule "you should + // implement equals and hashCode together" requires us to + override def equals(obj: Any): Boolean = super.equals(obj) + })).asInstanceOf[AttributeReference] + assert(maxHash.hashCode() == Int.MaxValue) + + // An [AttributeReference] with almost the minimum hashcode, to make testing canonicalize rules + // like `case GreaterThan(l, r) if l.hashcode > r.hashcode => GreaterThan(r, l)` easier + val minHash = + Canonicalize.ignoreNamesTypes( + AttributeReference("minHash", IntegerType)(exprId = + new ExprId(5, NamedExpression.jvmId) { + // minHash's hashcode is calculated based on this exprId's hashcode, so we set this + // exprId's hashCode to this specific value to make sure minHash's hashcode is + // `Int.MinValue` + override def hashCode: Int = 1407330692 + // We are implementing this equals() only because the style-checking rule "you should + // implement equals and hashCode together" requires us to + override def equals(obj: Any): Boolean = super.equals(obj) + })).asInstanceOf[AttributeReference] + assert(minHash.hashCode() == Int.MinValue) + def setTest(size: Int, exprs: Expression*): Unit = { test(s"expect $size: ${exprs.mkString(", ")}") { val set = ExpressionSet(exprs) @@ -75,10 +107,14 @@ class ExpressionSetSuite extends SparkFunSuite { setTest(1, aUpper >= bUpper, bUpper <= aUpper) // `Not` canonicalization - setTest(1, Not(aUpper > 1), aUpper <= 1, Not(Literal(1) < aUpper), Literal(1) >= aUpper) - setTest(1, Not(aUpper < 1), aUpper >= 1, Not(Literal(1) > aUpper), Literal(1) <= aUpper) - setTest(1, Not(aUpper >= 1), aUpper < 1, Not(Literal(1) <= aUpper), Literal(1) > aUpper) - setTest(1, Not(aUpper <= 1), aUpper > 1, Not(Literal(1) >= aUpper), Literal(1) < aUpper) + setTest(1, Not(maxHash > 1), maxHash <= 1, Not(Literal(1) < maxHash), Literal(1) >= maxHash) + setTest(1, Not(minHash > 1), minHash <= 1, Not(Literal(1) < minHash), Literal(1) >= minHash) + setTest(1, Not(maxHash < 1), maxHash >= 1, Not(Literal(1) > maxHash), Literal(1) <= maxHash) + setTest(1, Not(minHash < 1), minHash >= 1, Not(Literal(1) > minHash), Literal(1) <= minHash) + setTest(1, Not(maxHash >= 1), maxHash < 1, Not(Literal(1) <= maxHash), Literal(1) > maxHash) + setTest(1, Not(minHash >= 1), minHash < 1, Not(Literal(1) <= minHash), Literal(1) > minHash) + setTest(1, Not(maxHash <= 1), maxHash > 1, Not(Literal(1) >= maxHash), Literal(1) < maxHash) + setTest(1, Not(minHash <= 1), minHash > 1, Not(Literal(1) >= minHash), Literal(1) < minHash) // Reordering AND/OR expressions setTest(1, aUpper > bUpper && aUpper <= 10, aUpper <= 10 && aUpper > bUpper)