diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9ab0a20a5276e8217e8f84161e455a7729e5e682..b654827b8a28ed3f6ff6b542b2176c68d01c801c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -421,7 +421,7 @@ class Analyzer( val newOutput = oldVersion.generatorOutput.map(_.newInstance()) (oldVersion, oldVersion.copy(generatorOutput = newOutput)) - case oldVersion @ Window(_, windowExpressions, _, _, child) + case oldVersion @ Window(windowExpressions, _, _, child) if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) .nonEmpty => (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) @@ -658,10 +658,6 @@ class Analyzer( case p: Project => val missing = missingAttrs -- p.child.outputSet Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing)) - case w: Window => - val missing = missingAttrs -- w.child.outputSet - w.copy(projectList = w.projectList ++ missingAttrs, - child = addMissingAttr(w.child, missing)) case a: Aggregate => // all the missing attributes should be grouping expressions // TODO: push down AggregateExpression @@ -1166,7 +1162,6 @@ class Analyzer( // Set currentChild to the newly created Window operator. currentChild = Window( - currentChild.output, windowExpressions, partitionSpec, orderSpec, @@ -1436,10 +1431,10 @@ object CleanupAliases extends Rule[LogicalPlan] { val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) Aggregate(grouping.map(trimAliases), cleanedAggs, child) - case w @ Window(projectList, windowExprs, partitionSpec, orderSpec, child) => + case w @ Window(windowExprs, partitionSpec, orderSpec, child) => val cleanedWindowExprs = windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression]) - Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases), + Window(cleanedWindowExprs, partitionSpec.map(trimAliases), orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) // Operators that operate on objects should only have expressions from encoders, which should diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 63463265e3971606db3fe56a4c5d3a9245374365..dc5264e2660d8356147c572ea2ee76f782ef2fb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -268,6 +268,12 @@ package object dsl { Aggregate(groupingExprs, aliasedExprs, logicalPlan) } + def window( + windowExpressions: Seq[NamedExpression], + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder]): LogicalPlan = + Window(windowExpressions, partitionSpec, orderSpec, logicalPlan) + def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan) def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 650b4eef6e4490672d98bafe49c26f86adc8affb..85776670e5c4e0596a84f88894dcdf438ae9d553 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -315,21 +315,17 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { * - LeftSemiJoin */ object ColumnPruning extends Rule[LogicalPlan] { - def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = + private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean = output1.size == output2.size && output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2)) def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Prunes the unused columns from project list of Project/Aggregate/Window/Expand + // Prunes the unused columns from project list of Project/Aggregate/Expand case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty => p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains))) case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty => p.copy( child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains))) - case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty => - p.copy(child = w.copy( - projectList = w.projectList.filter(p.references.contains), - windowExpressions = w.windowExpressions.filter(p.references.contains))) case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => val newOutput = e.output.filter(a.references.contains(_)) val newProjects = e.projections.map { proj => @@ -343,11 +339,9 @@ object ColumnPruning extends Rule[LogicalPlan] { case mp @ MapPartitions(_, _, _, child) if (child.outputSet -- mp.references).nonEmpty => mp.copy(child = prunedChild(child, mp.references)) - // Prunes the unused columns from child of Aggregate/Window/Expand/Generate + // Prunes the unused columns from child of Aggregate/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = prunedChild(child, a.references)) - case w @ Window(_, _, _, _, child) if (child.outputSet -- w.references).nonEmpty => - w.copy(child = prunedChild(child, w.references)) case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => e.copy(child = prunedChild(child, e.references)) case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => @@ -381,6 +375,14 @@ object ColumnPruning extends Rule[LogicalPlan] { p } + // Prune unnecessary window expressions + case p @ Project(_, w: Window) if (w.windowOutputSet -- p.references).nonEmpty => + p.copy(child = w.copy( + windowExpressions = w.windowExpressions.filter(p.references.contains))) + + // Eliminate no-op Window + case w: Window if w.windowExpressions.isEmpty => w.child + // Eliminate no-op Projects case p @ Project(projectList, child) if sameOutput(child.output, p.output) => child diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 3bc246a32dec88063522cc605a3d32346058fb3a..09ea3fea6a69414e8c6f374692190f071f615d14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -434,14 +434,15 @@ case class Aggregate( } case class Window( - projectList: Seq[Attribute], windowExpressions: Seq[NamedExpression], partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = - projectList ++ windowExpressions.map(_.toAttribute) + child.output ++ windowExpressions.map(_.toAttribute) + + def windowOutputSet: AttributeSet = AttributeSet(windowExpressions.map(_.toAttribute)) } private[sql] object Expand { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 409e92238e29f2347acfd8d1f4d524efa0a4c334..dd7d65ddc9e96a91a522a31f181b68635552285b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -33,7 +34,8 @@ class ColumnPruningSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Column pruning", FixedPoint(100), - ColumnPruning) :: Nil + ColumnPruning, + CollapseProject) :: Nil } test("Column pruning for Generate when Generate.join = false") { @@ -258,6 +260,68 @@ class ColumnPruningSuite extends PlanTest { comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1)) } + test("Column pruning on Window with useless aggregate functions") { + val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) + + val originalQuery = + input.groupBy('a, 'c, 'd)('a, 'c, 'd, + WindowExpression( + AggregateExpression(Count('b), Complete, isDistinct = false), + WindowSpecDefinition( 'a :: Nil, + SortOrder('b, Ascending) :: Nil, + UnspecifiedFrame)).as('window)).select('a, 'c) + + val correctAnswer = input.select('a, 'c, 'd).groupBy('a, 'c, 'd)('a, 'c).analyze + + val optimized = Optimize.execute(originalQuery.analyze) + + comparePlans(optimized, correctAnswer) + } + + test("Column pruning on Window with selected agg expressions") { + val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) + + val originalQuery = + input.select('a, 'b, 'c, 'd, + WindowExpression( + AggregateExpression(Count('b), Complete, isDistinct = false), + WindowSpecDefinition( 'a :: Nil, + SortOrder('b, Ascending) :: Nil, + UnspecifiedFrame)).as('window)).where('window > 1).select('a, 'c) + + val correctAnswer = + input.select('a, 'b, 'c) + .window(WindowExpression( + AggregateExpression(Count('b), Complete, isDistinct = false), + WindowSpecDefinition( 'a :: Nil, + SortOrder('b, Ascending) :: Nil, + UnspecifiedFrame)).as('window) :: Nil, + 'a :: Nil, 'b.asc :: Nil) + .select('a, 'c, 'window).where('window > 1).select('a, 'c).analyze + + val optimized = Optimize.execute(originalQuery.analyze) + + comparePlans(optimized, correctAnswer) + } + + test("Column pruning on Window in select") { + val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) + + val originalQuery = + input.select('a, 'b, 'c, 'd, + WindowExpression( + AggregateExpression(Count('b), Complete, isDistinct = false), + WindowSpecDefinition( 'a :: Nil, + SortOrder('b, Ascending) :: Nil, + UnspecifiedFrame)).as('window)).select('a, 'c) + + val correctAnswer = input.select('a, 'c).analyze + + val optimized = Optimize.execute(originalQuery.analyze) + + comparePlans(optimized, correctAnswer) + } + test("Column pruning on Union") { val input1 = LocalRelation('a.int, 'b.string, 'c.double) val input2 = LocalRelation('c.int, 'd.string, 'e.double) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index debd04aa95b9c5bf57f3fc4bd9933e36242b4568..bae075078808898d76c60ee0e6d29b48a3149b6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -344,9 +344,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Filter(condition, planLater(child)) :: Nil case e @ logical.Expand(_, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil - case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) => - execution.Window( - projectList, windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil + case logical.Window(windowExprs, partitionSpec, orderSpec, child) => + execution.Window(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 84154a47de393b59c0603725c94b129bdedaa922..a4c0e1c9fba4136c38c45c79e4c0a96d90d98ba7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -81,14 +81,14 @@ import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, Unsaf * of specialized classes: [[RowBoundOrdering]] & [[RangeBoundOrdering]]. */ case class Window( - projectList: Seq[Attribute], windowExpression: Seq[NamedExpression], partitionSpec: Seq[Expression], orderSpec: Seq[SortOrder], child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = projectList ++ windowExpression.map(_.toAttribute) + override def output: Seq[Attribute] = + child.output ++ windowExpression.map(_.toAttribute) override def requiredChildDistribution: Seq[Distribution] = { if (partitionSpec.isEmpty) { @@ -275,7 +275,7 @@ case class Window( val unboundToRefMap = expressions.zip(references).toMap val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) UnsafeProjection.create( - projectList ++ patchedWindowExpression, + child.output ++ patchedWindowExpression, child.output) }