diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 83379ae90f703b4f72f76ee8e242bafdeb3c707c..1e113ccd4e137438579c09c762dc27eba24209f6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -33,15 +33,14 @@ object Utils {
       resultExpressions: Seq[NamedExpression],
       child: SparkPlan): Seq[SparkPlan] = {
 
-    val groupingAttributes = groupingExpressions.map(_.toAttribute)
     val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete))
     val completeAggregateAttributes = completeAggregateExpressions.map {
       expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct)
     }
 
     SortBasedAggregate(
-      requiredChildDistributionExpressions = Some(groupingAttributes),
-      groupingExpressions = groupingAttributes,
+      requiredChildDistributionExpressions = Some(groupingExpressions),
+      groupingExpressions = groupingExpressions,
       aggregateExpressions = completeAggregateExpressions,
       aggregateAttributes = completeAggregateAttributes,
       initialInputBufferOffset = 0,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 3e4cf3f79e57c83c1cc3a3280bdef5b63b88f90b..7a9ed1eaf3dbc0936ed4718d2201684c9cd9180c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -193,6 +193,14 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
     sqlContext.dropTempTable("emptyTable")
   }
 
+  test("group by function") {
+    Seq((1, 2)).toDF("a", "b").registerTempTable("data")
+
+    checkAnswer(
+      sql("SELECT floor(a) AS a, collect_set(b) FROM data GROUP BY floor(a) ORDER BY a"),
+      Row(1, Array(2)) :: Nil)
+  }
+
   test("empty table") {
     // If there is no GROUP BY clause and the table is empty, we will generate a single row.
     checkAnswer(