diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index 397eff05686b6e4a7e23ede2be5c7b9813c3f0bc..c0c960471a61add3d27185f2d74e57af7d989224 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -151,11 +151,12 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP } // Setup unique distinct aggregate children. - val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq - val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap - val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct + val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) + val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) // Setup expand & aggregate operators for distinct aggregate expressions. + val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { case ((group, expressions), i) => val id = Literal(i + 1) @@ -170,7 +171,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP val operators = expressions.map { e => val af = e.aggregateFunction val naf = patchAggregateFunctionChildren(af) { x => - evalWithinGroup(id, distinctAggChildAttrMap(x)) + evalWithinGroup(id, distinctAggChildAttrLookup(x)) } (e, e.copy(aggregateFunction = naf, isDistinct = false)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 6bf2c53440baf5739cee4c3449ddbdd18a4ec640..8253921563b3acac704cdab4cc3942d3473a176b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -66,6 +66,36 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun } } +class LongProductSum extends UserDefinedAggregateFunction { + def inputSchema: StructType = new StructType() + .add("a", LongType) + .add("b", LongType) + + def bufferSchema: StructType = new StructType() + .add("product", LongType) + + def dataType: DataType = LongType + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + buffer(0) = 0L + } + + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + if (!(input.isNullAt(0) || input.isNullAt(1))) { + buffer(0) = buffer.getLong(0) + input.getLong(0) * input.getLong(1) + } + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) + } + + def evaluate(buffer: Row): Any = + buffer.getLong(0) +} + abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ @@ -110,6 +140,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te // Register UDAFs sqlContext.udf.register("mydoublesum", new MyDoubleSum) sqlContext.udf.register("mydoubleavg", new MyDoubleAvg) + sqlContext.udf.register("longProductSum", new LongProductSum) } override def afterAll(): Unit = { @@ -545,19 +576,21 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te | count(distinct value2), | sum(distinct value2), | count(distinct value1, value2), + | longProductSum(distinct value1, value2), | count(value1), | sum(value1), | count(value2), | sum(value2), + | longProductSum(value1, value2), | count(*), | count(1) |FROM agg2 |GROUP BY key """.stripMargin), - Row(null, 3, 30, 3, 60, 3, 3, 30, 3, 60, 4, 4) :: - Row(1, 2, 40, 3, -10, 3, 3, 70, 3, -10, 3, 3) :: - Row(2, 2, 0, 1, 1, 1, 3, 1, 3, 3, 4, 4) :: - Row(3, 0, null, 1, 3, 0, 0, null, 1, 3, 2, 2) :: Nil) + Row(null, 3, 30, 3, 60, 3, -4700, 3, 30, 3, 60, -4700, 4, 4) :: + Row(1, 2, 40, 3, -10, 3, -100, 3, 70, 3, -10, -100, 3, 3) :: + Row(2, 2, 0, 1, 1, 1, 1, 3, 1, 3, 3, 2, 4, 4) :: + Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil) } test("test count") {