From 341b13f8f5eb118f1fb4d4f84418715ac4750a4d Mon Sep 17 00:00:00 2001 From: Wenchen Fan <cloud0fan@163.com> Date: Thu, 24 Sep 2015 09:54:07 -0700 Subject: [PATCH] [SPARK-10765] [SQL] use new aggregate interface for hive UDAF Author: Wenchen Fan <cloud0fan@163.com> Closes #8874 from cloud-fan/hive-agg. --- .../expressions/aggregate/interfaces.scala | 7 +- .../spark/sql/execution/SparkStrategies.scala | 14 +- .../aggregate/AggregationIterator.scala | 2 +- .../spark/sql/execution/aggregate/utils.scala | 51 +++++++ .../org/apache/spark/sql/hive/hiveUDFs.scala | 139 +++++++----------- .../sql/hive/execution/HiveUDFSuite.scala | 2 +- 6 files changed, 129 insertions(+), 86 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 576d8c7a3a..d8699533cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow @@ -169,6 +168,12 @@ abstract class AggregateFunction2 override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + + /** + * Indicates if this function supports partial aggregation. + * Currently Hive UDAF is the only one that doesn't support partial aggregation. + */ + def supportsPartial: Boolean = true } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 41b215c792..b078c8b6b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -221,7 +221,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } val aggregateOperator = - if (functionsWithDistinct.isEmpty) { + if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) { + if (functionsWithDistinct.nonEmpty) { + sys.error("Distinct columns cannot exist in Aggregate operator containing " + + "aggregate functions which don't support partial aggregation.") + } else { + aggregate.Utils.planAggregateWithoutPartial( + groupingExpressions, + aggregateExpressions, + aggregateFunctionMap, + resultExpressions, + planLater(child)) + } + } else if (functionsWithDistinct.isEmpty) { aggregate.Utils.planAggregateWithoutDistinct( groupingExpressions, aggregateExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index abca373b0c..62dbc07e88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -26,7 +26,7 @@ import org.apache.spark.unsafe.KVIterator import scala.collection.mutable.ArrayBuffer /** - * The base class of [[SortBasedAggregationIterator]] and [[UnsafeHybridAggregationIterator]]. + * The base class of [[SortBasedAggregationIterator]]. * It mainly contains two parts: * 1. It initializes aggregate functions. * 2. It creates two functions, `processRow` and `generateOutput` based on [[AggregateMode]] of diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 80816a095e..4f5e86cceb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -37,6 +37,57 @@ object Utils { UnsafeProjection.canSupport(groupingExpressions) } + def planAggregateWithoutPartial( + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[AggregateExpression2], + aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + + val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) + val completeAggregateAttributes = + completeAggregateExpressions.map { + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2 + } + + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case agg: AggregateExpression2 => + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2 + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + SortBasedAggregate( + requiredChildDistributionExpressions = Some(namedGroupingAttributes), + groupingExpressions = namedGroupingExpressions.map(_._2), + nonCompleteAggregateExpressions = Nil, + nonCompleteAggregateAttributes = Nil, + completeAggregateExpressions = completeAggregateExpressions, + completeAggregateAttributes = completeAggregateAttributes, + initialInputBufferOffset = 0, + resultExpressions = rewrittenResultExpressions, + child = child + ) :: Nil + } + def planAggregateWithoutDistinct( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[AggregateExpression2], diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index cad02373e5..fa9012b96e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -65,9 +66,10 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDAF(new HiveFunctionWrapper(functionClassName), children) + HiveUDAFFunction(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUDAF(new HiveFunctionWrapper(functionClassName), children) + HiveUDAFFunction( + new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) } else { @@ -441,70 +443,6 @@ private[hive] case class HiveWindowFunction( new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children) } -private[hive] case class HiveGenericUDAF( - funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression1 - with HiveInspectors { - - type UDFType = AbstractGenericUDAFResolver - - @transient - protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction() - - @transient - protected lazy val objectInspector = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) - resolver.getEvaluator(parameterInfo) - .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) - } - - @transient - protected lazy val inspectors = children.map(toInspector) - - def dataType: DataType = inspectorToDataType(objectInspector) - - def nullable: Boolean = true - - override def toString: String = { - s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - } - - def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this) -} - -/** It is used as a wrapper for the hive functions which uses UDAF interface */ -private[hive] case class HiveUDAF( - funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression1 - with HiveInspectors { - - type UDFType = UDAF - - @transient - protected lazy val resolver: AbstractGenericUDAFResolver = - new GenericUDAFBridge(funcWrapper.createFunction()) - - @transient - protected lazy val objectInspector = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) - resolver.getEvaluator(parameterInfo) - .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) - } - - @transient - protected lazy val inspectors = children.map(toInspector) - - def dataType: DataType = inspectorToDataType(objectInspector) - - def nullable: Boolean = true - - override def toString: String = { - s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - } - - def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this, true) -} - /** * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a * [[Generator]]. Note that the semantics of Generators do not allow @@ -584,49 +522,86 @@ private[hive] case class HiveGenericUDTF( } } +/** + * Currently we don't support partial aggregation for queries using Hive UDAF, which may hurt + * performance a lot. + */ private[hive] case class HiveUDAFFunction( funcWrapper: HiveFunctionWrapper, - exprs: Seq[Expression], - base: AggregateExpression1, + children: Seq[Expression], isUDAFBridgeRequired: Boolean = false) - extends AggregateFunction1 - with HiveInspectors { + extends AggregateFunction2 with HiveInspectors { - def this() = this(null, null, null) + def this() = this(null, null) - private val resolver = + @transient + private lazy val resolver = if (isUDAFBridgeRequired) { new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) } else { funcWrapper.createFunction[AbstractGenericUDAFResolver]() } - private val inspectors = exprs.map(toInspector).toArray + @transient + private lazy val inspectors = children.map(toInspector).toArray - private val function = { + @transient + private lazy val functionAndInspector = { val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false) - resolver.getEvaluator(parameterInfo) + val f = resolver.getEvaluator(parameterInfo) + f -> f.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) } - private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) + @transient + private lazy val function = functionAndInspector._1 + + @transient + private lazy val returnInspector = functionAndInspector._2 - private val buffer = - function.getNewAggregationBuffer + @transient + private[this] var buffer: GenericUDAFEvaluator.AggregationBuffer = _ override def eval(input: InternalRow): Any = unwrap(function.evaluate(buffer), returnInspector) @transient - val inputProjection = new InterpretedProjection(exprs) + private lazy val inputProjection = new InterpretedProjection(children) @transient - protected lazy val cached = new Array[AnyRef](exprs.length) + private lazy val cached = new Array[AnyRef](children.length) @transient - private lazy val inputDataTypes: Array[DataType] = exprs.map(_.dataType).toArray + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + + // Hive UDAF has its own buffer, so we don't need to occupy a slot in the aggregation + // buffer for it. + override def bufferSchema: StructType = StructType(Nil) - def update(input: InternalRow): Unit = { + override def update(_buffer: MutableRow, input: InternalRow): Unit = { val inputs = inputProjection(input) function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes)) } + + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + throw new UnsupportedOperationException( + "Hive UDAF doesn't support partial aggregate") + } + + override def cloneBufferAttributes: Seq[Attribute] = Nil + + override def initialize(_buffer: MutableRow): Unit = { + buffer = function.getNewAggregationBuffer + } + + override def bufferAttributes: Seq[AttributeReference] = Nil + + // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our + // catalyst type checking framework. + override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType) + + override def nullable: Boolean = true + + override def supportsPartial: Boolean = false + + override lazy val dataType: DataType = inspectorToDataType(returnInspector) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index d9ba895e1e..3c8a0091c8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -131,7 +131,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton { hiveContext.setConf(SQLConf.CODEGEN_ENABLED, codegenDefault) } - test("SPARK-6409 UDAFAverage test") { + test("SPARK-6409 UDAF Average test") { sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") checkAnswer( sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"), -- GitLab