diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 8d8cc152ff29c2287193e833c853a5fb50560042..607c7c877cc14c9c12fff37363f62909421da4c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -69,8 +69,17 @@ class EquivalentExpressions { */ def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = { val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf - // the children of CodegenFallback will not be used to generate code (call eval() instead) - if (!skip && !addExpr(root) && !root.isInstanceOf[CodegenFallback]) { + // There are some special expressions that we should not recurse into children. + // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) + // 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination. + val shouldRecurse = root match { + // TODO: some expressions implements `CodegenFallback` but can still do codegen, + // e.g. `CaseWhen`, we should support them. + case _: CodegenFallback => false + case _: ReferenceToExpressions => false + case _ => true + } + if (!skip && !addExpr(root) && shouldRecurse) { root.children.foreach(addExprTree(_, ignoreLeaf)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 38ac13b208aabea1155b4b7ec6f65372eb5edce4..d29c27c14b0c3ff79829644dae176562980c5fa2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -110,13 +110,17 @@ class CodegenContext { } def declareMutableStates(): String = { - mutableStates.map { case (javaType, variableName, _) => + // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in + // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. + mutableStates.distinct.map { case (javaType, variableName, _) => s"private $javaType $variableName;" }.mkString("\n") } def initMutableStates(): String = { - mutableStates.map(_._3).mkString("\n") + // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in + // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. + mutableStates.distinct.map(_._3).mkString("\n") } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 0d84a594f71a9b62dbb9971a8ce4cf24961994bc..6eae3ed7ad6c0df5dcbf3a307c3a6300c5780aec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import scala.language.postfixOps +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions._ @@ -72,6 +73,16 @@ object NameAgg extends Aggregator[AggData, String, String] { } +object SeqAgg extends Aggregator[AggData, Seq[Int], Seq[Int]] { + def zero: Seq[Int] = Nil + def reduce(b: Seq[Int], a: AggData): Seq[Int] = a.a +: b + def merge(b1: Seq[Int], b2: Seq[Int]): Seq[Int] = b1 ++ b2 + def finish(r: Seq[Int]): Seq[Int] = r + override def bufferEncoder: Encoder[Seq[Int]] = ExpressionEncoder() + override def outputEncoder: Encoder[Seq[Int]] = ExpressionEncoder() +} + + class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT) extends Aggregator[IN, OUT, OUT] { @@ -212,4 +223,13 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j") checkAnswer(df.groupBy($"j").agg(RowAgg.toColumn), Row("a", 1) :: Row("b", 5) :: Nil) } + + test("SPARK-14675: ClassFormatError when use Seq as Aggregator buffer type") { + val ds = Seq(AggData(1, "a"), AggData(2, "a")).toDS() + + checkDataset( + ds.groupByKey(_.b).agg(SeqAgg.toColumn), + "a" -> Seq(1, 2) + ) + } }