diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 619514e8aacbe5e4efd47972c25f75cf320d1f1c..bad115d22f1ae5ace96096eaa5b75c8fbbd94a80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -86,6 +86,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { BooleanSimplification, SimplifyConditionals, RemoveDispensableExpressions, + BinaryComparisonSimplification, PruneFilters, EliminateSorts, SimplifyCasts, @@ -786,6 +787,29 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Simplifies binary comparisons with semantically-equal expressions: + * 1) Replace '<=>' with 'true' literal. + * 2) Replace '=', '<=', and '>=' with 'true' literal if both operands are non-nullable. + * 3) Replace '<' and '>' with 'false' literal if both operands are non-nullable. + */ +object BinaryComparisonSimplification extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsUp { + // True with equality + case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral + case a EqualTo b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral + case a GreaterThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => + TrueLiteral + case a LessThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral + + // False with inequality + case a GreaterThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral + case a LessThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral + } + } +} + /** * Simplifies conditional expressions (if / case). */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..7cd038570bbdf8b3edb8603205fae0811fd2279e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("AnalysisNodes", Once, + EliminateSubqueryAliases) :: + Batch("Constant Folding", FixedPoint(50), + NullPropagation, + ConstantFolding, + BooleanSimplification, + BinaryComparisonSimplification, + PruneFilters) :: Nil + } + + val nullableRelation = LocalRelation('a.int.withNullability(true)) + val nonNullableRelation = LocalRelation('a.int.withNullability(false)) + + test("Preserve nullable exprs in general") { + for (e <- Seq('a === 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a)) { + val plan = nullableRelation.where(e).analyze + val actual = Optimize.execute(plan) + val correctAnswer = plan + comparePlans(actual, correctAnswer) + } + } + + test("Preserve non-deterministic exprs") { + val plan = nonNullableRelation + .where(Rand(0) === Rand(0) && Rand(1) <=> Rand(1)).analyze + val actual = Optimize.execute(plan) + val correctAnswer = plan + comparePlans(actual, correctAnswer) + } + + test("Nullable Simplification Primitive: <=>") { + val plan = nullableRelation.select('a <=> 'a).analyze + val actual = Optimize.execute(plan) + val correctAnswer = nullableRelation.select(Alias(TrueLiteral, "(a <=> a)")()).analyze + comparePlans(actual, correctAnswer) + } + + test("Non-Nullable Simplification Primitive") { + val plan = nonNullableRelation + .select('a === 'a, 'a <=> 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a).analyze + val actual = Optimize.execute(plan) + val correctAnswer = nonNullableRelation + .select( + Alias(TrueLiteral, "(a = a)")(), + Alias(TrueLiteral, "(a <=> a)")(), + Alias(TrueLiteral, "(a <= a)")(), + Alias(TrueLiteral, "(a >= a)")(), + Alias(FalseLiteral, "(a < a)")(), + Alias(FalseLiteral, "(a > a)")()) + .analyze + comparePlans(actual, correctAnswer) + } + + test("Expression Normalization") { + val plan = nonNullableRelation.where( + 'a * Literal(100) + Pi() === Pi() + Literal(100) * 'a && + DateAdd(CurrentDate(), 'a + Literal(2)) <= DateAdd(CurrentDate(), Literal(2) + 'a)) + .analyze + val actual = Optimize.execute(plan) + val correctAnswer = nonNullableRelation.analyze + comparePlans(actual, correctAnswer) + } +}