Skip to content
Snippets Groups Projects
Commit fc483077 authored by Wenchen Fan's avatar Wenchen Fan Committed by Michael Armbrust
Browse files

[SPARK-10389] [SQL] support order by non-attribute grouping expression on Aggregate

For example, we can write `SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1` in PostgreSQL, and we should support this in Spark SQL.

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #8548 from cloud-fan/support-order-by-non-attribute.
parent 56c4c172
No related branches found
No related tags found
No related merge requests found
......@@ -560,43 +560,47 @@ class Analyzer(
filter
}
case sort @ Sort(sortOrder, global,
aggregate @ Aggregate(grouping, originalAggExprs, child))
case sort @ Sort(sortOrder, global, aggregate: Aggregate)
if aggregate.resolved && !sort.resolved =>
// Try resolving the ordering as though it is in the aggregate clause.
try {
val aliasedOrder = sortOrder.map(o => Alias(o.child, "aggOrder")())
val aggregatedOrdering = Aggregate(grouping, aliasedOrder, child)
val resolvedOperator: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions
// Expressions that have an aggregate can be pushed down.
val needsAggregate = resolvedAggregateOrdering.exists(containsAggregate)
// Attribute references, that are missing from the order but are present in the grouping
// expressions can also be pushed down.
val requiredAttributes = resolvedAggregateOrdering.map(_.references).reduce(_ ++ _)
val missingAttributes = requiredAttributes -- aggregate.outputSet
val validPushdownAttributes =
missingAttributes.filter(a => grouping.exists(a.semanticEquals))
// If resolution was successful and we see the ordering either has an aggregate in it or
// it is missing something that is projected away by the aggregate, add the ordering
// the original aggregate operator.
if (resolvedOperator.resolved && (needsAggregate || validPushdownAttributes.nonEmpty)) {
val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedAggregateOrdering).map {
case (order, evaluated) => order.copy(child = evaluated.toAttribute)
}
val aggExprsWithOrdering: Seq[NamedExpression] =
resolvedAggregateOrdering ++ originalAggExprs
Project(aggregate.output,
Sort(evaluatedOrderings, global,
aggregate.copy(aggregateExpressions = aggExprsWithOrdering)))
} else {
sort
val aliasedOrdering = sortOrder.map(o => Alias(o.child, "aggOrder")())
val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
val resolvedAliasedOrdering: Seq[Alias] =
resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]]
// If we pass the analysis check, then the ordering expressions should only reference to
// aggregate expressions or grouping expressions, and it's safe to push them down to
// Aggregate.
checkAnalysis(resolvedAggregate)
val originalAggExprs = aggregate.aggregateExpressions.map(
CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
// If the ordering expression is same with original aggregate expression, we don't need
// to push down this ordering expression and can reference the original aggregate
// expression instead.
val needsPushDown = ArrayBuffer.empty[NamedExpression]
val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map {
case (evaluated, order) =>
val index = originalAggExprs.indexWhere {
case Alias(child, _) => child semanticEquals evaluated.child
case other => other semanticEquals evaluated.child
}
if (index == -1) {
needsPushDown += evaluated
order.copy(child = evaluated.toAttribute)
} else {
order.copy(child = originalAggExprs(index).toAttribute)
}
}
Project(aggregate.output,
Sort(evaluatedOrderings, global,
aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown)))
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return the original plan.
......@@ -605,9 +609,7 @@ class Analyzer(
}
protected def containsAggregate(condition: Expression): Boolean = {
condition
.collect { case ae: AggregateExpression => ae }
.nonEmpty
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}
}
......
......@@ -1722,9 +1722,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-10130 type coercion for IF should have children resolved first") {
val df = Seq((1, 1), (-1, 1)).toDF("key", "value")
df.registerTempTable("src")
checkAnswer(
sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0)))
withTempTable("src") {
Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src")
checkAnswer(
sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0)))
}
}
test("SPARK-10389: order by non-attribute grouping expression on Aggregate") {
withTempTable("src") {
Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src")
checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1"),
Seq(Row(1), Row(1)))
checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY (key + 1) * 2"),
Seq(Row(1), Row(1)))
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment