diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 3a7004ef297f6d5e1f1c1ae42bad7eaf19e45591..6958398e03f7072539b78dce57256017e5ee04ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -442,12 +442,9 @@ object FoldablePropagation extends Rule[LogicalPlan] { case l: LeafNode => l - // Whitelist of all nodes we are allowed to apply this rule to. - case p @ (_: Project | _: Filter | _: SubqueryAlias | _: Aggregate | _: Window | - _: Sample | _: GlobalLimit | _: LocalLimit | _: Generate | _: Distinct | - _: AppendColumns | _: AppendColumnsWithObject | _: BroadcastHint | - _: RedistributeData | _: Repartition | _: Sort | _: TypedFilter) if !stop => - p.transformExpressions(replaceFoldable) + // We can only propagate foldables for a subset of unary nodes. + case u: UnaryNode if !stop && canPropagateFoldables(u) => + u.transformExpressions(replaceFoldable) // Allow inner joins. We do not allow outer join, although its output attributes are // derived from its children, they are actually different attributes: the output of outer @@ -474,6 +471,30 @@ object FoldablePropagation extends Rule[LogicalPlan] { }) } } + + /** + * Whitelist of all [[UnaryNode]]s for which allow foldable propagation. + */ + private def canPropagateFoldables(u: UnaryNode): Boolean = u match { + case _: Project => true + case _: Filter => true + case _: SubqueryAlias => true + case _: Aggregate => true + case _: Window => true + case _: Sample => true + case _: GlobalLimit => true + case _: LocalLimit => true + case _: Generate => true + case _: Distinct => true + case _: AppendColumns => true + case _: AppendColumnsWithObject => true + case _: BroadcastHint => true + case _: RedistributeData => true + case _: Repartition => true + case _: Sort => true + case _: TypedFilter => true + case _ => false + } }