Skip to content
Snippets Groups Projects
Commit c03299a1 authored by Yin Huai's avatar Yin Huai Committed by Reynold Xin
Browse files

[SPARK-4233] [SPARK-4367] [SPARK-3947] [SPARK-3056] [SQL] Aggregation Improvement

This is the first PR for the aggregation improvement, which is tracked by https://issues.apache.org/jira/browse/SPARK-4366 (umbrella JIRA). This PR contains work for its subtasks, SPARK-3056, SPARK-3947, SPARK-4233, and SPARK-4367.

This PR introduces a new code path for evaluating aggregate functions. This code path is guarded by `spark.sql.useAggregate2` and by default the value of this flag is true.

This new code path contains:
* A new aggregate function interface (`AggregateFunction2`) and 7 built-int aggregate functions based on this new interface (`AVG`, `COUNT`, `FIRST`, `LAST`, `MAX`, `MIN`, `SUM`)
* A UDAF interface (`UserDefinedAggregateFunction`) based on the new code path and two example UDAFs (`MyDoubleAvg` and `MyDoubleSum`).
* A sort-based aggregate operator (`Aggregate2Sort`) for the new aggregate function interface .
* A sort-based aggregate operator (`FinalAndCompleteAggregate2Sort`) for distinct aggregations (for distinct aggregations the query plan will use `Aggregate2Sort` and `FinalAndCompleteAggregate2Sort` together).

With this change, `spark.sql.useAggregate2` is `true`, the flow of compiling an aggregation query is:
1. Our analyzer looks up functions and returns aggregate functions built based on the old aggregate function interface.
2. When our planner is compiling the physical plan, it tries try to convert all aggregate functions to the ones built based on the new interface. The planner will fallback to the old code path if any of the following two conditions is true:
* code-gen is disabled.
* there is any function that cannot be converted (right now, Hive UDAFs).
* the schema of grouping expressions contain any complex data type.
* There are multiple distinct columns.

Right now, the new code path handles a single distinct column in the query (you can have multiple aggregate functions using that distinct column). For a query having a aggregate function with DISTINCT and regular aggregate functions, the generated plan will do partial aggregations for those regular aggregate function.

Thanks chenghao-intel for his initial work on it.

Author: Yin Huai <yhuai@databricks.com>
Author: Michael Armbrust <michael@databricks.com>

Closes #7458 from yhuai/UDAF and squashes the following commits:

