From 489641117651d11806d2773b7ded7c163d0260e5 Mon Sep 17 00:00:00 2001
From: Wenchen Fan <wenchen@databricks.com>
Date: Mon, 7 Mar 2016 10:32:34 -0800
Subject: [PATCH] [SPARK-13694][SQL] QueryPlan.expressions should always
 include all expressions

## What changes were proposed in this pull request?

It's weird that expressions don't always have all the expressions in it. This PR marks `QueryPlan.expressions` final to forbid sub classes overriding it to exclude some expressions. Currently only `Generate` override it, we can use `producedAttributes` to fix the unresolved attribute problem for it.

Note that this PR doesn't fix the problem in #11497

## How was this patch tested?

existing tests.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #11532 from cloud-fan/generate.
---
 .../scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala | 2 +-
 .../spark/sql/catalyst/plans/logical/basicOperators.scala     | 4 +---
 .../org/apache/spark/sql/catalyst/plans/logical/object.scala  | 2 --
 .../main/scala/org/apache/spark/sql/execution/Generate.scala  | 2 +-
 4 files changed, 3 insertions(+), 7 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 0e0453b517..c62d5ead86 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -194,7 +194,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
   }
 
   /** Returns all of the expressions present in this query plan operator. */
-  def expressions: Seq[Expression] = {
+  final def expressions: Seq[Expression] = {
     // Recursively find all expressions from a traversable.
     def seqToExpressions(seq: Traversable[Any]): Traversable[Expression] = seq.flatMap {
       case e: Expression => e :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 522348735a..411594c951 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -89,9 +89,7 @@ case class Generate(
       generatorOutput.forall(_.resolved)
   }
 
-  // we don't want the gOutput to be taken as part of the expressions
-  // as that will cause exceptions like unresolved attributes etc.
-  override def expressions: Seq[Expression] = generator :: Nil
+  override def producedAttributes: AttributeSet = AttributeSet(generatorOutput)
 
   def output: Seq[Attribute] = {
     val qualified = qualifier.map(q =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 3f97662957..da7f81c785 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -208,8 +208,6 @@ case class CoGroup(
     left: LogicalPlan,
     right: LogicalPlan) extends BinaryNode with ObjectOperator {
 
-  override def producedAttributes: AttributeSet = outputSet
-
   override def deserializers: Seq[(Expression, Seq[Attribute])] =
     // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to resolve
     // the `keyDeserializer` based on either of them, here we pick the left one.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
index 6bc4649d43..9938d2169f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
@@ -58,7 +58,7 @@ case class Generate(
   private[sql] override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
 
-  override def expressions: Seq[Expression] = generator :: Nil
+  override def producedAttributes: AttributeSet = AttributeSet(output)
 
   val boundGenerator = BindReferences.bindReference(generator, child.output)
 
-- 
GitLab