diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ed6e17a8eb46545e54e2b5b5823a19d95fd61431..58f98d529ab588d308f3455d878de59f37a57e4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -463,14 +463,15 @@ class Analyzer( .toAggregateExpression() , "__pivot_" + a.sql)() } - val secondAgg = Aggregate(groupByExprs, groupByExprs ++ pivotAggs, firstAgg) + val groupByExprsAttr = groupByExprs.map(_.toAttribute) + val secondAgg = Aggregate(groupByExprsAttr, groupByExprsAttr ++ pivotAggs, firstAgg) val pivotAggAttribute = pivotAggs.map(_.toAttribute) val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, i) => aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) => Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))() } } - Project(groupByExprs ++ pivotOutputs, secondAgg) + Project(groupByExprsAttr ++ pivotOutputs, secondAgg) } else { val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => def ifExpr(expr: Expression) = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 1bbe1354d55f48715eaa459f866595d1b8366047..a8d854ccbc94362aed2cae64006cecc814b17f8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -208,4 +208,12 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ ) } + test("pivot with column definition in groupby") { + checkAnswer( + courseSales.groupBy(substring(col("course"), 0, 1).as("foo")) + .pivot("year", Seq(2012, 2013)) + .sum("earnings"), + Row("d", 15000.0, 48000.0) :: Row("J", 20000.0, 30000.0) :: Nil + ) + } }