From dbf428c87ab34b6f76c75946043bdf5f60c9b1b3 Mon Sep 17 00:00:00 2001
From: Wenchen Fan <wenchen@databricks.com>
Date: Wed, 18 Nov 2015 10:33:17 -0800
Subject: [PATCH] [SPARK-11795][SQL] combine grouping attributes into a single
 NamedExpression
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

we use `ExpressionEncoder.tuple` to build the result encoder, which assumes the input encoder should point to a struct type field if it’s non-flat.
However, our keyEncoder always point to a flat field/fields: `groupingAttributes`, we should combine them into a single `NamedExpression`.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9792 from cloud-fan/agg.
---
 .../main/scala/org/apache/spark/sql/GroupedDataset.scala | 9 +++++++--
 .../test/scala/org/apache/spark/sql/DatasetSuite.scala   | 5 ++---
 2 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index c66162ee21..3f84e22a10 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.function._
 import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor}
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.QueryExecution
 
@@ -187,7 +187,12 @@ class GroupedDataset[K, T] private[sql](
     val namedColumns =
       columns.map(
         _.withInputType(resolvedTEncoder, dataAttributes).named)
-    val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan)
+    val keyColumn = if (groupingAttributes.length > 1) {
+      Alias(CreateStruct(groupingAttributes), "key")()
+    } else {
+      groupingAttributes.head
+    }
+    val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan)
     val execution = new QueryExecution(sqlContext, aggregate)
 
     new Dataset(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 198962b8fb..b6db583dfe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -84,8 +84,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       ("a", 2), ("b", 3), ("c", 4))
   }
 
-  ignore("Dataset should set the resolved encoders internally for maps") {
-    // TODO: Enable this once we fix SPARK-11793.
+  test("map and group by with class data") {
     // We inject a group by here to make sure this test case is future proof
     // when we implement better pipelining and local execution mode.
     val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS()
@@ -94,7 +93,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
 
     checkAnswer(
       ds,
-      (ClassData("one", 1), 1L), (ClassData("two", 2), 1L))
+      (ClassData("one", 2), 1L), (ClassData("two", 3), 1L))
   }
 
   test("select") {
-- 
GitLab