Skip to content
Snippets Groups Projects
Commit 7f05b1fe authored by Wenchen Fan's avatar Wenchen Fan Committed by Michael Armbrust
Browse files

[SPARK-7067] [SQL] fix bug when use complex nested fields in ORDER BY

This PR is a improvement for https://github.com/apache/spark/pull/5189.

The resolution rule for ORDER BY is: first resolve based on what comes from the select clause and then fall back on its child only when this fails.

There are 2 steps. First, try to resolve `Sort` in `ResolveReferences` based on select clause, and ignore exceptions. Second, try to resolve `Sort` in `ResolveSortReferences` and add missing projection.

However, the way we resolve `SortOrder` is wrong. We just resolve `UnresolvedAttribute` and use the result to indicate if we can resolve `SortOrder`. But `UnresolvedAttribute` is only part of `GetField` chain(broken by `GetItem`), so we need to go through the whole chain to indicate if we can resolve `SortOrder`.

With this change, we can also avoid re-throw GetField exception in `CheckAnalysis` which is little ugly.

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #5659 from cloud-fan/order-by and squashes the following commits:

cfa79f8 [Wenchen Fan] update test
3245d28 [Wenchen Fan] minor improve
465ee07 [Wenchen Fan] address comment
1fc41a2 [Wenchen Fan] fix SPARK-7067
parent a411a40d
No related branches found
No related tags found
No related merge requests found
...@@ -336,9 +336,15 @@ class Analyzer( ...@@ -336,9 +336,15 @@ class Analyzer(
} }
j.copy(right = newRight) j.copy(right = newRight)
// When resolve `SortOrder`s in Sort based on child, don't report errors as
// we still have chance to resolve it based on grandchild
case s @ Sort(ordering, global, child) if child.resolved && !s.resolved =>
val newOrdering = resolveSortOrders(ordering, child, throws = false)
Sort(newOrdering, global, child)
case q: LogicalPlan => case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}") logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressionsUp { q transformExpressionsUp {
case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 && case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 &&
resolver(nameParts(0), VirtualColumn.groupingIdName) && resolver(nameParts(0), VirtualColumn.groupingIdName) &&
q.isInstanceOf[GroupingAnalytics] => q.isInstanceOf[GroupingAnalytics] =>
...@@ -373,6 +379,26 @@ class Analyzer( ...@@ -373,6 +379,26 @@ class Analyzer(
exprs.exists(_.collect { case _: Star => true }.nonEmpty) exprs.exists(_.collect { case _: Star => true }.nonEmpty)
} }
private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = {
ordering.map { order =>
// Resolve SortOrder in one round.
// If throws == false or the desired attribute doesn't exist
// (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one.
// Else, throw exception.
try {
val newOrder = order transformUp {
case u @ UnresolvedAttribute(nameParts) =>
plan.resolve(nameParts, resolver).getOrElse(u)
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
newOrder.asInstanceOf[SortOrder]
} catch {
case a: AnalysisException if !throws => order
}
}
}
/** /**
* In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT
* clause. This rule detects such queries and adds the required attributes to the original * clause. This rule detects such queries and adds the required attributes to the original
...@@ -383,13 +409,13 @@ class Analyzer( ...@@ -383,13 +409,13 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case s @ Sort(ordering, global, p @ Project(projectList, child)) case s @ Sort(ordering, global, p @ Project(projectList, child))
if !s.resolved && p.resolved => if !s.resolved && p.resolved =>
val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, p, child) val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child)
// If this rule was not a no-op, return the transformed plan, otherwise return the original. // If this rule was not a no-op, return the transformed plan, otherwise return the original.
if (missing.nonEmpty) { if (missing.nonEmpty) {
// Add missing attributes and then project them away after the sort. // Add missing attributes and then project them away after the sort.
Project(p.output, Project(p.output,
Sort(resolvedOrdering, global, Sort(newOrdering, global,
Project(projectList ++ missing, child))) Project(projectList ++ missing, child)))
} else { } else {
logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}") logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}")
...@@ -404,19 +430,19 @@ class Analyzer( ...@@ -404,19 +430,19 @@ class Analyzer(
) )
// Find sort attributes that are projected away so we can temporarily add them back in. // Find sort attributes that are projected away so we can temporarily add them back in.
val (resolvedOrdering, unresolved) = resolveAndFindMissing(ordering, a, groupingRelation) val (newOrdering, missingAttr) = resolveAndFindMissing(ordering, a, groupingRelation)
// Find aggregate expressions and evaluate them early, since they can't be evaluated in a // Find aggregate expressions and evaluate them early, since they can't be evaluated in a
// Sort. // Sort.
val (withAggsRemoved, aliasedAggregateList) = resolvedOrdering.map { val (withAggsRemoved, aliasedAggregateList) = newOrdering.map {
case aggOrdering if aggOrdering.collect { case a: AggregateExpression => a }.nonEmpty => case aggOrdering if aggOrdering.collect { case a: AggregateExpression => a }.nonEmpty =>
val aliased = Alias(aggOrdering.child, "_aggOrdering")() val aliased = Alias(aggOrdering.child, "_aggOrdering")()
(aggOrdering.copy(child = aliased.toAttribute), aliased :: Nil) (aggOrdering.copy(child = aliased.toAttribute), Some(aliased))
case other => (other, Nil) case other => (other, None)
}.unzip }.unzip
val missing = unresolved ++ aliasedAggregateList.flatten val missing = missingAttr ++ aliasedAggregateList.flatten
if (missing.nonEmpty) { if (missing.nonEmpty) {
// Add missing grouping exprs and then project them away after the sort. // Add missing grouping exprs and then project them away after the sort.
...@@ -429,40 +455,25 @@ class Analyzer( ...@@ -429,40 +455,25 @@ class Analyzer(
} }
/** /**
* Given a child and a grandchild that are present beneath a sort operator, returns * Given a child and a grandchild that are present beneath a sort operator, try to resolve
* a resolved sort ordering and a list of attributes that are missing from the child * the sort ordering and returns it with a list of attributes that are missing from the
* but are present in the grandchild. * child but are present in the grandchild.
*/ */
def resolveAndFindMissing( def resolveAndFindMissing(
ordering: Seq[SortOrder], ordering: Seq[SortOrder],
child: LogicalPlan, child: LogicalPlan,
grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
// Find any attributes that remain unresolved in the sort. val newOrdering = resolveSortOrders(ordering, grandchild, throws = true)
val unresolved: Seq[Seq[String]] =
ordering.flatMap(_.collect { case UnresolvedAttribute(nameParts) => nameParts })
// Create a map from name, to resolved attributes, when the desired name can be found
// prior to the projection.
val resolved: Map[Seq[String], NamedExpression] =
unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap
// Construct a set that contains all of the attributes that we need to evaluate the // Construct a set that contains all of the attributes that we need to evaluate the
// ordering. // ordering.
val requiredAttributes = AttributeSet(resolved.values) val requiredAttributes = AttributeSet(newOrdering.filter(_.resolved))
// Figure out which ones are missing from the projection, so that we can add them and // Figure out which ones are missing from the projection, so that we can add them and
// remove them after the sort. // remove them after the sort.
val missingInProject = requiredAttributes -- child.output val missingInProject = requiredAttributes -- child.output
// It is important to return the new SortOrders here, instead of waiting for the standard
// Now that we have all the attributes we need, reconstruct a resolved ordering. // resolving process as adding attributes to the project below can actually introduce
// It is important to do it here, instead of waiting for the standard resolved as adding // ambiguity that was not present before.
// attributes to the project below can actually introduce ambiquity that was not present (newOrdering, missingInProject.toSeq)
// before.
val resolvedOrdering = ordering.map(_ transform {
case u @ UnresolvedAttribute(name) => resolved.getOrElse(name, u)
}).asInstanceOf[Seq[SortOrder]]
(resolvedOrdering, missingInProject.toSeq)
} }
} }
......
...@@ -51,14 +51,6 @@ trait CheckAnalysis { ...@@ -51,14 +51,6 @@ trait CheckAnalysis {
case operator: LogicalPlan => case operator: LogicalPlan =>
operator transformExpressionsUp { operator transformExpressionsUp {
case a: Attribute if !a.resolved => case a: Attribute if !a.resolved =>
if (operator.childrenResolved) {
a match {
case UnresolvedAttribute(nameParts) =>
// Throw errors for specific problems with get field.
operator.resolveChildren(nameParts, resolver, throwErrors = true)
}
}
val from = operator.inputSet.map(_.name).mkString(", ") val from = operator.inputSet.map(_.name).mkString(", ")
a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")
......
...@@ -50,19 +50,19 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { ...@@ -50,19 +50,19 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
* [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation UnresolvedRelation]] * [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation UnresolvedRelation]]
* should return `false`). * should return `false`).
*/ */
lazy val resolved: Boolean = !expressions.exists(!_.resolved) && childrenResolved lazy val resolved: Boolean = expressions.forall(_.resolved) && childrenResolved
override protected def statePrefix = if (!resolved) "'" else super.statePrefix override protected def statePrefix = if (!resolved) "'" else super.statePrefix
/** /**
* Returns true if all its children of this query plan have been resolved. * Returns true if all its children of this query plan have been resolved.
*/ */
def childrenResolved: Boolean = !children.exists(!_.resolved) def childrenResolved: Boolean = children.forall(_.resolved)
/** /**
* Returns true when the given logical plan will return the same results as this logical plan. * Returns true when the given logical plan will return the same results as this logical plan.
* *
* Since its likely undecideable to generally determine if two given plans will produce the same * Since its likely undecidable to generally determine if two given plans will produce the same
* results, it is okay for this function to return false, even if the results are actually * results, it is okay for this function to return false, even if the results are actually
* the same. Such behavior will not affect correctness, only the application of performance * the same. Such behavior will not affect correctness, only the application of performance
* enhancements like caching. However, it is not acceptable to return true if the results could * enhancements like caching. However, it is not acceptable to return true if the results could
...@@ -111,9 +111,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { ...@@ -111,9 +111,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
*/ */
def resolveChildren( def resolveChildren(
nameParts: Seq[String], nameParts: Seq[String],
resolver: Resolver, resolver: Resolver): Option[NamedExpression] =
throwErrors: Boolean = false): Option[NamedExpression] = resolve(nameParts, children.flatMap(_.output), resolver)
resolve(nameParts, children.flatMap(_.output), resolver, throwErrors)
/** /**
* Optionally resolves the given strings to a [[NamedExpression]] based on the output of this * Optionally resolves the given strings to a [[NamedExpression]] based on the output of this
...@@ -122,9 +121,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { ...@@ -122,9 +121,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
*/ */
def resolve( def resolve(
nameParts: Seq[String], nameParts: Seq[String],
resolver: Resolver, resolver: Resolver): Option[NamedExpression] =
throwErrors: Boolean = false): Option[NamedExpression] = resolve(nameParts, output, resolver)
resolve(nameParts, output, resolver, throwErrors)
/** /**
* Given an attribute name, split it to name parts by dot, but * Given an attribute name, split it to name parts by dot, but
...@@ -134,7 +132,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { ...@@ -134,7 +132,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
def resolveQuoted( def resolveQuoted(
name: String, name: String,
resolver: Resolver): Option[NamedExpression] = { resolver: Resolver): Option[NamedExpression] = {
resolve(parseAttributeName(name), resolver, true) resolve(parseAttributeName(name), output, resolver)
} }
/** /**
...@@ -219,8 +217,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { ...@@ -219,8 +217,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
protected def resolve( protected def resolve(
nameParts: Seq[String], nameParts: Seq[String],
input: Seq[Attribute], input: Seq[Attribute],
resolver: Resolver, resolver: Resolver): Option[NamedExpression] = {
throwErrors: Boolean): Option[NamedExpression] = {
// A sequence of possible candidate matches. // A sequence of possible candidate matches.
// Each candidate is a tuple. The first element is a resolved attribute, followed by a list // Each candidate is a tuple. The first element is a resolved attribute, followed by a list
...@@ -254,19 +251,15 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { ...@@ -254,19 +251,15 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// One match, but we also need to extract the requested nested field. // One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) => case Seq((a, nestedFields)) =>
try { // The foldLeft adds ExtractValues for every remaining parts of the identifier,
// The foldLeft adds GetFields for every remaining parts of the identifier, // and aliases it with the last part of the identifier.
// and aliases it with the last part of the identifier. // For example, consider "a.b.c", where "a" is resolved to an existing attribute.
// For example, consider "a.b.c", where "a" is resolved to an existing attribute. // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias
// Then this will add GetField("c", GetField("b", a)), and alias // the final expression as "c".
// the final expression as "c". val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => ExtractValue(expr, Literal(fieldName), resolver))
ExtractValue(expr, Literal(fieldName), resolver)) val aliasName = nestedFields.last
val aliasName = nestedFields.last Some(Alias(fieldExprs, aliasName)())
Some(Alias(fieldExprs, aliasName)())
} catch {
case a: AnalysisException if !throwErrors => None
}
// No matches. // No matches.
case Seq() => case Seq() =>
......
...@@ -285,7 +285,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { ...@@ -285,7 +285,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
* @param rule the function use to transform this nodes children * @param rule the function use to transform this nodes children
*/ */
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
val afterRuleOnChildren = transformChildrenUp(rule); val afterRuleOnChildren = transformChildrenUp(rule)
if (this fastEquals afterRuleOnChildren) { if (this fastEquals afterRuleOnChildren) {
CurrentOrigin.withOrigin(origin) { CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, identity[BaseType]) rule.applyOrElse(this, identity[BaseType])
......
...@@ -1440,4 +1440,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { ...@@ -1440,4 +1440,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
checkAnswer(sql("select i <=> b from t"), sql("select r2 from t")) checkAnswer(sql("select i <=> b from t"), sql("select r2 from t"))
} }
} }
test("SPARK-7067: order by queries for complex ExtractValue chain") {
withTempTable("t") {
sqlContext.read.json(sqlContext.sparkContext.makeRDD(
"""{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t")
checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1))))
}
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment