diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index b59f800e7cc0f6faa4b64d3a8ae9013adfa42f2b..813c62009666c52be41098d45e618fb6a7a927ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -36,8 +36,9 @@ object DefaultOptimizer extends Optimizer {
     // SubQueries are only needed for analysis and can be removed before execution.
     Batch("Remove SubQueries", FixedPoint(100),
       EliminateSubQueries) ::
-    Batch("Distinct", FixedPoint(100),
-      ReplaceDistinctWithAggregate) ::
+    Batch("Aggregate", FixedPoint(100),
+      ReplaceDistinctWithAggregate,
+      RemoveLiteralFromGroupExpressions) ::
     Batch("Operator Optimizations", FixedPoint(100),
       // Operator push down
       SetOperationPushDown,
@@ -799,3 +800,15 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
     case Distinct(child) => Aggregate(child.output, child.output, child)
   }
 }
+
+/**
+ * Removes literals from group expressions in [[Aggregate]], as they have no effect to the result
+ * but only makes the grouping key bigger.
+ */
+object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+    case a @ Aggregate(grouping, _, _) =>
+      val newGrouping = grouping.filter(!_.foldable)
+      a.copy(groupingExpressions = newGrouping)
+  }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 1e7b2a536ac128c8512f847a8f3ea32f903590e8..b9ca712c1ee1cc26539787490f20bf24b899fc8e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -144,14 +144,14 @@ object PartialAggregation {
         // time. However some of them might be unnamed so we alias them allowing them to be
         // referenced in the second aggregation.
         val namedGroupingExpressions: Seq[(Expression, NamedExpression)] =
-          groupingExpressions.filter(!_.isInstanceOf[Literal]).map {
+          groupingExpressions.map {
             case n: NamedExpression => (n, n)
             case other => (other, Alias(other, "PartialGroup")())
           }
 
         // Replace aggregations with a new expression that computes the result from the already
         // computed partial evaluations and grouping values.
-        val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
+        val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformDown {
           case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) =>
             partialEvaluations(new TreeNodeRef(e)).finalEvaluation
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
similarity index 72%
rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
index df29a62ff0e1564a5f0117d7ca4a025c4c216dd1..2d080b95b1292646faaadefd02d6401675b98b5a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
@@ -19,14 +19,17 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.Literal
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan}
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
 
-class ReplaceDistinctWithAggregateSuite extends PlanTest {
+class AggregateOptimizeSuite extends PlanTest {
 
   object Optimize extends RuleExecutor[LogicalPlan] {
-    val batches = Batch("ProjectCollapsing", Once, ReplaceDistinctWithAggregate) :: Nil
+    val batches = Batch("Aggregate", FixedPoint(100),
+      ReplaceDistinctWithAggregate,
+      RemoveLiteralFromGroupExpressions) :: Nil
   }
 
   test("replace distinct with aggregate") {
@@ -39,4 +42,16 @@ class ReplaceDistinctWithAggregateSuite extends PlanTest {
 
     comparePlans(optimized, correctAnswer)
   }
+
+  test("remove literals in grouping expression") {
+    val input = LocalRelation('a.int, 'b.int)
+
+    val query =
+      input.groupBy('a, Literal(1), Literal(1) + Literal(2))(sum('b))
+    val optimized = Optimize.execute(query)
+
+    val correctAnswer = input.groupBy('a)(sum('b))
+
+    comparePlans(optimized, correctAnswer)
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 8cef0b39f87dc98b414667bb6082429b0125f590..358e319476e83d35cbf80a7386a51b42085f6ee1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -463,12 +463,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
   }
 
   test("literal in agg grouping expressions") {
-    checkAnswer(
-      sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"),
-      Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
-    checkAnswer(
-      sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"),
-      Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
+    def literalInAggTest(): Unit = {
+      checkAnswer(
+        sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"),
+        Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
+      checkAnswer(
+        sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"),
+        Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
+
+      checkAnswer(
+        sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"),
+        sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
+      checkAnswer(
+        sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"),
+        sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
+      checkAnswer(
+        sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"),
+        sql("SELECT 1, 2, sum(b) FROM testData2"))
+    }
+
+    literalInAggTest()
+    withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
+      literalInAggTest()
+    }
   }
 
   test("aggregates with nulls") {