From 31a229aa739b6d05ec6d91b820fcca79b6b7d6fe Mon Sep 17 00:00:00 2001
From: Wenchen Fan <cloud0fan@outlook.com>
Date: Tue, 15 Sep 2015 13:36:52 -0700
Subject: [PATCH] [SPARK-10475] [SQL] improve column prunning for Project on
 Sort

Sometimes we can't push down the whole `Project` though `Sort`, but we still have a chance to push down part of it.

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #8644 from cloud-fan/column-prune.
---
 .../sql/catalyst/optimizer/Optimizer.scala    | 19 +++++++++++++++----
 .../optimizer/ColumnPruningSuite.scala        | 11 +++++++++++
 2 files changed, 26 insertions(+), 4 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 0f4caec745..648a65e7c0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -228,10 +228,21 @@ object ColumnPruning extends Rule[LogicalPlan] {
     case Project(projectList, Limit(exp, child)) =>
       Limit(exp, Project(projectList, child))
 
-    // Push down project if possible when the child is sort
-    case p @ Project(projectList, s @ Sort(_, _, grandChild))
-      if s.references.subsetOf(p.outputSet) =>
-      s.copy(child = Project(projectList, grandChild))
+    // Push down project if possible when the child is sort.
+    case p @ Project(projectList, s @ Sort(_, _, grandChild)) =>
+      if (s.references.subsetOf(p.outputSet)) {
+        s.copy(child = Project(projectList, grandChild))
+      } else {
+        val neededReferences = s.references ++ p.references
+        if (neededReferences == grandChild.outputSet) {
+          // No column we can prune, return the original plan.
+          p
+        } else {
+          // Do not use neededReferences.toSeq directly, should respect grandChild's output order.
+          val newProjectList = grandChild.output.filter(neededReferences.contains)
+          p.copy(child = s.copy(child = Project(newProjectList, grandChild)))
+        }
+      }
 
     // Eliminate no-op Projects
     case Project(projectList, child) if child.output == projectList => child
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index dbebcb8680..4a1e7ceaf3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -80,5 +80,16 @@ class ColumnPruningSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
+  test("Column pruning for Project on Sort") {
+    val input = LocalRelation('a.int, 'b.string, 'c.double)
+
+    val query = input.orderBy('b.asc).select('a).analyze
+    val optimized = Optimize.execute(query)
+
+    val correctAnswer = input.select('a, 'b).orderBy('b.asc).select('a).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
   // todo: add more tests for column pruning
 }
-- 
GitLab