diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d0eb9c2c90bdf85d8a30a2c3a473edbe80db5390..1a5de15c61f8673e2603498ff3cafa2e8b812e8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -78,7 +78,7 @@ class Analyzer( ResolveAliases :: ExtractWindowExpressions :: GlobalAggregates :: - UnresolvedHavingClauseAttributes :: + ResolveAggregateFunctions :: HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, @@ -452,37 +452,6 @@ class Analyzer( logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}") s // Nothing we can do here. Return original plan. } - case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child)) - if !s.resolved && a.resolved => - // A small hack to create an object that will allow us to resolve any references that - // refer to named expressions that are present in the grouping expressions. - val groupingRelation = LocalRelation( - grouping.collect { case ne: NamedExpression => ne.toAttribute } - ) - - // Find sort attributes that are projected away so we can temporarily add them back in. - val (newOrdering, missingAttr) = resolveAndFindMissing(ordering, a, groupingRelation) - - // Find aggregate expressions and evaluate them early, since they can't be evaluated in a - // Sort. - val (withAggsRemoved, aliasedAggregateList) = newOrdering.map { - case aggOrdering if aggOrdering.collect { case a: AggregateExpression => a }.nonEmpty => - val aliased = Alias(aggOrdering.child, "_aggOrdering")() - (aggOrdering.copy(child = aliased.toAttribute), Some(aliased)) - - case other => (other, None) - }.unzip - - val missing = missingAttr ++ aliasedAggregateList.flatten - - if (missing.nonEmpty) { - // Add missing grouping exprs and then project them away after the sort. - Project(a.output, - Sort(withAggsRemoved, global, - Aggregate(grouping, aggs ++ missing, child))) - } else { - s // Nothing we can do here. Return original plan. - } } /** @@ -515,6 +484,7 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case q: LogicalPlan => q transformExpressions { + case u if !u.childrenResolved => u // Skip until children are resolved. case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { registry.lookupFunction(name, children) match { @@ -559,21 +529,79 @@ class Analyzer( } /** - * This rule finds expressions in HAVING clause filters that depend on - * unresolved attributes. It pushes these expressions down to the underlying - * aggregates and then projects them away above the filter. + * This rule finds aggregate expressions that are not in an aggregate operator. For example, + * those in a HAVING clause or ORDER BY clause. These expressions are pushed down to the + * underlying aggregate operator and then projected away after the original operator. */ - object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] { + object ResolveAggregateFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) - if aggregate.resolved && containsAggregate(havingCondition) => - - val evaluatedCondition = Alias(havingCondition, "havingCondition")() - val aggExprsWithHaving = evaluatedCondition +: originalAggExprs + case filter @ Filter(havingCondition, + aggregate @ Aggregate(grouping, originalAggExprs, child)) + if aggregate.resolved && !filter.resolved => + + // Try resolving the condition of the filter as though it is in the aggregate clause + val aggregatedCondition = + Aggregate(grouping, Alias(havingCondition, "havingCondition")() :: Nil, child) + val resolvedOperator = execute(aggregatedCondition) + def resolvedAggregateFilter = + resolvedOperator + .asInstanceOf[Aggregate] + .aggregateExpressions.head + + // If resolution was successful and we see the filter has an aggregate in it, add it to + // the original aggregate operator. + if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) { + val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs + + Project(aggregate.output, + Filter(resolvedAggregateFilter.toAttribute, + aggregate.copy(aggregateExpressions = aggExprsWithHaving))) + } else { + filter + } - Project(aggregate.output, - Filter(evaluatedCondition.toAttribute, - aggregate.copy(aggregateExpressions = aggExprsWithHaving))) + case sort @ Sort(sortOrder, global, + aggregate @ Aggregate(grouping, originalAggExprs, child)) + if aggregate.resolved && !sort.resolved => + + // Try resolving the ordering as though it is in the aggregate clause. + try { + val aliasedOrder = sortOrder.map(o => Alias(o.child, "aggOrder")()) + val aggregatedOrdering = Aggregate(grouping, aliasedOrder, child) + val resolvedOperator: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] + def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions + + // Expressions that have an aggregate can be pushed down. + val needsAggregate = resolvedAggregateOrdering.exists(containsAggregate) + + // Attribute references, that are missing from the order but are present in the grouping + // expressions can also be pushed down. + val requiredAttributes = resolvedAggregateOrdering.map(_.references).reduce(_ ++ _) + val missingAttributes = requiredAttributes -- aggregate.outputSet + val validPushdownAttributes = + missingAttributes.filter(a => grouping.exists(a.semanticEquals)) + + // If resolution was successful and we see the ordering either has an aggregate in it or + // it is missing something that is projected away by the aggregate, add the ordering + // the original aggregate operator. + if (resolvedOperator.resolved && (needsAggregate || validPushdownAttributes.nonEmpty)) { + val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedAggregateOrdering).map { + case (order, evaluated) => order.copy(child = evaluated.toAttribute) + } + val aggExprsWithOrdering: Seq[NamedExpression] = + resolvedAggregateOrdering ++ originalAggExprs + + Project(aggregate.output, + Sort(evaluatedOrderings, global, + aggregate.copy(aggregateExpressions = aggExprsWithOrdering))) + } else { + sort + } + } catch { + // Attempting to resolve in the aggregate can result in ambiguity. When this happens, + // just return the original plan. + case ae: AnalysisException => sort + } } protected def containsAggregate(condition: Expression): Boolean = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 10f2902e5eef0b2c489cde74f02d06fbe911fb9a..b03a35132325d35bc2769faf1998da58fbd45f85 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -276,6 +276,11 @@ class HiveUDFSuite extends QueryTest { checkAnswer( sql("SELECT testStringStringUDF(\"hello\", s) FROM stringTable"), Seq(Row("hello world"), Row("hello goodbye"))) + + checkAnswer( + sql("SELECT testStringStringUDF(\"\", testStringStringUDF(\"hello\", s)) FROM stringTable"), + Seq(Row(" hello world"), Row(" hello goodbye"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF") TestHive.reset()