diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 8095083f336e19274c5da56f84aa93aa86aa0acf..31e775d60f950b0363cd1263b448288b96c7d6e8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -315,6 +315,22 @@ abstract class UnaryNode extends LogicalPlan {
 
   override def children: Seq[LogicalPlan] = child :: Nil
 
+  /**
+   * Generates an additional set of aliased constraints by replacing the original constraint
+   * expressions with the corresponding alias
+   */
+  protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = {
+    projectList.flatMap {
+      case a @ Alias(e, _) =>
+        child.constraints.map(_ transform {
+          case expr: Expression if expr.semanticEquals(e) =>
+            a.toAttribute
+        }).union(Set(EqualNullSafe(e, a.toAttribute)))
+      case _ =>
+        Set.empty[Expression]
+    }.toSet
+  }
+
   override protected def validConstraints: Set[Expression] = child.constraints
 
   override def statistics: Statistics = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 5d2a65b716b244b8392cb09acf3bdf6b302c7eac..e81a0f9487469251b357affc51a945be39bd2db8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -51,25 +51,8 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
     !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions
   }
 
-  /**
-   * Generates an additional set of aliased constraints by replacing the original constraint
-   * expressions with the corresponding alias
-   */
-  private def getAliasedConstraints: Set[Expression] = {
-    projectList.flatMap {
-      case a @ Alias(e, _) =>
-        child.constraints.map(_ transform {
-          case expr: Expression if expr.semanticEquals(e) =>
-            a.toAttribute
-        }).union(Set(EqualNullSafe(e, a.toAttribute)))
-      case _ =>
-        Set.empty[Expression]
-    }.toSet
-  }
-
-  override def validConstraints: Set[Expression] = {
-    child.constraints.union(getAliasedConstraints)
-  }
+  override def validConstraints: Set[Expression] =
+    child.constraints.union(getAliasedConstraints(projectList))
 }
 
 /**
@@ -126,9 +109,8 @@ case class Filter(condition: Expression, child: LogicalPlan)
 
   override def maxRows: Option[Long] = child.maxRows
 
-  override protected def validConstraints: Set[Expression] = {
+  override protected def validConstraints: Set[Expression] =
     child.constraints.union(splitConjunctivePredicates(condition).toSet)
-  }
 }
 
 abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
@@ -157,9 +139,8 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
       leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
     }
 
-  override protected def validConstraints: Set[Expression] = {
+  override protected def validConstraints: Set[Expression] =
     leftConstraints.union(rightConstraints)
-  }
 
   // Intersect are only resolved if they don't introduce ambiguous expression ids,
   // since the Optimizer will convert Intersect to Join.
@@ -442,6 +423,9 @@ case class Aggregate(
   override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
   override def maxRows: Option[Long] = child.maxRows
 
+  override def validConstraints: Set[Expression] =
+    child.constraints.union(getAliasedConstraints(aggregateExpressions))
+
   override def statistics: Statistics = {
     if (groupingExpressions.isEmpty) {
       Statistics(sizeInBytes = 1)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
index 373b1ffa83d2379e0edc7ec3517fcdfc7381aba5..b68432b1a128f16352bf564885aa672c34a80af7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -72,6 +72,21 @@ class ConstraintPropagationSuite extends SparkFunSuite {
         IsNotNull(resolveColumn(tr, "c"))))
   }
 
+  test("propagating constraints in aggregate") {
+    val tr = LocalRelation('a.int, 'b.string, 'c.int)
+
+    assert(tr.analyze.constraints.isEmpty)
+
+    val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5)
+      .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a).analyze
+
+    verifyConstraints(aliasedRelation.analyze.constraints,
+      Set(resolveColumn(aliasedRelation.analyze, "c1") > 10,
+        IsNotNull(resolveColumn(aliasedRelation.analyze, "c1")),
+        resolveColumn(aliasedRelation.analyze, "a") < 5,
+        IsNotNull(resolveColumn(aliasedRelation.analyze, "a"))))
+  }
+
   test("propagating constraints in aliases") {
     val tr = LocalRelation('a.int, 'b.string, 'c.int)