Skip to content
Snippets Groups Projects
Commit 26f38bb8 authored by Davies Liu's avatar Davies Liu Committed by Josh Rosen
Browse files

[SPARK-13351][SQL] fix column pruning on Expand

Currently, the columns in projects of Expand that are not used by Aggregate are not pruned, this PR fix that.

Author: Davies Liu <davies@databricks.com>

Closes #11225 from davies/fix_pruning_expand.
parent 78562535
No related branches found
No related tags found
No related merge requests found
......@@ -300,6 +300,16 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
*/
object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case a @ Aggregate(_, _, e @ Expand(projects, output, child))
if (e.outputSet -- a.references).nonEmpty =>
val newOutput = output.filter(a.references.contains(_))
val newProjects = projects.map { proj =>
proj.zip(output).filter { case (e, a) =>
newOutput.contains(a)
}.unzip._1
}
a.copy(child = Expand(newProjects, newOutput, child))
case a @ Aggregate(_, _, e @ Expand(_, _, child))
if (child.outputSet -- e.references -- a.references).nonEmpty =>
a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references)))
......
......@@ -19,9 +19,9 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Explode
import org.apache.spark.sql.catalyst.expressions.{Explode, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.StringType
......@@ -96,5 +96,34 @@ class ColumnPruningSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
test("Column pruning for Expand") {
val input = LocalRelation('a.int, 'b.string, 'c.double)
val query =
Aggregate(
Seq('aa, 'gid),
Seq(sum('c).as("sum")),
Expand(
Seq(
Seq('a, 'b, 'c, Literal.create(null, StringType), 1),
Seq('a, 'b, 'c, 'a, 2)),
Seq('a, 'b, 'c, 'aa.int, 'gid.int),
input)).analyze
val optimized = Optimize.execute(query)
val expected =
Aggregate(
Seq('aa, 'gid),
Seq(sum('c).as("sum")),
Expand(
Seq(
Seq('c, Literal.create(null, StringType), 1),
Seq('c, 'a, 2)),
Seq('c, 'aa.int, 'gid.int),
Project(Seq('c, 'a),
input))).analyze
comparePlans(optimized, expected)
}
// todo: add more tests for column pruning
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment