From 091f6a7830bbee01fa580fbb0336b9f4fcac0dfa Mon Sep 17 00:00:00 2001 From: Sameer Agarwal <sameer@databricks.com> Date: Fri, 19 Feb 2016 14:48:34 -0800 Subject: [PATCH] [SPARK-13091][SQL] Rewrite/Propagate constraints for Aliases This PR adds support for rewriting constraints if there are aliases in the query plan. For e.g., if there is a query of form `SELECT a, a AS b`, any constraints on `a` now also apply to `b`. JIRA: https://issues.apache.org/jira/browse/SPARK-13091 cc marmbrus Author: Sameer Agarwal <sameer@databricks.com> Closes #11144 from sameeragarwal/alias. --- .../plans/logical/basicOperators.scala | 20 +++++++++++++++++++ .../plans/ConstraintPropagationSuite.scala | 20 ++++++++++++++++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 502d898fea..7d155ac183 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -50,6 +50,26 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions } + + /** + * Generates an additional set of aliased constraints by replacing the original constraint + * expressions with the corresponding alias + */ + private def getAliasedConstraints: Set[Expression] = { + projectList.flatMap { + case a @ Alias(e, _) => + child.constraints.map(_ transform { + case expr: Expression if expr.semanticEquals(e) => + a.toAttribute + }).union(Set(EqualNullSafe(e, a.toAttribute))) + case _ => + Set.empty[Expression] + }.toSet + } + + override def validConstraints: Set[Expression] = { + child.constraints.union(getAliasedConstraints) + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index b5cf91394d..373b1ffa83 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -27,7 +27,10 @@ import org.apache.spark.sql.catalyst.plans.logical._ class ConstraintPropagationSuite extends SparkFunSuite { private def resolveColumn(tr: LocalRelation, columnName: String): Expression = - tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get + resolveColumn(tr.analyze, columnName) + + private def resolveColumn(plan: LogicalPlan, columnName: String): Expression = + plan.resolveQuoted(columnName, caseInsensitiveResolution).get private def verifyConstraints(found: Set[Expression], expected: Set[Expression]): Unit = { val missing = expected.filterNot(i => found.map(_.semanticEquals(i)).reduce(_ || _)) @@ -69,6 +72,21 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(tr, "c")))) } + test("propagating constraints in aliases") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + + assert(tr.where('c.attr > 10).select('a.as('x), 'b.as('y)).analyze.constraints.isEmpty) + + val aliasedRelation = tr.where('a.attr > 10).select('a.as('x), 'b, 'b.as('y), 'a.as('z)) + + verifyConstraints(aliasedRelation.analyze.constraints, + Set(resolveColumn(aliasedRelation.analyze, "x") > 10, + IsNotNull(resolveColumn(aliasedRelation.analyze, "x")), + resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"), + resolveColumn(aliasedRelation.analyze, "z") > 10, + IsNotNull(resolveColumn(aliasedRelation.analyze, "z")))) + } + test("propagating constraints in union") { val tr1 = LocalRelation('a.int, 'b.int, 'c.int) val tr2 = LocalRelation('d.int, 'e.int, 'f.int) -- GitLab