Skip to content
Snippets Groups Projects
Commit 5cb2e336 authored by Wenchen Fan's avatar Wenchen Fan Committed by Yin Huai
Browse files

[SPARK-14675][SQL] ClassFormatError when use Seq as Aggregator buffer type

## What changes were proposed in this pull request?

After https://github.com/apache/spark/pull/12067, we now use expressions to do the aggregation in `TypedAggregateExpression`. To implement buffer merge, we produce a new buffer deserializer expression by replacing `AttributeReference` with right-side buffer attribute, like other `DeclarativeAggregate`s do, and finally combine the left and right buffer deserializer with `Invoke`.

However, after https://github.com/apache/spark/pull/12338, we will add loop variable to class members when codegen `MapObjects`. If the `Aggregator` buffer type is `Seq`, which is implemented by `MapObjects` expression, we will add the same loop variable to class members twice(by left and right buffer deserializer), which cause the `ClassFormatError`.

This PR fixes this issue by calling `distinct` before declare the class menbers.

## How was this patch tested?

new regression test in `DatasetAggregatorSuite`

Author: Wenchen Fan <wenchen@databricks.com>

Closes #12468 from cloud-fan/bug.
parent 947b9020
No related branches found
No related tags found
No related merge requests found
...@@ -69,8 +69,17 @@ class EquivalentExpressions { ...@@ -69,8 +69,17 @@ class EquivalentExpressions {
*/ */
def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = { def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = {
val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf
// the children of CodegenFallback will not be used to generate code (call eval() instead) // There are some special expressions that we should not recurse into children.
if (!skip && !addExpr(root) && !root.isInstanceOf[CodegenFallback]) { // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
// 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination.
val shouldRecurse = root match {
// TODO: some expressions implements `CodegenFallback` but can still do codegen,
// e.g. `CaseWhen`, we should support them.
case _: CodegenFallback => false
case _: ReferenceToExpressions => false
case _ => true
}
if (!skip && !addExpr(root) && shouldRecurse) {
root.children.foreach(addExprTree(_, ignoreLeaf)) root.children.foreach(addExprTree(_, ignoreLeaf))
} }
} }
......
...@@ -110,13 +110,17 @@ class CodegenContext { ...@@ -110,13 +110,17 @@ class CodegenContext {
} }
def declareMutableStates(): String = { def declareMutableStates(): String = {
mutableStates.map { case (javaType, variableName, _) => // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
// `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
mutableStates.distinct.map { case (javaType, variableName, _) =>
s"private $javaType $variableName;" s"private $javaType $variableName;"
}.mkString("\n") }.mkString("\n")
} }
def initMutableStates(): String = { def initMutableStates(): String = {
mutableStates.map(_._3).mkString("\n") // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
// `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
mutableStates.distinct.map(_._3).mkString("\n")
} }
/** /**
......
...@@ -19,6 +19,7 @@ package org.apache.spark.sql ...@@ -19,6 +19,7 @@ package org.apache.spark.sql
import scala.language.postfixOps import scala.language.postfixOps
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.expressions.scala.typed
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
...@@ -72,6 +73,16 @@ object NameAgg extends Aggregator[AggData, String, String] { ...@@ -72,6 +73,16 @@ object NameAgg extends Aggregator[AggData, String, String] {
} }
object SeqAgg extends Aggregator[AggData, Seq[Int], Seq[Int]] {
def zero: Seq[Int] = Nil
def reduce(b: Seq[Int], a: AggData): Seq[Int] = a.a +: b
def merge(b1: Seq[Int], b2: Seq[Int]): Seq[Int] = b1 ++ b2
def finish(r: Seq[Int]): Seq[Int] = r
override def bufferEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
override def outputEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
}
class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT) class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT)
extends Aggregator[IN, OUT, OUT] { extends Aggregator[IN, OUT, OUT] {
...@@ -212,4 +223,13 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ...@@ -212,4 +223,13 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j") val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j")
checkAnswer(df.groupBy($"j").agg(RowAgg.toColumn), Row("a", 1) :: Row("b", 5) :: Nil) checkAnswer(df.groupBy($"j").agg(RowAgg.toColumn), Row("a", 1) :: Row("b", 5) :: Nil)
} }
test("SPARK-14675: ClassFormatError when use Seq as Aggregator buffer type") {
val ds = Seq(AggData(1, "a"), AggData(2, "a")).toDS()
checkDataset(
ds.groupByKey(_.b).agg(SeqAgg.toColumn),
"a" -> Seq(1, 2)
)
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment