From cd48ca50129e8952f487051796244e7569275416 Mon Sep 17 00:00:00 2001
From: Michael Armbrust <michael@databricks.com>
Date: Tue, 31 Mar 2015 11:23:18 -0700
Subject: [PATCH] [SPARK-6145][SQL] fix ORDER BY on nested fields

This PR is based on work by cloud-fan in #4904, but with two differences:
 - We isolate the logic for Sort's special handling into `ResolveSortReferences`
 - We avoid creating UnresolvedGetField expressions during resolution.  Instead we either resolve GetField or we return None.  This avoids us going down the wrong path early on.

Author: Michael Armbrust <michael@databricks.com>

Closes #5189 from marmbrus/nestedOrderBy and squashes the following commits:

b8cae45 [Michael Armbrust] fix another test
0f36a11 [Michael Armbrust] WIP
91820cd [Michael Armbrust] Fix bug.
---
 .../sql/catalyst/analysis/Analyzer.scala      | 76 ++++++++++++++-----
 .../sql/catalyst/analysis/CheckAnalysis.scala | 12 ++-
 .../catalyst/expressions/AttributeSet.scala   |  2 +-
 .../catalyst/plans/logical/LogicalPlan.scala  | 76 +++++++++++++++----
 .../sql/catalyst/analysis/AnalysisSuite.scala | 39 +++++++++-
 .../org/apache/spark/sql/SQLContext.scala     | 14 ++--
 .../org/apache/spark/sql/SQLQuerySuite.scala  | 19 +++--
 .../spark/sql/sources/DataSourceTest.scala    |  4 +
 8 files changed, 185 insertions(+), 57 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index dc14f49e6e..c578d084a4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -37,11 +37,12 @@ object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true
  * [[UnresolvedRelation]]s into fully typed objects using information in a schema [[Catalog]] and
  * a [[FunctionRegistry]].
  */
-class Analyzer(catalog: Catalog,
-               registry: FunctionRegistry,
-               caseSensitive: Boolean,
-               maxIterations: Int = 100)
-  extends RuleExecutor[LogicalPlan] with HiveTypeCoercion {
+class Analyzer(
+    catalog: Catalog,
+    registry: FunctionRegistry,
+    caseSensitive: Boolean,
+    maxIterations: Int = 100)
+  extends RuleExecutor[LogicalPlan] with HiveTypeCoercion with CheckAnalysis {
 
   val resolver = if (caseSensitive) caseSensitiveResolution else caseInsensitiveResolution
 
@@ -354,19 +355,16 @@ class Analyzer(catalog: Catalog,
     def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
       case s @ Sort(ordering, global, p @ Project(projectList, child))
           if !s.resolved && p.resolved =>
-        val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
-        val resolved = unresolved.flatMap(child.resolve(_, resolver))
-        val requiredAttributes =
-          AttributeSet(resolved.flatMap(_.collect { case a: Attribute => a }))
+        val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, p, child)
 
-        val missingInProject = requiredAttributes -- p.output
-        if (missingInProject.nonEmpty) {
+        // If this rule was not a no-op, return the transformed plan, otherwise return the original.
+        if (missing.nonEmpty) {
           // Add missing attributes and then project them away after the sort.
-          Project(projectList.map(_.toAttribute),
-            Sort(ordering, global,
-              Project(projectList ++ missingInProject, child)))
+          Project(p.output,
+            Sort(resolvedOrdering, global,
+              Project(projectList ++ missing, child)))
         } else {
-          logDebug(s"Failed to find $missingInProject in ${p.output.mkString(", ")}")
+          logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}")
           s // Nothing we can do here. Return original plan.
         }
       case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child))
@@ -378,18 +376,54 @@ class Analyzer(catalog: Catalog,
           grouping.collect { case ne: NamedExpression => ne.toAttribute }
         )
 
-        logDebug(s"Grouping expressions: $groupingRelation")
-        val resolved = unresolved.flatMap(groupingRelation.resolve(_, resolver))
-        val missingInAggs = resolved.filterNot(a.outputSet.contains)
-        logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs")
-        if (missingInAggs.nonEmpty) {
+        val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, a, groupingRelation)
+
+        if (missing.nonEmpty) {
           // Add missing grouping exprs and then project them away after the sort.
           Project(a.output,
-            Sort(ordering, global, Aggregate(grouping, aggs ++ missingInAggs, child)))
+            Sort(resolvedOrdering, global,
+              Aggregate(grouping, aggs ++ missing, child)))
         } else {
           s // Nothing we can do here. Return original plan.
         }
     }
