diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 2f4d68d17943db7ca92c83ced8afef957cf6d606..eaeb010b0e4fa38c0d2f0a6bba8592027c8db2d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -33,10 +33,9 @@ import org.apache.spark.util.collection.OpenHashMap * The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at * the given percentage(s) with value range in [0.0, 1.0]. * - * The operator is bound to the slower sort based aggregation path because the number of elements - * and their partial order cannot be determined in advance. Therefore we have to store all the - * elements in memory, and that too many elements can cause GC paused and eventually OutOfMemory - * Errors. + * Because the number of elements and their partial order cannot be determined in advance. + * Therefore we have to store all the elements in memory, and so notice that too many elements can + * cause GC paused and eventually OutOfMemory Errors. * * @param child child expression that produce numeric column value with `child.eval(inputRow)` * @param percentageExpression Expression that represents a single percentage value or an array of diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index b176e2a128f43923e580e7b60de35f920fcf4fcf..411f058510ca7226202458593f775ea9671413be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + import scala.collection.generic.Growable import scala.collection.mutable @@ -27,14 +29,12 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ /** - * The Collect aggregate function collects all seen expression values into a list of values. + * A base class for collect_list and collect_set aggregate functions. * - * The operator is bound to the slower sort based aggregation path because the number of - * elements (and their memory usage) can not be determined in advance. This also means that the - * collected elements are stored on heap, and that too many elements can cause GC pauses and - * eventually Out of Memory Errors. + * We have to store all the collected elements in memory, and so notice that too many elements + * can cause GC paused and eventually OutOfMemory Errors. */ -abstract class Collect extends ImperativeAggregate { +abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImperativeAggregate[T] { val child: Expression @@ -44,40 +44,44 @@ abstract class Collect extends ImperativeAggregate { override def dataType: DataType = ArrayType(child.dataType) - override def supportsPartial: Boolean = false - - override def aggBufferAttributes: Seq[AttributeReference] = Nil - - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - - override def inputAggBufferAttributes: Seq[AttributeReference] = Nil - // Both `CollectList` and `CollectSet` are non-deterministic since their results depend on the // actual order of input rows. override def deterministic: Boolean = false - protected[this] val buffer: Growable[Any] with Iterable[Any] - - override def initialize(b: InternalRow): Unit = { - buffer.clear() - } + override def update(buffer: T, input: InternalRow): T = { + val value = child.eval(input) - override def update(b: InternalRow, input: InternalRow): Unit = { // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here. // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator - val value = child.eval(input) if (value != null) { buffer += value } + buffer } - override def merge(buffer: InternalRow, input: InternalRow): Unit = { - sys.error("Collect cannot be used in partial aggregations.") + override def merge(buffer: T, other: T): T = { + buffer ++= other } - override def eval(input: InternalRow): Any = { + override def eval(buffer: T): Any = { new GenericArrayData(buffer.toArray) } + + private lazy val projection = UnsafeProjection.create( + Array[DataType](ArrayType(elementType = child.dataType, containsNull = false))) + private lazy val row = new UnsafeRow(1) + + override def serialize(obj: T): Array[Byte] = { + val array = new GenericArrayData(obj.toArray) + projection.apply(InternalRow.apply(array)).getBytes() + } + + override def deserialize(bytes: Array[Byte]): T = { + val buffer = createAggregationBuffer() + row.pointTo(bytes, bytes.length) + row.getArray(0).foreach(child.dataType, (_, x: Any) => buffer += x) + buffer + } } /** @@ -88,7 +92,7 @@ abstract class Collect extends ImperativeAggregate { case class CollectList( child: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends Collect { + inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] { def this(child: Expression) = this(child, 0, 0) @@ -98,9 +102,9 @@ case class CollectList( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = copy(inputAggBufferOffset = newInputAggBufferOffset) - override def prettyName: String = "collect_list" + override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty - override protected[this] val buffer: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty + override def prettyName: String = "collect_list" } /** @@ -111,7 +115,7 @@ case class CollectList( case class CollectSet( child: Expression, mutableAggBufferOffset: Int = 0, - inputAggBufferOffset: Int = 0) extends Collect { + inputAggBufferOffset: Int = 0) extends Collect[mutable.HashSet[Any]] { def this(child: Expression) = this(child, 0, 0) @@ -131,5 +135,5 @@ case class CollectSet( override def prettyName: String = "collect_set" - override protected[this] val buffer: mutable.HashSet[Any] = mutable.HashSet.empty + override def createAggregationBuffer(): mutable.HashSet[Any] = mutable.HashSet.empty } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 8e63fba14ce541923b5efd46fb79406b81633bca..ccd4ae6c2d845c6899d9524a99682290ed0b554e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -458,7 +458,9 @@ abstract class DeclarativeAggregate * instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation * buffer's storage format, which is not supported by hash based aggregation. Hash based * aggregation only support aggregation buffer of mutable types (like LongType, IntType that have - * fixed length and can be mutated in place in UnsafeRow) + * fixed length and can be mutated in place in UnsafeRow). + * NOTE: The newly added ObjectHashAggregateExec supports TypedImperativeAggregate functions in + * hash based aggregation under some constraints. */ abstract class TypedImperativeAggregate[T] extends ImperativeAggregate { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index 0b973c3b659cf3d74b02253a276673820ada2e18..5c1faaecdb548e627e4d3ba8ff9dd71d3c0982fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -59,15 +59,6 @@ class RewriteDistinctAggregatesSuite extends PlanTest { comparePlans(input, rewrite) } - test("single distinct group with non-partial aggregates") { - val input = testRelation - .groupBy('a, 'd)( - countDistinct('e, 'c).as('agg1), - CollectSet('b).toAggregateExpression().as('agg2)) - .analyze - checkRewrite(RewriteDistinctAggregates(input)) - } - test("multiple distinct groups") { val input = testRelation .groupBy('a)(countDistinct('b, 'c), countDistinct('d))