Skip to content
Snippets Groups Projects
Commit 091f6a78 authored by Sameer Agarwal's avatar Sameer Agarwal Committed by Michael Armbrust
Browse files

[SPARK-13091][SQL] Rewrite/Propagate constraints for Aliases

This PR adds support for rewriting constraints if there are aliases in the query plan. For e.g., if there is a query of form `SELECT a, a AS b`, any constraints on `a` now also apply to `b`.

JIRA: https://issues.apache.org/jira/browse/SPARK-13091

cc marmbrus

Author: Sameer Agarwal <sameer@databricks.com>

Closes #11144 from sameeragarwal/alias.
parent 14844118
No related branches found
No related tags found
No related merge requests found
...@@ -50,6 +50,26 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend ...@@ -50,6 +50,26 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
!expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions
} }
/**
* Generates an additional set of aliased constraints by replacing the original constraint
* expressions with the corresponding alias
*/
private def getAliasedConstraints: Set[Expression] = {
projectList.flatMap {
case a @ Alias(e, _) =>
child.constraints.map(_ transform {
case expr: Expression if expr.semanticEquals(e) =>
a.toAttribute
}).union(Set(EqualNullSafe(e, a.toAttribute)))
case _ =>
Set.empty[Expression]
}.toSet
}
override def validConstraints: Set[Expression] = {
child.constraints.union(getAliasedConstraints)
}
} }
/** /**
......
...@@ -27,7 +27,10 @@ import org.apache.spark.sql.catalyst.plans.logical._ ...@@ -27,7 +27,10 @@ import org.apache.spark.sql.catalyst.plans.logical._
class ConstraintPropagationSuite extends SparkFunSuite { class ConstraintPropagationSuite extends SparkFunSuite {
private def resolveColumn(tr: LocalRelation, columnName: String): Expression = private def resolveColumn(tr: LocalRelation, columnName: String): Expression =
tr.analyze.resolveQuoted(columnName, caseInsensitiveResolution).get resolveColumn(tr.analyze, columnName)
private def resolveColumn(plan: LogicalPlan, columnName: String): Expression =
plan.resolveQuoted(columnName, caseInsensitiveResolution).get
private def verifyConstraints(found: Set[Expression], expected: Set[Expression]): Unit = { private def verifyConstraints(found: Set[Expression], expected: Set[Expression]): Unit = {
val missing = expected.filterNot(i => found.map(_.semanticEquals(i)).reduce(_ || _)) val missing = expected.filterNot(i => found.map(_.semanticEquals(i)).reduce(_ || _))
...@@ -69,6 +72,21 @@ class ConstraintPropagationSuite extends SparkFunSuite { ...@@ -69,6 +72,21 @@ class ConstraintPropagationSuite extends SparkFunSuite {
IsNotNull(resolveColumn(tr, "c")))) IsNotNull(resolveColumn(tr, "c"))))
} }
test("propagating constraints in aliases") {
val tr = LocalRelation('a.int, 'b.string, 'c.int)
assert(tr.where('c.attr > 10).select('a.as('x), 'b.as('y)).analyze.constraints.isEmpty)
val aliasedRelation = tr.where('a.attr > 10).select('a.as('x), 'b, 'b.as('y), 'a.as('z))
verifyConstraints(aliasedRelation.analyze.constraints,
Set(resolveColumn(aliasedRelation.analyze, "x") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "x")),
resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"),
resolveColumn(aliasedRelation.analyze, "z") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "z"))))
}
test("propagating constraints in union") { test("propagating constraints in union") {
val tr1 = LocalRelation('a.int, 'b.int, 'c.int) val tr1 = LocalRelation('a.int, 'b.int, 'c.int)
val tr2 = LocalRelation('d.int, 'e.int, 'f.int) val tr2 = LocalRelation('d.int, 'e.int, 'f.int)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment