Skip to content
Snippets Groups Projects
Commit b0dbaec4 authored by Wenchen Fan's avatar Wenchen Fan Committed by Michael Armbrust
Browse files

[SPARK-6489] [SQL] add column pruning for Generate

This PR takes over https://github.com/apache/spark/pull/5358

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #8268 from cloud-fan/6489.
parent e0dd1309
No related branches found
No related tags found
No related merge requests found
......@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.expressions
import scala.collection.Map
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
......
......@@ -165,6 +165,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] {
*
* - Inserting Projections beneath the following operators:
* - Aggregate
* - Generate
* - Project <- Join
* - LeftSemiJoin
*/
......@@ -178,6 +179,21 @@ object ColumnPruning extends Rule[LogicalPlan] {
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
a.copy(child = Project(a.references.toSeq, child))
// Eliminate attributes that are not needed to calculate the Generate.
case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty =>
g.copy(child = Project(g.references.toSeq, g.child))
case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) =>
p.copy(child = g.copy(join = false))
case p @ Project(projectList, g: Generate) if g.join =>
val neededChildOutput = p.references -- g.generatorOutput ++ g.references
if (neededChildOutput == g.child.outputSet) {
p
} else {
Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child)))
}
case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child))
if (a.outputSet -- p.references).nonEmpty =>
Project(
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions.Explode
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation, Generate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.types.StringType
class ColumnPruningSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Column pruning", FixedPoint(100),
ColumnPruning) :: Nil
}
test("Column pruning for Generate when Generate.join = false") {
val input = LocalRelation('a.int, 'b.array(StringType))
val query = Generate(Explode('b), false, false, None, 's.string :: Nil, input).analyze
val optimized = Optimize.execute(query)
val correctAnswer =
Generate(Explode('b), false, false, None, 's.string :: Nil,
Project('b.attr :: Nil, input)).analyze
comparePlans(optimized, correctAnswer)
}
test("Column pruning for Generate when Generate.join = true") {
val input = LocalRelation('a.int, 'b.int, 'c.array(StringType))
val query =
Project(Seq('a, 's),
Generate(Explode('c), true, false, None, 's.string :: Nil,
input)).analyze
val optimized = Optimize.execute(query)
val correctAnswer =
Project(Seq('a, 's),
Generate(Explode('c), true, false, None, 's.string :: Nil,
Project(Seq('a, 'c),
input))).analyze
comparePlans(optimized, correctAnswer)
}
test("Turn Generate.join to false if possible") {
val input = LocalRelation('b.array(StringType))
val query =
Project(('s + 1).as("s+1") :: Nil,
Generate(Explode('b), true, false, None, 's.string :: Nil,
input)).analyze
val optimized = Optimize.execute(query)
val correctAnswer =
Project(('s + 1).as("s+1") :: Nil,
Generate(Explode('b), false, false, None, 's.string :: Nil,
input)).analyze
comparePlans(optimized, correctAnswer)
}
// todo: add more tests for column pruning
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment