diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 0cd90866e14a2dcb29f33f92177e25eaf201160a..5ea9868e9e8466b1d1e74422c67cd15aeb67dea8 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -361,10 +361,10 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ } case class SumDistinct(child: Expression) - extends AggregateExpression with trees.UnaryNode[Expression] { + extends PartialAggregate with trees.UnaryNode[Expression] { + def this() = this(null) override def nullable = true - override def dataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive @@ -373,10 +373,55 @@ case class SumDistinct(child: Expression) case _ => child.dataType } + override def toString = s"SUM(DISTINCT ${child})" + override def newInstance() = new SumDistinctFunction(child, this) + + override def asPartial = { + val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() + SplitEvaluation( + CombineSetsAndSum(partialSet.toAttribute, this), + partialSet :: Nil) + } +} - override def toString = s"SUM(DISTINCT $child)" +case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { + def this() = this(null, null) - override def newInstance() = new SumDistinctFunction(child, this) + override def children = inputSet :: Nil + override def nullable = true + override def dataType = base.dataType + override def toString = s"CombineAndSum($inputSet)" + override def newInstance() = new CombineSetsAndSumFunction(inputSet, this) +} + +case class CombineSetsAndSumFunction( + @transient inputSet: Expression, + @transient base: AggregateExpression) + extends AggregateFunction { + + def this() = this(null, null) // Required for serialization. + + val seen = new OpenHashSet[Any]() + + override def update(input: Row): Unit = { + val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] + val inputIterator = inputSetEval.iterator + while (inputIterator.hasNext) { + seen.add(inputIterator.next) + } + } + + override def eval(input: Row): Any = { + val casted = seen.asInstanceOf[OpenHashSet[Row]] + if (casted.size == 0) { + null + } else { + Cast(Literal( + casted.iterator.map(f => f.apply(0)).reduceLeft( + base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), + base.dataType).eval(null) + } + } } case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 96f34302079826c076c295bf857f085db7380731..f57f31af1556696a3f90a061b4d279d356fe70ea 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -185,9 +185,14 @@ class SQLQuerySuite extends QueryTest { sql("SELECT case when ~1=-2 then 1 else 0 end FROM src"), sql("SELECT 1 FROM src").collect().toSeq) } - - test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") { - checkAnswer(sql("SELECT key FROM src WHERE key not between 0 and 10 order by key"), - sql("SELECT key FROM src WHERE key between 11 and 500 order by key").collect().toSeq) + + test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") { + checkAnswer(sql("SELECT key FROM src WHERE key not between 0 and 10 order by key"), + sql("SELECT key FROM src WHERE key between 11 and 500 order by key").collect().toSeq) + } + + test("SPARK-2554 SumDistinct partial aggregation") { + checkAnswer(sql("SELECT sum( distinct key) FROM src group by key order by key"), + sql("SELECT distinct key FROM src order by key").collect().toSeq) } }