7865f5e [Yin Huai] Put the catalyst expression in the comment of the generated code for it.
b04d6c8 [Yin Huai] Remove unnecessary change.
f1d5901 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF
35b0520 [Yin Huai] Use semanticEquals to replace grouping expressions in the output of the aggregate operator.
3b43b24 [Yin Huai] bug fix.
00eb298 [Yin Huai] Make it compile.
a3ca551 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF
e0afca3 [Yin Huai] Gracefully fallback to old aggregation code path.
8a8ac4a [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF
88c7d4d [Yin Huai] Enable spark.sql.useAggregate2 by default for testing purpose.
dc96fd1 [Yin Huai] Many updates:
85c9c4b [Yin Huai] newline.
43de3de [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF
c3614d7 [Yin Huai] Handle single distinct column.
68b8ee9 [Yin Huai] Support single distinct column set. WIP
3013579 [Yin Huai] Format.
d678aee [Yin Huai] Remove AggregateExpressionSuite.scala since our built-in aggregate functions will be based on AlgebraicAggregate and we need to have another way to test it.
e243ca6 [Yin Huai] Add aggregation iterators.
a101960 [Yin Huai] Change MyJavaUDAF to MyDoubleSum.
594cdf5 [Yin Huai] Change existing AggregateExpression to AggregateExpression1 and add an AggregateExpression as the common interface for both AggregateExpression1 and AggregateExpression2.
380880f [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF
0a827b3 [Yin Huai] Add comments and doc. Move some classes to the right places.
a19fea6 [Yin Huai] Add UDAF interface.
262d4c4 [Yin Huai] Make it compile.
b2e358e [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF
6edb5ac [Yin Huai] Format update.
70b169c [Yin Huai] Remove groupOrdering.
4721936 [Yin Huai] Add CheckAggregateFunction to extendedCheckRules.
d821a34 [Yin Huai] Cleanup.
32aea9c [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF
5b46d41 [Yin Huai] Bug fix.
aff9534 [Yin Huai] Make Aggregate2Sort work with both algebraic AggregateFunctions and non-algebraic AggregateFunctions.
2857b55 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF
4435f20 [Yin Huai] Add ConvertAggregateFunction to HiveContext's analyzer.
1b490ed [Michael Armbrust] make hive test
8cfa6a9 [Michael Armbrust] add test
1b0bb3f [Yin Huai] Do not bind references in AlgebraicAggregate and use code gen for all places.
072209f [Yin Huai] Bug fix: Handle expressions in grouping columns that are not attribute references.
f7d9e54 [Michael Armbrust] Merge remote-tracking branch 'apache/master' into UDAF
39ee975 [Yin Huai] Code cleanup: Remove unnecesary AttributeReferences.
b7720ba [Yin Huai] Add an analysis rule to convert aggregate function to the new version.
5c00f3f [Michael Armbrust] First draft of codegen
6bbc6ba [Michael Armbrust] now with correct answers\!
f7996d0 [Michael Armbrust] Add AlgebraicAggregate
dded1c5 [Yin Huai] wip
parent f4785f5b
No related branches found
No related tags found
No related merge requests found
Showing
with 921 additions and 83 deletions
......@@ -266,11 +266,12 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
}
}
| ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^
{ case udfName ~ exprs => UnresolvedFunction(udfName, exprs) }
{ case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) }
| ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs =>
lexical.normalizeKeyword(udfName) match {
case "sum" => SumDistinct(exprs.head)
case "count" => CountDistinct(exprs)
case name => UnresolvedFunction(name, exprs, isDistinct = true)
case _ => throw new AnalysisException(s"function $udfName does not support DISTINCT")
}
}
......
......@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2}
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
......@@ -277,7 +278,7 @@ class Analyzer(
Project(
projectList.flatMap {
case s: Star => s.expand(child.output, resolver)
case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) =>
case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
......@@ -517,9 +518,26 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan =>
q transformExpressions {
case u @ UnresolvedFunction(name, children) =>
case u @ UnresolvedFunction(name, children, isDistinct) =>
withPosition(u) {
registry.lookupFunction(name, children)
registry.lookupFunction(name, children) match {
// We get an aggregate function built based on AggregateFunction2 interface.
// So, we wrap it in AggregateExpression2.
case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct)
// Currently, our old aggregate function interface supports SUM(DISTINCT ...)
// and COUTN(DISTINCT ...).
case sumDistinct: SumDistinct => sumDistinct
case countDistinct: CountDistinct => countDistinct
// DISTINCT is not meaningful with Max and Min.
case max: Max if isDistinct => max
case min: Min if isDistinct => min
// For other aggregate functions, DISTINCT keyword is not supported for now.
// Once we converted to the new code path, we will allow using DISTINCT keyword.
case other if isDistinct =>
failAnalysis(s"$name does not support DISTINCT keyword.")
// If it does not have DISTINCT keyword, we will return it as is.
case other => other
}
}
}
}
......
......@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
......
......@@ -73,7 +73,10 @@ object UnresolvedAttribute {
def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name))
}
case class UnresolvedFunction(name: String, children: Seq[Expression])
case class UnresolvedFunction(
name: String,
children: Seq[Expression],
isDistinct: Boolean)
extends Expression with Unevaluable {
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
......
......@@ -32,7 +32,7 @@ import org.apache.spark.sql.types._
case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
extends LeafExpression with NamedExpression {
override def toString: String = s"input[$ordinal]"
override def toString: String = s"input[$ordinal, $dataType]"
override def eval(input: InternalRow): Any = input(ordinal)
......
......@@ -96,7 +96,8 @@ abstract class Expression extends TreeNode[Expression] {
val primitive = ctx.freshName("primitive")
val ve = GeneratedExpressionCode("", isNull, primitive)
ve.code = genCode(ctx, ve)
ve
// Add `this` in the comment.
ve.copy(s"/* $this */\n" + ve.code)
}
/**
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
case class Average(child: Expression) extends AlgebraicAggregate {
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
// Return data type.
override def dataType: DataType = resultType
// Expected input data type.
// TODO: Once we remove the old code path, we can use our analyzer to cast NullType
// to the default data type of the NumericType.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
private val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 4, scale + 4)
case DecimalType.Unlimited => DecimalType.Unlimited
case _ => DoubleType
}
private val sumDataType = child.dataType match {
case _ @ DecimalType() => DecimalType.Unlimited
case _ => DoubleType
}
private val currentSum = AttributeReference("currentSum", sumDataType)()
private val currentCount = AttributeReference("currentCount", LongType)()
override val bufferAttributes = currentSum :: currentCount :: Nil
override val initialValues = Seq(
/* currentSum = */ Cast(Literal(0), sumDataType),
/* currentCount = */ Literal(0L)
)
override val updateExpressions = Seq(
/* currentSum = */
Add(
currentSum,
Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)),
/* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
)
override val mergeExpressions = Seq(
/* currentSum = */ currentSum.left + currentSum.right,
/* currentCount = */ currentCount.left + currentCount.right
)
// If all input are nulls, currentCount will be 0 and we will get null after the division.
override val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType)
}
case class Count(child: Expression) extends AlgebraicAggregate {
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = false
// Return data type.
override def dataType: DataType = LongType
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
private val currentCount = AttributeReference("currentCount", LongType)()
override val bufferAttributes = currentCount :: Nil
override val initialValues = Seq(
/* currentCount = */ Literal(0L)
)
override val updateExpressions = Seq(
/* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
)
override val mergeExpressions = Seq(
/* currentCount = */ currentCount.left + currentCount.right
)
override val evaluateExpression = Cast(currentCount, LongType)
}
case class First(child: Expression) extends AlgebraicAggregate {
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
// First is not a deterministic function.
override def deterministic: Boolean = false
// Return data type.
override def dataType: DataType = child.dataType
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
private val first = AttributeReference("first", child.dataType)()
override val bufferAttributes = first :: Nil
override val initialValues = Seq(
/* first = */ Literal.create(null, child.dataType)
)
override val updateExpressions = Seq(
/* first = */ If(IsNull(first), child, first)
)
override val mergeExpressions = Seq(
/* first = */ If(IsNull(first.left), first.right, first.left)
)
override val evaluateExpression = first
}
case class Last(child: Expression) extends AlgebraicAggregate {
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
// Last is not a deterministic function.
override def deterministic: Boolean = false
// Return data type.
override def dataType: DataType = child.dataType
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
private val last = AttributeReference("last", child.dataType)()
override val bufferAttributes = last :: Nil
override val initialValues = Seq(
/* last = */ Literal.create(null, child.dataType)
)
override val updateExpressions = Seq(
/* last = */ If(IsNull(child), last, child)
)
override val mergeExpressions = Seq(
/* last = */ If(IsNull(last.right), last.left, last.right)
)
override val evaluateExpression = last
}
case class Max(child: Expression) extends AlgebraicAggregate {
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
// Return data type.
override def dataType: DataType = child.dataType
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
private val max = AttributeReference("max", child.dataType)()
override val bufferAttributes = max :: Nil
override val initialValues = Seq(
/* max = */ Literal.create(null, child.dataType)
)
override val updateExpressions = Seq(
/* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child))))
)
override val mergeExpressions = {
val greatest = Greatest(Seq(max.left, max.right))
Seq(
/* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest))
)
}
override val evaluateExpression = max
}
case class Min(child: Expression) extends AlgebraicAggregate {
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
// Return data type.
override def dataType: DataType = child.dataType
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
private val min = AttributeReference("min", child.dataType)()
override val bufferAttributes = min :: Nil
override val initialValues = Seq(
/* min = */ Literal.create(null, child.dataType)
)
override val updateExpressions = Seq(
/* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child))))
)
override val mergeExpressions = {
val least = Least(Seq(min.left, min.right))
Seq(
/* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least))
)
}
override val evaluateExpression = min
}
case class Sum(child: Expression) extends AlgebraicAggregate {
override def children: Seq[Expression] = child :: Nil
override def nullable: Boolean = true
// Return data type.
override def dataType: DataType = resultType
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))
private val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType(precision + 4, scale + 4)
case DecimalType.Unlimited => DecimalType.Unlimited
case _ => child.dataType
}
private val sumDataType = child.dataType match {
case _ @ DecimalType() => DecimalType.Unlimited
case _ => child.dataType
}
private val currentSum = AttributeReference("currentSum", sumDataType)()
private val zero = Cast(Literal(0), sumDataType)
override val bufferAttributes = currentSum :: Nil
override val initialValues = Seq(
/* currentSum = */ Literal.create(null, sumDataType)
)
override val updateExpressions = Seq(
/* currentSum = */
Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child, sumDataType)), currentSum))
)
override val mergeExpressions = {
val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right, sumDataType))
Seq(
/* currentSum = */
Coalesce(Seq(add, currentSum.left))
)
}
override val evaluateExpression = Cast(currentSum, resultType)
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
/** The mode of an [[AggregateFunction1]]. */
private[sql] sealed trait AggregateMode
/**
* An [[AggregateFunction1]] with [[Partial]] mode is used for partial aggregation.
* This function updates the given aggregation buffer with the original input of this
* function. When it has processed all input rows, the aggregation buffer is returned.
*/
private[sql] case object Partial extends AggregateMode
/**
* An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers
* containing intermediate results for this function.
* This function updates the given aggregation buffer by merging multiple aggregation buffers.
* When it has processed all input rows, the aggregation buffer is returned.
*/
private[sql] case object PartialMerge extends AggregateMode
/**
* An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers
* containing intermediate results for this function and the generate final result.
* This function updates the given aggregation buffer by merging multiple aggregation buffers.
* When it has processed all input rows, the final result of this function is returned.
*/
private[sql] case object Final extends AggregateMode
/**
* An [[AggregateFunction2]] with [[Partial]] mode is used to evaluate this function directly
* from original input rows without any partial aggregation.
* This function updates the given aggregation buffer with the original input of this
* function. When it has processed all input rows, the final result of this function is returned.
*/
private[sql] case object Complete extends AggregateMode
/**
* A place holder expressions used in code-gen, it does not change the corresponding value
* in the row.
*/
private[sql] case object NoOp extends Expression with Unevaluable {
override def nullable: Boolean = true
override def eval(input: InternalRow): Any = {
throw new TreeNodeException(
this, s"No function to evaluate expression. type: ${this.nodeName}")
}
override def dataType: DataType = NullType
override def children: Seq[Expression] = Nil
}
/**
* A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field
* (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
* @param aggregateFunction
* @param mode
* @param isDistinct
*/
private[sql] case class AggregateExpression2(
aggregateFunction: AggregateFunction2,
mode: AggregateMode,
isDistinct: Boolean) extends AggregateExpression {
override def children: Seq[Expression] = aggregateFunction :: Nil
override def dataType: DataType = aggregateFunction.dataType
override def foldable: Boolean = false
override def nullable: Boolean = aggregateFunction.nullable
override def references: AttributeSet = {
val childReferemces = mode match {
case Partial | Complete => aggregateFunction.references.toSeq
case PartialMerge | Final => aggregateFunction.bufferAttributes
}
AttributeSet(childReferemces)
}
override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)"
}
abstract class AggregateFunction2
extends Expression with ImplicitCastInputTypes {
self: Product =>
/** An aggregate function is not foldable. */
override def foldable: Boolean = false
/**
* The offset of this function's buffer in the underlying buffer shared with other functions.
*/
var bufferOffset: Int = 0
/** The schema of the aggregation buffer. */
def bufferSchema: StructType
/** Attributes of fields in bufferSchema. */
def bufferAttributes: Seq[AttributeReference]
/** Clones bufferAttributes. */
def cloneBufferAttributes: Seq[Attribute]
/**
* Initializes its aggregation buffer located in `buffer`.
* It will use bufferOffset to find the starting point of
* its buffer in the given `buffer` shared with other functions.
*/
def initialize(buffer: MutableRow): Unit
/**
* Updates its aggregation buffer located in `buffer` based on the given `input`.
* It will use bufferOffset to find the starting point of its buffer in the given `buffer`
* shared with other functions.
*/
def update(buffer: MutableRow, input: InternalRow): Unit
/**
* Updates its aggregation buffer located in `buffer1` by combining intermediate results
* in the current buffer and intermediate results from another buffer `buffer2`.
* It will use bufferOffset to find the starting point of its buffer in the given `buffer1`
* and `buffer2`.
*/
def merge(buffer1: MutableRow, buffer2: InternalRow): Unit
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
}
/**
* A helper class for aggregate functions that can be implemented in terms of catalyst expressions.
*/
abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable {
self: Product =>
val initialValues: Seq[Expression]
val updateExpressions: Seq[Expression]
val mergeExpressions: Seq[Expression]
val evaluateExpression: Expression
override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance())
/**
* A helper class for representing an attribute used in merging two
* aggregation buffers. When merging two buffers, `bufferLeft` and `bufferRight`,
* we merge buffer values and then update bufferLeft. A [[RichAttribute]]
* of an [[AttributeReference]] `a` has two functions `left` and `right`,
* which represent `a` in `bufferLeft` and `bufferRight`, respectively.
* @param a
*/
implicit class RichAttribute(a: AttributeReference) {
/** Represents this attribute at the mutable buffer side. */
def left: AttributeReference = a
/** Represents this attribute at the input buffer side (the data value is read-only). */
def right: AttributeReference = cloneBufferAttributes(bufferAttributes.indexOf(a))
}
/** An AlgebraicAggregate's bufferSchema is derived from bufferAttributes. */
override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes)
override def initialize(buffer: MutableRow): Unit = {
var i = 0
while (i < bufferAttributes.size) {
buffer(i + bufferOffset) = initialValues(i).eval()
i += 1
}
}
override def update(buffer: MutableRow, input: InternalRow): Unit = {
throw new UnsupportedOperationException(
"AlgebraicAggregate's update should not be called directly")
}
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
throw new UnsupportedOperationException(
"AlgebraicAggregate's merge should not be called directly")
}
override def eval(buffer: InternalRow): Any = {
throw new UnsupportedOperationException(
"AlgebraicAggregate's eval should not be called directly")
}
}
......@@ -27,7 +27,9 @@ import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
trait AggregateExpression extends Expression with Unevaluable {
trait AggregateExpression extends Expression with Unevaluable
trait AggregateExpression1 extends AggregateExpression {
/**
* Aggregate expressions should not be foldable.
......@@ -38,7 +40,7 @@ trait AggregateExpression extends Expression with Unevaluable {
* Creates a new instance that can be used to compute this aggregate expression for a group
* of input rows/
*/
def newInstance(): AggregateFunction
def newInstance(): AggregateFunction1
}
/**
......@@ -54,10 +56,10 @@ case class SplitEvaluation(
partialEvaluations: Seq[NamedExpression])
/**
* An [[AggregateExpression]] that can be partially computed without seeing all relevant tuples.
* An [[AggregateExpression1]] that can be partially computed without seeing all relevant tuples.
* These partial evaluations can then be combined to compute the actual answer.
*/
trait PartialAggregate extends AggregateExpression {
trait PartialAggregate1 extends AggregateExpression1 {
/**
* Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation.
......@@ -67,13 +69,13 @@ trait PartialAggregate extends AggregateExpression {
/**
* A specific implementation of an aggregate function. Used to wrap a generic
* [[AggregateExpression]] with an algorithm that will be used to compute one specific result.
* [[AggregateExpression1]] with an algorithm that will be used to compute one specific result.
*/
abstract class AggregateFunction
extends LeafExpression with AggregateExpression with Serializable {
abstract class AggregateFunction1
extends LeafExpression with AggregateExpression1 with Serializable {
/** Base should return the generic aggregate expression that this function is computing */
val base: AggregateExpression
val base: AggregateExpression1
override def nullable: Boolean = base.nullable
override def dataType: DataType = base.dataType
......@@ -81,12 +83,12 @@ abstract class AggregateFunction
def update(input: InternalRow): Unit
// Do we really need this?
override def newInstance(): AggregateFunction = {
override def newInstance(): AggregateFunction1 = {
makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
}
}
case class Min(child: Expression) extends UnaryExpression with PartialAggregate {
case class Min(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
......@@ -102,7 +104,7 @@ case class Min(child: Expression) extends UnaryExpression with PartialAggregate
TypeUtils.checkForOrderingExpr(child.dataType, "function min")
}
case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
case class MinFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType)
......@@ -119,7 +121,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr
override def eval(input: InternalRow): Any = currentMin.value
}
case class Max(child: Expression) extends UnaryExpression with PartialAggregate {
case class Max(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
......@@ -135,7 +137,7 @@ case class Max(child: Expression) extends UnaryExpression with PartialAggregate
TypeUtils.checkForOrderingExpr(child.dataType, "function max")
}
case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
case class MaxFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType)
......@@ -152,7 +154,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
override def eval(input: InternalRow): Any = currentMax.value
}
case class Count(child: Expression) extends UnaryExpression with PartialAggregate {
case class Count(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def nullable: Boolean = false
override def dataType: LongType.type = LongType
......@@ -165,7 +167,7 @@ case class Count(child: Expression) extends UnaryExpression with PartialAggregat
override def newInstance(): CountFunction = new CountFunction(child, this)
}
case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
case class CountFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
var count: Long = _
......@@ -180,7 +182,7 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
override def eval(input: InternalRow): Any = count
}
case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate {
case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate1 {
def this() = this(null)
override def children: Seq[Expression] = expressions
......@@ -200,8 +202,8 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate
case class CountDistinctFunction(
@transient expr: Seq[Expression],
@transient base: AggregateExpression)
extends AggregateFunction {
@transient base: AggregateExpression1)
extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
......@@ -220,7 +222,7 @@ case class CountDistinctFunction(
override def eval(input: InternalRow): Any = seen.size.toLong
}
case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression {
case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression1 {
def this() = this(null)
override def children: Seq[Expression] = expressions
......@@ -233,8 +235,8 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress
case class CollectHashSetFunction(
@transient expr: Seq[Expression],
@transient base: AggregateExpression)
extends AggregateFunction {
@transient base: AggregateExpression1)
extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
......@@ -255,7 +257,7 @@ case class CollectHashSetFunction(
}
}
case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression {
case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression1 {
def this() = this(null)
override def children: Seq[Expression] = inputSet :: Nil
......@@ -269,8 +271,8 @@ case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression
case class CombineSetsAndCountFunction(
@transient inputSet: Expression,
@transient base: AggregateExpression)
extends AggregateFunction {
@transient base: AggregateExpression1)
extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
......@@ -305,7 +307,7 @@ private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] {
}
case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
extends UnaryExpression with AggregateExpression {
extends UnaryExpression with AggregateExpression1 {
override def nullable: Boolean = false
override def dataType: DataType = HyperLogLogUDT
......@@ -317,9 +319,9 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
case class ApproxCountDistinctPartitionFunction(
expr: Expression,
base: AggregateExpression,
base: AggregateExpression1,
relativeSD: Double)
extends AggregateFunction {
extends AggregateFunction1 {
def this() = this(null, null, 0) // Required for serialization.
private val hyperLogLog = new HyperLogLog(relativeSD)
......@@ -335,7 +337,7 @@ case class ApproxCountDistinctPartitionFunction(
}
case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
extends UnaryExpression with AggregateExpression {
extends UnaryExpression with AggregateExpression1 {
override def nullable: Boolean = false
override def dataType: LongType.type = LongType
......@@ -347,9 +349,9 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
case class ApproxCountDistinctMergeFunction(
expr: Expression,
base: AggregateExpression,
base: AggregateExpression1,
relativeSD: Double)
extends AggregateFunction {
extends AggregateFunction1 {
def this() = this(null, null, 0) // Required for serialization.
private val hyperLogLog = new HyperLogLog(relativeSD)
......@@ -363,7 +365,7 @@ case class ApproxCountDistinctMergeFunction(
}
case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
extends UnaryExpression with PartialAggregate {
extends UnaryExpression with PartialAggregate1 {
override def nullable: Boolean = false
override def dataType: LongType.type = LongType
......@@ -381,7 +383,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this)
}
case class Average(child: Expression) extends UnaryExpression with PartialAggregate {
case class Average(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def prettyName: String = "avg"
......@@ -427,8 +429,8 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg
TypeUtils.checkForNumericExpr(child.dataType, "function average")
}
case class AverageFunction(expr: Expression, base: AggregateExpression)
extends AggregateFunction {
case class AverageFunction(expr: Expression, base: AggregateExpression1)
extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
......@@ -474,7 +476,7 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
}
}
case class Sum(child: Expression) extends UnaryExpression with PartialAggregate {
case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def nullable: Boolean = true
......@@ -509,7 +511,7 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate
TypeUtils.checkForNumericExpr(child.dataType, "function sum")
}
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
case class SumFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
private val calcType =
......@@ -554,7 +556,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr
* <-- null <-- no data
* null <-- null <-- no data
*/
case class CombineSum(child: Expression) extends AggregateExpression {
case class CombineSum(child: Expression) extends AggregateExpression1 {
def this() = this(null)
override def children: Seq[Expression] = child :: Nil
......@@ -564,8 +566,8 @@ case class CombineSum(child: Expression) extends AggregateExpression {
override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this)
}
case class CombineSumFunction(expr: Expression, base: AggregateExpression)
extends AggregateFunction {
case class CombineSumFunction(expr: Expression, base: AggregateExpression1)
extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
......@@ -601,7 +603,7 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression)
}
}
case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate {
case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 {
def this() = this(null)
override def nullable: Boolean = true
......@@ -627,8 +629,8 @@ case class SumDistinct(child: Expression) extends UnaryExpression with PartialAg
TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct")
}
case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
extends AggregateFunction {
case class SumDistinctFunction(expr: Expression, base: AggregateExpression1)
extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
......@@ -653,7 +655,7 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
}
}
case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression {
case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression1 {
def this() = this(null, null)
override def children: Seq[Expression] = inputSet :: Nil
......@@ -667,8 +669,8 @@ case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends Agg
case class CombineSetsAndSumFunction(
@transient inputSet: Expression,
@transient base: AggregateExpression)
extends AggregateFunction {
@transient base: AggregateExpression1)
extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
......@@ -695,7 +697,7 @@ case class CombineSetsAndSumFunction(
}
}
case class First(child: Expression) extends UnaryExpression with PartialAggregate {
case class First(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
override def toString: String = s"FIRST($child)"
......@@ -709,7 +711,7 @@ case class First(child: Expression) extends UnaryExpression with PartialAggregat
override def newInstance(): FirstFunction = new FirstFunction(child, this)
}
case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
case class FirstFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
var result: Any = null
......@@ -723,7 +725,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag
override def eval(input: InternalRow): Any = result
}
case class Last(child: Expression) extends UnaryExpression with PartialAggregate {
case class Last(child: Expression) extends UnaryExpression with PartialAggregate1 {
override def references: AttributeSet = child.references
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
......@@ -738,7 +740,7 @@ case class Last(child: Expression) extends UnaryExpression with PartialAggregate
override def newInstance(): LastFunction = new LastFunction(child, this)
}
case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
case class LastFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
def this() = this(null, null) // Required for serialization.
var result: Any = null
......
......@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import scala.collection.mutable.ArrayBuffer
......@@ -38,15 +39,17 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
protected def create(expressions: Seq[Expression]): (() => MutableProjection) = {
val ctx = newCodeGenContext()
val projectionCode = expressions.zipWithIndex.map { case (e, i) =>
val evaluationCode = e.gen(ctx)
evaluationCode.code +
s"""
if(${evaluationCode.isNull})
mutableRow.setNullAt($i);
else
${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
"""
val projectionCode = expressions.zipWithIndex.map {
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
evaluationCode.code +
s"""
if(${evaluationCode.isNull})
mutableRow.setNullAt($i);
else
${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)};
"""
}
// collect projections into blocks as function has 64kb codesize limit in JVM
val projectionBlocks = new ArrayBuffer[String]()
......
......@@ -129,10 +129,10 @@ object PartialAggregation {
case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
// Collect all aggregate expressions.
val allAggregates =
aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a})
aggregateExpressions.flatMap(_ collect { case a: AggregateExpression1 => a})
// Collect all aggregate expressions that can be computed partially.
val partialAggregates =
aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p})
aggregateExpressions.flatMap(_ collect { case p: PartialAggregate1 => p})
// Only do partial aggregation if supported by all aggregate expressions.
if (allAggregates.size == partialAggregates.size) {
......
......@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
......
......@@ -402,6 +402,9 @@ private[spark] object SQLConf {
defaultValue = Some(true),
isPublic = false)
val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2",
defaultValue = Some(true), doc = "<TODO>")
val USE_SQL_SERIALIZER2 = booleanConf(
"spark.sql.useSerializer2",
defaultValue = Some(true), isPublic = false)
......@@ -473,6 +476,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED)
private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2)
private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2)
private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD)
......
......@@ -285,6 +285,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
val udf: UDFRegistration = new UDFRegistration(this)
@transient
val udaf: UDAFRegistration = new UDAFRegistration(this)
/**
* Returns true if the table is currently cached in-memory.
* @group cachemgmt
......@@ -863,6 +866,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
DDLStrategy ::
TakeOrderedAndProject ::
HashAggregation ::
Aggregation ::
LeftSemiJoin ::
HashJoin ::
InMemoryScans ::
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.{Expression}
import org.apache.spark.sql.expressions.aggregate.{ScalaUDAF, UserDefinedAggregateFunction}
class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
private val functionRegistry = sqlContext.functionRegistry
def register(
name: String,
func: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
def builder(children: Seq[Expression]) = ScalaUDAF(children, func)
functionRegistry.registerFunction(name, builder)
func
}
}
......@@ -68,14 +68,14 @@ case class Aggregate(
* output.
*/
case class ComputedAggregate(
unbound: AggregateExpression,
aggregate: AggregateExpression,
unbound: AggregateExpression1,
aggregate: AggregateExpression1,
resultAttribute: AttributeReference)
/** A list of aggregates that need to be computed for each group. */
private[this] val computedAggregates = aggregateExpressions.flatMap { agg =>
agg.collect {
case a: AggregateExpression =>
case a: AggregateExpression1 =>
ComputedAggregate(
a,
BindReferences.bindReference(a, child.output),
......@@ -87,8 +87,8 @@ case class Aggregate(
private[this] val computedSchema = computedAggregates.map(_.resultAttribute)
/** Creates a new aggregate buffer for a group. */
private[this] def newAggregateBuffer(): Array[AggregateFunction] = {
val buffer = new Array[AggregateFunction](computedAggregates.length)
private[this] def newAggregateBuffer(): Array[AggregateFunction1] = {
val buffer = new Array[AggregateFunction1](computedAggregates.length)
var i = 0
while (i < computedAggregates.length) {
buffer(i) = computedAggregates(i).aggregate.newInstance()
......@@ -146,7 +146,7 @@ case class Aggregate(
}
} else {
child.execute().mapPartitions { iter =>
val hashTable = new HashMap[InternalRow, Array[AggregateFunction]]
val hashTable = new HashMap[InternalRow, Array[AggregateFunction1]]
val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output)
var currentRow: InternalRow = null
......
......@@ -247,8 +247,15 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
}
def addSortIfNecessary(child: SparkPlan): SparkPlan = {
if (rowOrdering.nonEmpty && child.outputOrdering != rowOrdering) {
sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
if (rowOrdering.nonEmpty) {
// If child.outputOrdering is [a, b] and rowOrdering is [a], we do not need to sort.
val minSize = Seq(rowOrdering.size, child.outputOrdering.size).min
if (minSize == 0 || rowOrdering.take(minSize) != child.outputOrdering.take(minSize)) {
sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
} else {
child
}
} else {
child
}
......
......@@ -69,7 +69,7 @@ case class GeneratedAggregate(
protected override def doExecute(): RDD[InternalRow] = {
val aggregatesToCompute = aggregateExpressions.flatMap { a =>
a.collect { case agg: AggregateExpression => agg}
a.collect { case agg: AggregateExpression1 => agg}
}
// If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite
......
......@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.{SQLContext, Strategy, execution}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2}
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
......@@ -148,7 +149,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
if canBeCodeGened(
allAggregates(partialComputation) ++
allAggregates(rewrittenAggregateExpressions)) &&
codegenEnabled =>
codegenEnabled &&
!canBeConvertedToNewAggregation(plan) =>
execution.GeneratedAggregate(
partial = false,
namedGroupingAttributes,
......@@ -167,7 +169,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
rewrittenAggregateExpressions,
groupingExpressions,
partialComputation,
child) =>
child) if !canBeConvertedToNewAggregation(plan) =>
execution.Aggregate(
partial = false,
namedGroupingAttributes,
......@@ -181,7 +183,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case _ => Nil
}
def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists {
def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = {
aggregate.Utils.tryConvert(
plan,
sqlContext.conf.useSqlAggregate2,
sqlContext.conf.codegenEnabled).isDefined
}
def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = !aggs.exists {
case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false
// The generated set implementation is pretty limited ATM.
case CollectHashSet(exprs) if exprs.size == 1 &&
......@@ -189,10 +198,74 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case _ => true
}
def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression] =
exprs.flatMap(_.collect { case a: AggregateExpression => a })
def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] =
exprs.flatMap(_.collect { case a: AggregateExpression1 => a })
}
/**
* Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
*/
object Aggregation extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case p: logical.Aggregate =>
val converted =
aggregate.Utils.tryConvert(
p,
sqlContext.conf.useSqlAggregate2,
sqlContext.conf.codegenEnabled)
converted match {
case None => Nil // Cannot convert to new aggregation code path.
case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) =>
// Extracts all distinct aggregate expressions from the resultExpressions.
val aggregateExpressions = resultExpressions.flatMap { expr =>
expr.collect {
case agg: AggregateExpression2 => agg
}
}.toSet.toSeq
// For those distinct aggregate expressions, we create a map from the
// aggregate function to the corresponding attribute of the function.
val aggregateFunctionMap = aggregateExpressions.map { agg =>
val aggregateFunction = agg.aggregateFunction
(aggregateFunction, agg.isDistinct) ->
Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
}.toMap
val (functionsWithDistinct, functionsWithoutDistinct) =
aggregateExpressions.partition(_.isDistinct)
if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) {
// This is a sanity check. We should not reach here when we have multiple distinct
// column sets (aggregate.NewAggregation will not match).
sys.error(
"Multiple distinct column sets are not supported by the new aggregation" +
"code path.")
}
val aggregateOperator =
if (functionsWithDistinct.isEmpty) {
aggregate.Utils.planAggregateWithoutDistinct(
groupingExpressions,
aggregateExpressions,
aggregateFunctionMap,
resultExpressions,
planLater(child))
} else {
aggregate.Utils.planAggregateWithOneDistinct(
groupingExpressions,
functionsWithDistinct,
functionsWithoutDistinct,
aggregateFunctionMap,
resultExpressions,
planLater(child))
}
aggregateOperator
}
case _ => Nil
}
}
object BroadcastNestedLoopJoin extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Join(left, right, joinType, condition) =>
......@@ -336,8 +409,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Filter(condition, planLater(child)) :: Nil
case e @ logical.Expand(_, _, _, child) =>
execution.Expand(e.projections, e.output, planLater(child)) :: Nil
case logical.Aggregate(group, agg, child) =>
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
case a @ logical.Aggregate(group, agg, child) => {
val useNewAggregation =
aggregate.Utils.tryConvert(
a,
sqlContext.conf.useSqlAggregate2,
sqlContext.conf.codegenEnabled).isDefined
if (useNewAggregation) {
// If this logical.Aggregate can be planned to use new aggregation code path
// (i.e. it can be planned by the Strategy Aggregation), we will not use the old
// aggregation code path.
Nil
} else {
execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil
}
}
case logical.Window(projectList, windowExpressions, spec, child) =>
execution.Window(projectList, windowExpressions, spec, planLater(child)) :: Nil
case logical.Sample(lb, ub, withReplacement, seed, child) =>
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.aggregate
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{SparkPlan, UnaryNode}
case class Aggregate2Sort(
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression2],
aggregateAttributes: Seq[Attribute],
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryNode {
override def canProcessUnsafeRows: Boolean = true
override def references: AttributeSet = {
val referencesInResults =
AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes)
AttributeSet(
groupingExpressions.flatMap(_.references) ++
aggregateExpressions.flatMap(_.references) ++
referencesInResults)
}
override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}
override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
// TODO: We should not sort the input rows if they are just in reversed order.
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
}
override def outputOrdering: Seq[SortOrder] = {
// It is possible that the child.outputOrdering starts with the required
// ordering expressions (e.g. we require [a] as the sort expression and the
// child's outputOrdering is [a, b]). We can only guarantee the output rows
// are sorted by values of groupingExpressions.
groupingExpressions.map(SortOrder(_, Ascending))
}
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
child.execute().mapPartitions { iter =>
if (aggregateExpressions.length == 0) {
new GroupingIterator(
groupingExpressions,
resultExpressions,
newMutableProjection,
child.output,
iter)
} else {
val aggregationIterator: SortAggregationIterator = {
aggregateExpressions.map(_.mode).distinct.toList match {
case Partial :: Nil =>
new PartialSortAggregationIterator(
groupingExpressions,
aggregateExpressions,
newMutableProjection,
child.output,
iter)
case PartialMerge :: Nil =>
new PartialMergeSortAggregationIterator(
groupingExpressions,
aggregateExpressions,
newMutableProjection,
child.output,
iter)
case Final :: Nil =>
new FinalSortAggregationIterator(
groupingExpressions,
aggregateExpressions,
aggregateAttributes,
resultExpressions,
newMutableProjection,
child.output,
iter)
case other =>
sys.error(
s"Could not evaluate ${aggregateExpressions} because we do not support evaluate " +
s"modes $other in this operator.")
}
}
aggregationIterator
}
}
}
}
case class FinalAndCompleteAggregate2Sort(
previousGroupingExpressions: Seq[NamedExpression],
groupingExpressions: Seq[NamedExpression],
finalAggregateExpressions: Seq[AggregateExpression2],
finalAggregateAttributes: Seq[Attribute],
completeAggregateExpressions: Seq[AggregateExpression2],
completeAggregateAttributes: Seq[Attribute],
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryNode {
override def references: AttributeSet = {
val referencesInResults =
AttributeSet(resultExpressions.flatMap(_.references)) --
AttributeSet(finalAggregateExpressions) --
AttributeSet(completeAggregateExpressions)
AttributeSet(
groupingExpressions.flatMap(_.references) ++
finalAggregateExpressions.flatMap(_.references) ++
completeAggregateExpressions.flatMap(_.references) ++
referencesInResults)
}
override def requiredChildDistribution: List[Distribution] = {
if (groupingExpressions.isEmpty) {
AllTuples :: Nil
} else {
ClusteredDistribution(groupingExpressions) :: Nil
}
}
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
child.execute().mapPartitions { iter =>
new FinalAndCompleteSortAggregationIterator(
previousGroupingExpressions.length,
groupingExpressions,
finalAggregateExpressions,
finalAggregateAttributes,
completeAggregateExpressions,
completeAggregateAttributes,
resultExpressions,
newMutableProjection,
child.output,
iter)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment