diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 8095083f336e19274c5da56f84aa93aa86aa0acf..31e775d60f950b0363cd1263b448288b96c7d6e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -315,6 +315,22 @@ abstract class UnaryNode extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil + /** + * Generates an additional set of aliased constraints by replacing the original constraint + * expressions with the corresponding alias + */ + protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { + projectList.flatMap { + case a @ Alias(e, _) => + child.constraints.map(_ transform { + case expr: Expression if expr.semanticEquals(e) => + a.toAttribute + }).union(Set(EqualNullSafe(e, a.toAttribute))) + case _ => + Set.empty[Expression] + }.toSet + } + override protected def validConstraints: Set[Expression] = child.constraints override def statistics: Statistics = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 5d2a65b716b244b8392cb09acf3bdf6b302c7eac..e81a0f9487469251b357affc51a945be39bd2db8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -51,25 +51,8 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions } - /** - * Generates an additional set of aliased constraints by replacing the original constraint - * expressions with the corresponding alias - */ - private def getAliasedConstraints: Set[Expression] = { - projectList.flatMap { - case a @ Alias(e, _) => - child.constraints.map(_ transform { - case expr: Expression if expr.semanticEquals(e) => - a.toAttribute - }).union(Set(EqualNullSafe(e, a.toAttribute))) - case _ => - Set.empty[Expression] - }.toSet - } - - override def validConstraints: Set[Expression] = { - child.constraints.union(getAliasedConstraints) - } + override def validConstraints: Set[Expression] = + child.constraints.union(getAliasedConstraints(projectList)) } /** @@ -126,9 +109,8 @@ case class Filter(condition: Expression, child: LogicalPlan) override def maxRows: Option[Long] = child.maxRows - override protected def validConstraints: Set[Expression] = { + override protected def validConstraints: Set[Expression] = child.constraints.union(splitConjunctivePredicates(condition).toSet) - } } abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { @@ -157,9 +139,8 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } - override protected def validConstraints: Set[Expression] = { + override protected def validConstraints: Set[Expression] = leftConstraints.union(rightConstraints) - } // Intersect are only resolved if they don't introduce ambiguous expression ids, // since the Optimizer will convert Intersect to Join. @@ -442,6 +423,9 @@ case class Aggregate( override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows + override def validConstraints: Set[Expression] = + child.constraints.union(getAliasedConstraints(aggregateExpressions)) + override def statistics: Statistics = { if (groupingExpressions.isEmpty) { Statistics(sizeInBytes = 1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 373b1ffa83d2379e0edc7ec3517fcdfc7381aba5..b68432b1a128f16352bf564885aa672c34a80af7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -72,6 +72,21 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(tr, "c")))) } + test("propagating constraints in aggregate") { + val tr = LocalRelation('a.int, 'b.string, 'c.int) + + assert(tr.analyze.constraints.isEmpty) + + val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) + .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a).analyze + + verifyConstraints(aliasedRelation.analyze.constraints, + Set(resolveColumn(aliasedRelation.analyze, "c1") > 10, + IsNotNull(resolveColumn(aliasedRelation.analyze, "c1")), + resolveColumn(aliasedRelation.analyze, "a") < 5, + IsNotNull(resolveColumn(aliasedRelation.analyze, "a")))) + } + test("propagating constraints in aliases") { val tr = LocalRelation('a.int, 'b.string, 'c.int)