Skip to content
Snippets Groups Projects
Commit 21c562fa authored by Herman van Hovell's avatar Herman van Hovell Committed by Yin Huai
Browse files

[SPARK-9241][SQL] Supporting multiple DISTINCT columns - follow-up (3)

This PR is a 2nd follow-up for [SPARK-9241](https://issues.apache.org/jira/browse/SPARK-9241). It contains the following improvements:
* Fix for a potential bug in distinct child expression and attribute alignment.
* Improved handling of duplicate distinct child expressions.
* Added test for distinct UDAF with multiple children.

cc yhuai

Author: Herman van Hovell <hvanhovell@questtec.nl>

Closes #9566 from hvanhovell/SPARK-9241-followup-2.
parent 3121e781
No related branches found
No related tags found
No related merge requests found
......@@ -151,11 +151,12 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
}
// Setup unique distinct aggregate children.
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq
val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap
val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct
val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair)
val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2)
// Setup expand & aggregate operators for distinct aggregate expressions.
val distinctAggChildAttrLookup = distinctAggChildAttrMap.toMap
val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map {
case ((group, expressions), i) =>
val id = Literal(i + 1)
......@@ -170,7 +171,7 @@ case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalP
val operators = expressions.map { e =>
val af = e.aggregateFunction
val naf = patchAggregateFunctionChildren(af) { x =>
evalWithinGroup(id, distinctAggChildAttrMap(x))
evalWithinGroup(id, distinctAggChildAttrLookup(x))
}
(e, e.copy(aggregateFunction = naf, isDistinct = false))
}
......
......@@ -66,6 +66,36 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun
}
}
class LongProductSum extends UserDefinedAggregateFunction {
def inputSchema: StructType = new StructType()
.add("a", LongType)
.add("b", LongType)
def bufferSchema: StructType = new StructType()
.add("product", LongType)
def dataType: DataType = LongType
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!(input.isNullAt(0) || input.isNullAt(1))) {
buffer(0) = buffer.getLong(0) + input.getLong(0) * input.getLong(1)
}
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
}
def evaluate(buffer: Row): Any =
buffer.getLong(0)
}
abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import testImplicits._
......@@ -110,6 +140,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
// Register UDAFs
sqlContext.udf.register("mydoublesum", new MyDoubleSum)
sqlContext.udf.register("mydoubleavg", new MyDoubleAvg)
sqlContext.udf.register("longProductSum", new LongProductSum)
}
override def afterAll(): Unit = {
......@@ -545,19 +576,21 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
| count(distinct value2),
| sum(distinct value2),
| count(distinct value1, value2),
| longProductSum(distinct value1, value2),
| count(value1),
| sum(value1),
| count(value2),
| sum(value2),
| longProductSum(value1, value2),
| count(*),
| count(1)
|FROM agg2
|GROUP BY key
""".stripMargin),
Row(null, 3, 30, 3, 60, 3, 3, 30, 3, 60, 4, 4) ::
Row(1, 2, 40, 3, -10, 3, 3, 70, 3, -10, 3, 3) ::
Row(2, 2, 0, 1, 1, 1, 3, 1, 3, 3, 4, 4) ::
Row(3, 0, null, 1, 3, 0, 0, null, 1, 3, 2, 2) :: Nil)
Row(null, 3, 30, 3, 60, 3, -4700, 3, 30, 3, 60, -4700, 4, 4) ::
Row(1, 2, 40, 3, -10, 3, -100, 3, 70, 3, -10, -100, 3, 3) ::
Row(2, 2, 0, 1, 1, 1, 1, 3, 1, 3, 3, 2, 4, 4) ::
Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil)
}
test("test count") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment