diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 502d898fea86cf0a7b33162bc7bbea579741d7ce..7d155ac183d5877bfed2255e54ea792a729dc510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -50,6 +50,26 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions } + + /** + * Generates an additional set of aliased constraints by replacing the original constraint + * expressions with the corresponding alias + */ + private def getAliasedConstraints: Set[Expression] = { + projectList.flatMap { + case a @ Alias(e, _) => + child.constraints.map(_ transform { + case expr: Expression if expr.semanticEquals(e) => + a.toAttribute + }).union(Set(EqualNullSafe(e, a.toAttribute))) + case _ => + Set.empty[Expression] + }.toSet + } + + override def validConstraints: Set[Expression] = { + child.constraints.union(getAliasedConstraints) + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index b5cf91394d910ea31106a2e5c6fc5a6e861d1af1..373b1ffa83d2379e0edc7ec3517fcdfc7381aba5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -27,7 +27,10 @@ import org.apache.spark.sql.catalyst.plans.logical._ class ConstraintPropagationSuite extends SparkFunSuite { private def resolveColumn(tr: LocalRelation, columnName: String): Expression = - tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get + resolveColumn(tr.analyze, columnName) + + private def resolveColumn(plan: LogicalPlan, columnName: String): Expression = + plan.resolveQuoted(columnName, caseInsensitiveResolution).get private def verifyConstraints(found: Set[Expression], expected: Set[Expression]): Unit = { val missing = expected.filterNot(i => found.map(_.semanticEquals(i)).reduce(_ || _)) @@ -69,6 +72,21 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(tr, "c")))) } + test("propagating constraints in aliases") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + + assert(tr.where('c.attr > 10).select('a.as('x), 'b.as('y)).analyze.constraints.isEmpty) + + val aliasedRelation = tr.where('a.attr > 10).select('a.as('x), 'b, 'b.as('y), 'a.as('z)) + + verifyConstraints(aliasedRelation.analyze.constraints, + Set(resolveColumn(aliasedRelation.analyze, "x") > 10, + IsNotNull(resolveColumn(aliasedRelation.analyze, "x")), + resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"), + resolveColumn(aliasedRelation.analyze, "z") > 10, + IsNotNull(resolveColumn(aliasedRelation.analyze, "z")))) + } + test("propagating constraints in union") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int) val tr2 = LocalRelation('d.int, 'e.int, 'f.int)