From 4e290522c2a6310636317c54589dc35c91d95486 Mon Sep 17 00:00:00 2001
From: Wenchen Fan <cloud0fan@outlook.com>
Date: Tue, 12 May 2015 11:51:55 -0700
Subject: [PATCH] [SPARK-7276] [DATAFRAME] speed up DataFrame.select by
 collapsing Project

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #5831 from cloud-fan/7276 and squashes the following commits:

ee4a1e1 [Wenchen Fan] fix rebase mistake
a3b565d [Wenchen Fan] refactor
99deb5d [Wenchen Fan] add test
f1f67ad [Wenchen Fan] fix 7276
---
 .../sql/catalyst/optimizer/Optimizer.scala    | 40 +++++++++++--------
 .../optimizer/FilterPushdownSuite.scala       |  3 +-
 .../org/apache/spark/sql/DataFrame.scala      |  4 +-
 .../org/apache/spark/sql/DataFrameSuite.scala | 12 ++++++
 4 files changed, 41 insertions(+), 18 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 1ee5fb245f..b163707cc9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -43,6 +43,7 @@ object DefaultOptimizer extends Optimizer {
       PushPredicateThroughJoin,
       PushPredicateThroughGenerate,
       ColumnPruning,
+      ProjectCollapsing,
       CombineLimits) ::
     Batch("ConstantFolding", FixedPoint(100),
       NullPropagation,
@@ -114,7 +115,7 @@ object UnionPushdown extends Rule[LogicalPlan] {
  *   - Aggregate
  *   - Project <- Join
  *   - LeftSemiJoin
- *  - Collapse adjacent projections, performing alias substitution.
+ *  - Performing alias substitution.
  */
 object ColumnPruning extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
@@ -152,7 +153,28 @@ object ColumnPruning extends Rule[LogicalPlan] {
 
       Join(left, prunedChild(right, allReferences), LeftSemi, condition)
 
-    // Combine adjacent Projects.
+    case Project(projectList, Limit(exp, child)) =>
+      Limit(exp, Project(projectList, child))
+
+    // Eliminate no-op Projects
+    case Project(projectList, child) if child.output == projectList => child
+  }
+
+  /** Applies a projection only when the child is producing unnecessary attributes */
+  private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
+    if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
+      Project(allReferences.filter(c.outputSet.contains).toSeq, c)
+    } else {
+      c
+    }
+}
+
+/**
+ * Combines two adjacent [[Project]] operators into one, merging the
+ * expressions into one single expression.
+ */
+object ProjectCollapsing extends Rule[LogicalPlan] {
+  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
     case Project(projectList1, Project(projectList2, child)) =>
       // Create a map of Aliases to their values from the child projection.
       // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
@@ -169,21 +191,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
       }).asInstanceOf[Seq[NamedExpression]]
 
       Project(substitutedProjection, child)
-
-    case Project(projectList, Limit(exp, child)) =>
-      Limit(exp, Project(projectList, child))
-      
-    // Eliminate no-op Projects
-    case Project(projectList, child) if child.output == projectList => child
   }
-
-  /** Applies a projection only when the child is producing unnecessary attributes */
-  private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
-    if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
-      Project(allReferences.filter(c.outputSet.contains).toSeq, c)
-    } else {
-      c
-    }
 }
 
 /**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 58d415d901..0c428f7231 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -38,7 +38,8 @@ class FilterPushdownSuite extends PlanTest {
         PushPredicateThroughProject,
         PushPredicateThroughJoin,
         PushPredicateThroughGenerate,
-        ColumnPruning) :: Nil
+        ColumnPruning,
+        ProjectCollapsing) :: Nil
   }
 
   val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index f3107f7b51..1f85dac682 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -567,7 +567,9 @@ class DataFrame private[sql](
       case Column(expr: NamedExpression) => expr
       case Column(expr: Expression) => Alias(expr, expr.prettyString)()
     }
-    Project(namedExpressions.toSeq, logicalPlan)
+    // When user continuously call `select`, speed up analysis by collapsing `Project`
+    import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing
+    Project(namedExpressions.toSeq, ProjectCollapsing(logicalPlan))
   }
 
   /**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index d58438e5d1..52aa1f6558 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -493,4 +493,16 @@ class DataFrameSuite extends QueryTest {
       testData.dropDuplicates(Seq("value2")),
       Seq(Row(2, 1, 2), Row(1, 1, 1)))
   }
+
+  test("SPARK-7276: Project collapse for continuous select") {
+    var df = testData
+    for (i <- 1 to 5) {
+      df = df.select($"*")
+    }
+
+    import org.apache.spark.sql.catalyst.plans.logical.Project
+    // make sure df have at most two Projects
+    val p = df.logicalPlan.asInstanceOf[Project].child.asInstanceOf[Project]
+    assert(!p.child.isInstanceOf[Project])
+  }
 }
-- 
GitLab