Skip to content
Snippets Groups Projects
Commit bc65f60e authored by gatorsmile's avatar gatorsmile Committed by Michael Armbrust
Browse files

[SPARK-13544][SQL] Rewrite/Propagate Constraints for Aliases in Aggregate

#### What changes were proposed in this pull request?

After analysis by Analyzer, two operators could have alias. They are `Project` and `Aggregate`. So far, we only rewrite and propagate constraints if `Alias` is defined in `Project`. This PR is to resolve this issue in `Aggregate`.

#### How was this patch tested?

Added a test case for `Aggregate` in `ConstraintPropagationSuite`.

marmbrus sameeragarwal

Author: gatorsmile <gatorsmile@gmail.com>

Closes #11422 from gatorsmile/validConstraintsInUnaryNodes.
parent 02aa499d
No related branches found
No related tags found
No related merge requests found
......@@ -315,6 +315,22 @@ abstract class UnaryNode extends LogicalPlan {
override def children: Seq[LogicalPlan] = child :: Nil
/**
* Generates an additional set of aliased constraints by replacing the original constraint
* expressions with the corresponding alias
*/
protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = {
projectList.flatMap {
case a @ Alias(e, _) =>
child.constraints.map(_ transform {
case expr: Expression if expr.semanticEquals(e) =>
a.toAttribute
}).union(Set(EqualNullSafe(e, a.toAttribute)))
case _ =>
Set.empty[Expression]
}.toSet
}
override protected def validConstraints: Set[Expression] = child.constraints
override def statistics: Statistics = {
......
......@@ -51,25 +51,8 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
!expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions
}
/**
* Generates an additional set of aliased constraints by replacing the original constraint
* expressions with the corresponding alias
*/
private def getAliasedConstraints: Set[Expression] = {
projectList.flatMap {
case a @ Alias(e, _) =>
child.constraints.map(_ transform {
case expr: Expression if expr.semanticEquals(e) =>
a.toAttribute
}).union(Set(EqualNullSafe(e, a.toAttribute)))
case _ =>
Set.empty[Expression]
}.toSet
}
override def validConstraints: Set[Expression] = {
child.constraints.union(getAliasedConstraints)
}
override def validConstraints: Set[Expression] =
child.constraints.union(getAliasedConstraints(projectList))
}
/**
......@@ -126,9 +109,8 @@ case class Filter(condition: Expression, child: LogicalPlan)
override def maxRows: Option[Long] = child.maxRows
override protected def validConstraints: Set[Expression] = {
override protected def validConstraints: Set[Expression] =
child.constraints.union(splitConjunctivePredicates(condition).toSet)
}
}
abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
......@@ -157,9 +139,8 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
}
override protected def validConstraints: Set[Expression] = {
override protected def validConstraints: Set[Expression] =
leftConstraints.union(rightConstraints)
}
// Intersect are only resolved if they don't introduce ambiguous expression ids,
// since the Optimizer will convert Intersect to Join.
......@@ -442,6 +423,9 @@ case class Aggregate(
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
override def maxRows: Option[Long] = child.maxRows
override def validConstraints: Set[Expression] =
child.constraints.union(getAliasedConstraints(aggregateExpressions))
override def statistics: Statistics = {
if (groupingExpressions.isEmpty) {
Statistics(sizeInBytes = 1)
......
......@@ -72,6 +72,21 @@ class ConstraintPropagationSuite extends SparkFunSuite {
IsNotNull(resolveColumn(tr, "c"))))
}
test("propagating constraints in aggregate") {
val tr = LocalRelation('a.int, 'b.string, 'c.int)
assert(tr.analyze.constraints.isEmpty)
val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5)
.groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a).analyze
verifyConstraints(aliasedRelation.analyze.constraints,
Set(resolveColumn(aliasedRelation.analyze, "c1") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "c1")),
resolveColumn(aliasedRelation.analyze, "a") < 5,
IsNotNull(resolveColumn(aliasedRelation.analyze, "a"))))
}
test("propagating constraints in aliases") {
val tr = LocalRelation('a.int, 'b.string, 'c.int)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment