Skip to content
Snippets Groups Projects
Commit dbf428c8 authored by Wenchen Fan's avatar Wenchen Fan Committed by Michael Armbrust
Browse files

[SPARK-11795][SQL] combine grouping attributes into a single NamedExpression

we use `ExpressionEncoder.tuple` to build the result encoder, which assumes the input encoder should point to a struct type field if it’s non-flat.
However, our keyEncoder always point to a flat field/fields: `groupingAttributes`, we should combine them into a single `NamedExpression`.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9792 from cloud-fan/agg.
parent 33b83733
No related branches found
No related tags found
No related merge requests found
......@@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
......@@ -187,7 +187,12 @@ class GroupedDataset[K, T] private[sql](
val namedColumns =
columns.map(
_.withInputType(resolvedTEncoder, dataAttributes).named)
val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan)
val keyColumn = if (groupingAttributes.length > 1) {
Alias(CreateStruct(groupingAttributes), "key")()
} else {
groupingAttributes.head
}
val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan)
val execution = new QueryExecution(sqlContext, aggregate)
new Dataset(
......
......@@ -84,8 +84,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
("a", 2), ("b", 3), ("c", 4))
}
ignore("Dataset should set the resolved encoders internally for maps") {
// TODO: Enable this once we fix SPARK-11793.
test("map and group by with class data") {
// We inject a group by here to make sure this test case is future proof
// when we implement better pipelining and local execution mode.
val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS()
......@@ -94,7 +93,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkAnswer(
ds,
(ClassData("one", 1), 1L), (ClassData("two", 2), 1L))
(ClassData("one", 2), 1L), (ClassData("two", 3), 1L))
}
test("select") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment