diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 04639219a365065ee8fafb5da7a2d1d3d5daafd9..ea9bb3978691a4532b00ac23306ca69446fbfff7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -58,7 +58,6 @@ class Analyzer(catalog: Catalog, ResolveSortReferences :: NewRelationInstances :: ImplicitGenerate :: - StarExpansion :: ResolveFunctions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: @@ -153,7 +152,34 @@ class Analyzer(catalog: Catalog, */ object ResolveReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case q: LogicalPlan if q.childrenResolved => + case p: LogicalPlan if !p.childrenResolved => p + + // If the projection list contains Stars, expand it. + case p@Project(projectList, child) if containsStar(projectList) => + Project( + projectList.flatMap { + case s: Star => s.expand(child.output, resolver) + case o => o :: Nil + }, + child) + case t: ScriptTransformation if containsStar(t.input) => + t.copy( + input = t.input.flatMap { + case s: Star => s.expand(t.child.output, resolver) + case o => o :: Nil + } + ) + + // If the aggregate function argument contains Stars, expand it. + case a: Aggregate if containsStar(a.aggregateExpressions) => + a.copy( + aggregateExpressions = a.aggregateExpressions.flatMap { + case s: Star => s.expand(a.child.output, resolver) + case o => o :: Nil + } + ) + + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressions { case u @ UnresolvedAttribute(name) => @@ -163,6 +189,12 @@ class Analyzer(catalog: Catalog, result } } + + /** + * Returns true if `exprs` contains a [[Star]]. + */ + protected def containsStar(exprs: Seq[Expression]): Boolean = + exprs.collect { case _: Star => true}.nonEmpty } /** @@ -277,45 +309,6 @@ class Analyzer(catalog: Catalog, Generate(g, join = false, outer = false, None, child) } } - - /** - * Expands any references to [[Star]] (*) in project operators. - */ - object StarExpansion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Wait until children are resolved - case p: LogicalPlan if !p.childrenResolved => p - // If the projection list contains Stars, expand it. - case p @ Project(projectList, child) if containsStar(projectList) => - Project( - projectList.flatMap { - case s: Star => s.expand(child.output, resolver) - case o => o :: Nil - }, - child) - case t: ScriptTransformation if containsStar(t.input) => - t.copy( - input = t.input.flatMap { - case s: Star => s.expand(t.child.output, resolver) - case o => o :: Nil - } - ) - // If the aggregate function argument contains Stars, expand it. - case a: Aggregate if containsStar(a.aggregateExpressions) => - a.copy( - aggregateExpressions = a.aggregateExpressions.flatMap { - case s: Star => s.expand(a.child.output, resolver) - case o => o :: Nil - } - ) - } - - /** - * Returns true if `exprs` contains a [[Star]]. - */ - protected def containsStar(exprs: Seq[Expression]): Boolean = - exprs.collect { case _: Star => true }.nonEmpty - } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 33a3cba3d4c0e1e6c697afc3ff58c862848e33c9..82f2101d8ce1700e2c5a2cdd7a2755f350dfb92c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -19,12 +19,14 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.{BeforeAndAfter, FunSuite} -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ + class AnalysisSuite extends FunSuite with BeforeAndAfter { val caseSensitiveCatalog = new SimpleCatalog(true) val caseInsensitiveCatalog = new SimpleCatalog(false) @@ -46,6 +48,14 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { caseInsensitiveCatalog.registerTable(None, "TaBlE", testRelation) } + test("union project *") { + val plan = (1 to 100) + .map(_ => testRelation) + .fold[LogicalPlan](testRelation)((a,b) => a.select(Star(None)).select('a).unionAll(b.select(Star(None)))) + + assert(caseInsensitiveAnalyze(plan).resolved) + } + test("analyze project") { assert( caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===