From 7687415c2578b5bdc79c9646c246e52da9a4dd4a Mon Sep 17 00:00:00 2001 From: ravipesala <ravindra.pesala@huawei.com> Date: Thu, 18 Dec 2014 20:19:10 -0800 Subject: [PATCH] [SPARK-2554][SQL] Supporting SumDistinct partial aggregation Adding support to the partial aggregation of SumDistinct Author: ravipesala <ravindra.pesala@huawei.com> Closes #3348 from ravipesala/SPARK-2554 and squashes the following commits: fd28e4d [ravipesala] Fixed review comments e60e67f [ravipesala] Fixed test cases and made it as nullable 32fe234 [ravipesala] Supporting SumDistinct partial aggregation Conflicts: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala --- .../sql/catalyst/expressions/aggregates.scala | 53 +++++++++++++++++-- .../sql/hive/execution/SQLQuerySuite.scala | 13 +++-- 2 files changed, 58 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 0cd90866e1..5ea9868e9e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -361,10 +361,10 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ } case class SumDistinct(child: Expression) - extends AggregateExpression with trees.UnaryNode[Expression] { + extends PartialAggregate with trees.UnaryNode[Expression] { + def this() = this(null) override def nullable = true - override def dataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive @@ -373,10 +373,55 @@ case class SumDistinct(child: Expression) case _ => child.dataType } + override def toString = s"SUM(DISTINCT ${child})" + override def newInstance() = new SumDistinctFunction(child, this) + + override def asPartial = { + val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() + SplitEvaluation( + CombineSetsAndSum(partialSet.toAttribute, this), + partialSet :: Nil) + } +} - override def toString = s"SUM(DISTINCT $child)" +case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { + def this() = this(null, null) - override def newInstance() = new SumDistinctFunction(child, this) + override def children = inputSet :: Nil + override def nullable = true + override def dataType = base.dataType + override def toString = s"CombineAndSum($inputSet)" + override def newInstance() = new CombineSetsAndSumFunction(inputSet, this) +} + +case class CombineSetsAndSumFunction( + @transient inputSet: Expression, + @transient base: AggregateExpression) + extends AggregateFunction { + + def this() = this(null, null) // Required for serialization. + + val seen = new OpenHashSet[Any]() + + override def update(input: Row): Unit = { + val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] + val inputIterator = inputSetEval.iterator + while (inputIterator.hasNext) { + seen.add(inputIterator.next) + } + } + + override def eval(input: Row): Any = { + val casted = seen.asInstanceOf[OpenHashSet[Row]] + if (casted.size == 0) { + null + } else { + Cast(Literal( + casted.iterator.map(f => f.apply(0)).reduceLeft( + base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), + base.dataType).eval(null) + } + } } case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 96f3430207..f57f31af15 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -185,9 +185,14 @@ class SQLQuerySuite extends QueryTest { sql("SELECT case when ~1=-2 then 1 else 0 end FROM src"), sql("SELECT 1 FROM src").collect().toSeq) } - - test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") { - checkAnswer(sql("SELECT key FROM src WHERE key not between 0 and 10 order by key"), - sql("SELECT key FROM src WHERE key between 11 and 500 order by key").collect().toSeq) + + test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") { + checkAnswer(sql("SELECT key FROM src WHERE key not between 0 and 10 order by key"), + sql("SELECT key FROM src WHERE key between 11 and 500 order by key").collect().toSeq) + } + + test("SPARK-2554 SumDistinct partial aggregation") { + checkAnswer(sql("SELECT sum( distinct key) FROM src group by key order by key"), + sql("SELECT distinct key FROM src order by key").collect().toSeq) } } -- GitLab