diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala index 9154e96e34e9c452c46f52aa2a4eb0d07983348d..9ead571c5374a3872fb43dc17f38d8b28f661888 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -141,12 +141,12 @@ case class PivotFirst( copy(mutableAggBufferOffset = newMutableAggBufferOffset) - override lazy val aggBufferAttributes: Seq[AttributeReference] = + override val aggBufferAttributes: Seq[AttributeReference] = pivotIndex.toList.sortBy(_._2).map(kv => AttributeReference(kv._1.toString, valueDataType)()) - override lazy val aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) + override val aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - override lazy val inputAggBufferAttributes: Seq[AttributeReference] = + override val inputAggBufferAttributes: Seq[AttributeReference] = aggBufferAttributes.map(_.newInstance()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index b17284aa94d2f5ff10ff51e260bfcc213838ee4b..c6d67519b0e9486f3075a78ae1fef2c422203011 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -180,4 +180,21 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ ) } + test("pivot with datatype not supported by PivotFirst") { + checkAnswer( + complexData.groupBy().pivot("b", Seq(true, false)).agg(max("a")), + Row(Seq(1, 1, 1), Seq(2, 2, 2)) :: Nil + ) + } + + test("pivot with datatype not supported by PivotFirst 2") { + checkAnswer( + courseSales.withColumn("e", expr("array(earnings, 7.0d)")) + .groupBy("year") + .pivot("course", Seq("dotNET", "Java")) + .agg(min($"e")), + Row(2012, Seq(5000.0, 7.0), Seq(20000.0, 7.0)) :: + Row(2013, Seq(48000.0, 7.0), Seq(30000.0, 7.0)) :: Nil + ) + } }