diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index d474853355e5baa17459876a981706ed584b4eb3..c0845e1a0102fdbee225015f71b3d65b73549c25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import scala.collection.Map - import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 47b06cae154363d2b084c6aded3f0d08ef53ebc7..42457d5318b4870acd9b62f41462b26e884522a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -165,6 +165,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] { * * - Inserting Projections beneath the following operators: * - Aggregate + * - Generate * - Project <- Join * - LeftSemiJoin */ @@ -178,6 +179,21 @@ object ColumnPruning extends Rule[LogicalPlan] { case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = Project(a.references.toSeq, child)) + // Eliminate attributes that are not needed to calculate the Generate. + case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => + g.copy(child = Project(g.references.toSeq, g.child)) + + case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => + p.copy(child = g.copy(join = false)) + + case p @ Project(projectList, g: Generate) if g.join => + val neededChildOutput = p.references -- g.generatorOutput ++ g.references + if (neededChildOutput == g.child.outputSet) { + p + } else { + Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child))) + } + case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child)) if (a.outputSet -- p.references).nonEmpty => Project( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..dbebcb86809de37f7ef73f49bcf0d6c048012ee2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.Explode +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation, Generate, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.types.StringType + +class ColumnPruningSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("Column pruning", FixedPoint(100), + ColumnPruning) :: Nil + } + + test("Column pruning for Generate when Generate.join = false") { + val input = LocalRelation('a.int, 'b.array(StringType)) + + val query = Generate(Explode('b), false, false, None, 's.string :: Nil, input).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = + Generate(Explode('b), false, false, None, 's.string :: Nil, + Project('b.attr :: Nil, input)).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Column pruning for Generate when Generate.join = true") { + val input = LocalRelation('a.int, 'b.int, 'c.array(StringType)) + + val query = + Project(Seq('a, 's), + Generate(Explode('c), true, false, None, 's.string :: Nil, + input)).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = + Project(Seq('a, 's), + Generate(Explode('c), true, false, None, 's.string :: Nil, + Project(Seq('a, 'c), + input))).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Turn Generate.join to false if possible") { + val input = LocalRelation('b.array(StringType)) + + val query = + Project(('s + 1).as("s+1") :: Nil, + Generate(Explode('b), true, false, None, 's.string :: Nil, + input)).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = + Project(('s + 1).as("s+1") :: Nil, + Generate(Explode('b), false, false, None, 's.string :: Nil, + input)).analyze + + comparePlans(optimized, correctAnswer) + } + + // todo: add more tests for column pruning +}