+
+    /**
+     * Given a child and a grandchild that are present beneath a sort operator, returns
+     * a resolved sort ordering and a list of attributes that are missing from the child
+     * but are present in the grandchild.
+     */
+    def resolveAndFindMissing(
+        ordering: Seq[SortOrder],
+        child: LogicalPlan,
+        grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
+      // Find any attributes that remain unresolved in the sort.
+      val unresolved: Seq[String] =
+        ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
+
+      // Create a map from name, to resolved attributes, when the desired name can be found
+      // prior to the projection.
+      val resolved: Map[String, NamedExpression] =
+        unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap
+
+      // Construct a set that contains all of the attributes that we need to evaluate the
+      // ordering.
+      val requiredAttributes = AttributeSet(resolved.values)
+
+      // Figure out which ones are missing from the projection, so that we can add them and
+      // remove them after the sort.
+      val missingInProject = requiredAttributes -- child.output
+
+      // Now that we have all the attributes we need, reconstruct a resolved ordering.
+      // It is important to do it here, instead of waiting for the standard resolved as adding
+      // attributes to the project below can actually introduce ambiquity that was not present
+      // before.
+      val resolvedOrdering = ordering.map(_ transform {
+        case u @ UnresolvedAttribute(name) => resolved.getOrElse(name, u)
+      }).asInstanceOf[Seq[SortOrder]]
+
+      (resolvedOrdering, missingInProject.toSeq)
+    }
   }
 
   /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 40472a1cbb..fa02111385 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -25,7 +25,8 @@ import org.apache.spark.sql.types._
 /**
  * Throws user facing errors when passed invalid queries that fail to analyze.
  */
