diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 714e2cdac2b198369f6088aa133cb09899afad20..7f32f6b8bcf465ce4e2fc1c5b2f6fe136ff1a421 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -153,11 +153,13 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ GetItem(Literal(null, _), _) => Literal(null, e.dataType) case e @ GetItem(_, Literal(null, _)) => Literal(null, e.dataType) case e @ GetField(Literal(null, _), _) => Literal(null, e.dataType) - case e @ Coalesce(children) => { - val newChildren = children.filter(c => c match { + + // For Coalesce, remove null literals. + case e @ Coalesce(children) => + val newChildren = children.filter { case Literal(null, _) => false case _ => true - }) + } if (newChildren.length == 0) { Literal(null, e.dataType) } else if (newChildren.length == 1) { @@ -165,15 +167,11 @@ object NullPropagation extends Rule[LogicalPlan] { } else { Coalesce(newChildren) } - } - case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue - case e @ In(Literal(v, _), list) if (list.exists(c => c match { - case Literal(candidate, _) if candidate == v => true - case _ => false - })) => Literal(true, BooleanType) + case e @ Substring(Literal(null, _), _, _) => Literal(null, e.dataType) case e @ Substring(_, Literal(null, _), _) => Literal(null, e.dataType) case e @ Substring(_, _, Literal(null, _)) => Literal(null, e.dataType) + // Put exceptional cases above if any case e: BinaryArithmetic => e.children match { case Literal(null, _) :: right :: Nil => Literal(null, e.dataType) @@ -201,9 +199,19 @@ object NullPropagation extends Rule[LogicalPlan] { object ConstantFolding extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { - // Skip redundant folding of literals. + // Skip redundant folding of literals. This rule is technically not necessary. Placing this + // here avoids running the next rule for Literal values, which would create a new Literal + // object and running eval unnecessarily. case l: Literal => l + + // Fold expressions that are foldable. case e if e.foldable => Literal(e.eval(null), e.dataType) + + // Fold "literal in (item1, item2, ..., literal, ...)" into true directly. + case In(Literal(v, _), list) if list.exists { + case Literal(candidate, _) if candidate == v => true + case _ => false + } => Literal(true, BooleanType) } } } @@ -233,6 +241,9 @@ object BooleanSimplification extends Rule[LogicalPlan] { case (l, Literal(false, BooleanType)) => l case (_, _) => or } + + // Turn "if (true) a else b" into "a", and if (false) a else b" into "b". + case e @ If(Literal(v, _), trueValue, falseValue) => if (v == true) trueValue else falseValue } } } @@ -254,12 +265,12 @@ object CombineFilters extends Rule[LogicalPlan] { */ object SimplifyFilters extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Filter(Literal(true, BooleanType), child) => - child - case Filter(Literal(null, _), child) => - LocalRelation(child.output) - case Filter(Literal(false, BooleanType), child) => - LocalRelation(child.output) + // If the filter condition always evaluate to true, remove the filter. + case Filter(Literal(true, BooleanType), child) => child + // If the filter condition always evaluate to null or false, + // replace the input with an empty relation. + case Filter(Literal(null, _), child) => LocalRelation(child.output, data = Seq.empty) + case Filter(Literal(false, BooleanType), child) => LocalRelation(child.output, data = Seq.empty) } } @@ -301,7 +312,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { /** * Splits join condition expressions into three categories based on the attributes required * to evaluate them. - * @returns (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth) + * @return (canEvaluateInLeft, canEvaluateInRight, haveToEvaluateInBoth) */ private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { val (leftEvaluateCondition, rest) =