-class CheckAnalysis {
+trait CheckAnalysis {
+  self: Analyzer =>
 
   /**
    * Override to provide additional checks for correct analysis.
@@ -33,17 +34,22 @@ class CheckAnalysis {
    */
   val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil
 
-  def failAnalysis(msg: String): Nothing = {
+  protected def failAnalysis(msg: String): Nothing = {
     throw new AnalysisException(msg)
   }
 
-  def apply(plan: LogicalPlan): Unit = {
+  def checkAnalysis(plan: LogicalPlan): Unit = {
     // We transform up and order the rules so as to catch the first possible failure instead
     // of the result of cascading resolution failures.
     plan.foreachUp {
       case operator: LogicalPlan =>
         operator transformExpressionsUp {
           case a: Attribute if !a.resolved =>
+            if (operator.childrenResolved) {
+              // Throw errors for specific problems with get field.
+              operator.resolveChildren(a.name, resolver, throwErrors = true)
+            }
+
             val from = operator.inputSet.map(_.name).mkString(", ")
             a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
index 11b4eb5c88..5345696570 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -34,7 +34,7 @@ object AttributeSet {
   def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a)))
 
   /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */
-  def apply(baseSet: Seq[Expression]): AttributeSet = {
+  def apply(baseSet: Iterable[Expression]): AttributeSet = {
     new AttributeSet(
       baseSet
         .flatMap(_.references)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index b01a61d7bf..2e9f3aa4ec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.trees.TreeNode
 import org.apache.spark.sql.catalyst.trees
+import org.apache.spark.sql.types.{ArrayType, StructType, StructField}
 
 
 abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
@@ -109,16 +110,22 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
    * nodes of this LogicalPlan. The attribute is expressed as
    * as string in the following form: `[scope].AttributeName.[nested].[fields]...`.
    */
-  def resolveChildren(name: String, resolver: Resolver): Option[NamedExpression] =
-    resolve(name, children.flatMap(_.output), resolver)
+  def resolveChildren(
+      name: String,
+      resolver: Resolver,
+      throwErrors: Boolean = false): Option[NamedExpression] =
+    resolve(name, children.flatMap(_.output), resolver, throwErrors)
 
   /**
    * Optionally resolves the given string to a [[NamedExpression]] based on the output of this
    * LogicalPlan. The attribute is expressed as string in the following form:
    * `[scope].AttributeName.[nested].[fields]...`.
    */
-  def resolve(name: String, resolver: Resolver): Option[NamedExpression] =
-    resolve(name, output, resolver)
+  def resolve(
+      name: String,
+      resolver: Resolver,
+      throwErrors: Boolean = false): Option[NamedExpression] =
+    resolve(name, output, resolver, throwErrors)
 
   /**
    * Resolve the given `name` string against the given attribute, returning either 0 or 1 match.
@@ -162,7 +169,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
   protected def resolve(
       name: String,
       input: Seq[Attribute],
-      resolver: Resolver): Option[NamedExpression] = {
+      resolver: Resolver,
+      throwErrors: Boolean): Option[NamedExpression] = {
 
     val parts = name.split("\\.")
 
@@ -196,14 +204,19 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
 
       // One match, but we also need to extract the requested nested field.
       case Seq((a, nestedFields)) =>
-        // The foldLeft adds UnresolvedGetField for every remaining parts of the name,
-        // and aliased it with the last part of the name.
-        // For example, consider name "a.b.c", where "a" is resolved to an existing attribute.
-        // Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias
-        // the final expression as "c".
-        val fieldExprs = nestedFields.foldLeft(a: Expression)(UnresolvedGetField)
-        val aliasName = nestedFields.last
-        Some(Alias(fieldExprs, aliasName)())
+        try {
+
+          // The foldLeft adds UnresolvedGetField for every remaining parts of the name,
+          // and aliased it with the last part of the name.
+          // For example, consider name "a.b.c", where "a" is resolved to an existing attribute.
+          // Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias
+          // the final expression as "c".
+          val fieldExprs = nestedFields.foldLeft(a: Expression)(resolveGetField(_, _, resolver))
+          val aliasName = nestedFields.last
+          Some(Alias(fieldExprs, aliasName)())
+        } catch {
+          case a: AnalysisException if !throwErrors => None
+        }
 
       // No matches.
       case Seq() =>
@@ -212,11 +225,46 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
 
       // More than one match.
       case ambiguousReferences =>
-        val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ")
+        val referenceNames = ambiguousReferences.map(_._1).mkString(", ")
         throw new AnalysisException(
           s"Reference '$name' is ambiguous, could be: $referenceNames.")
     }
   }
+
+  /**
+   * Returns the resolved `GetField`, and report error if no desired field or over one
+   * desired fields are found.
+   *
+   * TODO: this code is duplicated from Analyzer and should be refactored to avoid this.
+   */
+  protected def resolveGetField(
+      expr: Expression,
+      fieldName: String,
+      resolver: Resolver): Expression = {
+    def findField(fields: Array[StructField]): Int = {
+      val checkField = (f: StructField) => resolver(f.name, fieldName)
+      val ordinal = fields.indexWhere(checkField)
+      if (ordinal == -1) {
+        throw new AnalysisException(
+          s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
+      } else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
+        throw new AnalysisException(
+          s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
+      } else {
+        ordinal
+      }
+    }
+    expr.dataType match {
+      case StructType(fields) =>
+        val ordinal = findField(fields)
+        StructGetField(expr, fields(ordinal), ordinal)
+      case ArrayType(StructType(fields), containsNull) =>
+        val ordinal = findField(fields)
+        ArrayGetField(expr, fields(ordinal), ordinal, containsNull)
+      case otherType =>
+        throw new AnalysisException(s"GetField is not valid on fields of type $otherType")
+    }
+  }
 }
 
 /**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 756cd36f05..ee7b14c7a1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -40,14 +40,12 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
       override val extendedResolutionRules = EliminateSubQueries :: Nil
     }
 
-  val checkAnalysis = new CheckAnalysis
-
 
   def caseSensitiveAnalyze(plan: LogicalPlan) =
-    checkAnalysis(caseSensitiveAnalyzer(plan))
+    caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer(plan))
 
   def caseInsensitiveAnalyze(plan: LogicalPlan) =
-    checkAnalysis(caseInsensitiveAnalyzer(plan))
+    caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer(plan))
 
   val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
   val testRelation2 = LocalRelation(
@@ -57,6 +55,21 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
     AttributeReference("d", DecimalType.Unlimited)(),
     AttributeReference("e", ShortType)())
 
+  val nestedRelation = LocalRelation(
+    AttributeReference("top", StructType(
+      StructField("duplicateField", StringType) ::
+      StructField("duplicateField", StringType) ::
+      StructField("differentCase", StringType) ::
+      StructField("differentcase", StringType) :: Nil
+    ))())
+
+  val nestedRelation2 = LocalRelation(
+    AttributeReference("top", StructType(
+      StructField("aField", StringType) ::
+      StructField("bField", StringType) ::
+      StructField("cField", StringType) :: Nil
+    ))())
+
   before {
     caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
     caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
@@ -169,6 +182,24 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
     "'b'" :: "group by" :: Nil
   )
 
+  errorTest(
+    "ambiguous field",
+    nestedRelation.select($"top.duplicateField"),
+    "Ambiguous reference to fields" :: "duplicateField" :: Nil,
+    caseSensitive = false)
+
+  errorTest(
+    "ambiguous field due to case insensitivity",
+    nestedRelation.select($"top.differentCase"),
+    "Ambiguous reference to fields" :: "differentCase" :: "differentcase" :: Nil,
+    caseSensitive = false)
+
+  errorTest(
+    "missing field",
+    nestedRelation2.select($"top.c"),
+    "No such struct field" :: "aField" :: "bField" :: "cField" :: Nil,
+    caseSensitive = false)
+
   case class UnresolvedTestPlan() extends LeafNode {
     override lazy val resolved = false
     override def output = Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index b8100782ec..1794936a52 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -120,6 +120,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
         ExtractPythonUdfs ::
         sources.PreInsertCastAndRename ::
         Nil
+
+      override val extendedCheckRules = Seq(
+        sources.PreWriteCheck(catalog)
+      )
     }
 
   @transient
@@ -1065,14 +1069,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
       Batch("Add exchange", Once, AddExchange(self)) :: Nil
   }
 
-  @transient
-  protected[sql] lazy val checkAnalysis = new CheckAnalysis {
-    override val extendedCheckRules = Seq(
-      sources.PreWriteCheck(catalog)
-    )
-  }
-
-
   protected[sql] def openSession(): SQLSession = {
     detachSession()
     val session = createSession()
@@ -1105,7 +1101,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
    */
   @DeveloperApi
   protected[sql] class QueryExecution(val logical: LogicalPlan) {
-    def assertAnalyzed(): Unit = checkAnalysis(analyzed)
+    def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed)
 
     lazy val analyzed: LogicalPlan = analyzer(logical)
     lazy val withCachedData: LogicalPlan = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index a3c0076e16..87e7cf8c8a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1084,10 +1084,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
   test("SPARK-6145: ORDER BY test for nested fields") {
     jsonRDD(sparkContext.makeRDD(
       """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)).registerTempTable("nestedOrder")
-    // These should be successfully analyzed
-    sql("SELECT 1 FROM nestedOrder ORDER BY a.b").queryExecution.analyzed
-    sql("SELECT a.b FROM nestedOrder ORDER BY a.b").queryExecution.analyzed
-    sql("SELECT 1 FROM nestedOrder ORDER BY a.a.a").queryExecution.analyzed
-    sql("SELECT 1 FROM nestedOrder ORDER BY c[0].d").queryExecution.analyzed
+
+    checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1))
+    checkAnswer(sql("SELECT a.b FROM nestedOrder ORDER BY a.b"), Row(1))
+    checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.a.a"), Row(1))
+    checkAnswer(sql("SELECT a.a.a FROM nestedOrder ORDER BY a.a.a"), Row(1))
+    checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY c[0].d"), Row(1))
+    checkAnswer(sql("SELECT c[0].d FROM nestedOrder ORDER BY c[0].d"), Row(1))
+  }
+
+  test("SPARK-6145: special cases") {
+    jsonRDD(sparkContext.makeRDD(
+      """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t")
+    checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1))
+    checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1))
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
index 91c6367371..33c6735596 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
@@ -32,6 +32,10 @@ abstract class DataSourceTest extends QueryTest with BeforeAndAfter {
         override val extendedResolutionRules =
           PreInsertCastAndRename ::
           Nil
+
+        override val extendedCheckRules = Seq(
+          sources.PreWriteCheck(catalog)
+        )
       }
   }
 }
-- 
GitLab