From 3f6d28a5ca98cf7d20c2c029094350cc4f9545a0 Mon Sep 17 00:00:00 2001
From: Wenchen Fan <cloud0fan@outlook.com>
Date: Fri, 17 Jul 2015 00:59:15 -0700
Subject: [PATCH] [SPARK-9102] [SQL] Improve project collapse with
 nondeterministic expressions

Currently we will stop project collapse when the lower projection has nondeterministic expressions. However it's overkill sometimes, we should be able to optimize `df.select(Rand(10)).select('a)` to `df.select('a)`

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #7445 from cloud-fan/non-deterministic and squashes the following commits:

0deaef6 [Wenchen Fan] Improve project collapse with nondeterministic expressions
---
 .../sql/catalyst/optimizer/Optimizer.scala    | 38 ++++++++++---------
 .../optimizer/ProjectCollapsingSuite.scala    | 26 +++++++++++++
 .../org/apache/spark/sql/DataFrameSuite.scala | 10 ++---
 3 files changed, 51 insertions(+), 23 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 2f94b457f4..d5beeec0ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -206,31 +206,33 @@ object ColumnPruning extends Rule[LogicalPlan] {
  */
 object ProjectCollapsing extends Rule[LogicalPlan] {
 
-  /** Returns true if any expression in projectList is non-deterministic. */
-  private def hasNondeterministic(projectList: Seq[NamedExpression]): Boolean = {
-    projectList.exists(expr => expr.find(!_.deterministic).isDefined)
-  }
-
   def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
-    // We only collapse these two Projects if the child Project's expressions are all
-    // deterministic.
-    case Project(projectList1, Project(projectList2, child))
-         if !hasNondeterministic(projectList2) =>
+    case p @ Project(projectList1, Project(projectList2, child)) =>
       // Create a map of Aliases to their values from the child projection.
       // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
       val aliasMap = AttributeMap(projectList2.collect {
-        case a @ Alias(e, _) => (a.toAttribute, a)
+        case a: Alias => (a.toAttribute, a)
       })
 
-      // Substitute any attributes that are produced by the child projection, so that we safely
-      // eliminate it.
-      // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
-      // TODO: Fix TransformBase to avoid the cast below.
-      val substitutedProjection = projectList1.map(_.transform {
-        case a: Attribute if aliasMap.contains(a) => aliasMap(a)
-      }).asInstanceOf[Seq[NamedExpression]]
+      // We only collapse these two Projects if their overlapped expressions are all
+      // deterministic.
+      val hasNondeterministic = projectList1.flatMap(_.collect {
+        case a: Attribute if aliasMap.contains(a) => aliasMap(a).child
+      }).exists(_.find(!_.deterministic).isDefined)
 
-      Project(substitutedProjection, child)
+      if (hasNondeterministic) {
+        p
+      } else {
+        // Substitute any attributes that are produced by the child projection, so that we safely
+        // eliminate it.
+        // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
+        // TODO: Fix TransformBase to avoid the cast below.
+        val substitutedProjection = projectList1.map(_.transform {
+          case a: Attribute => aliasMap.getOrElse(a, a)
+        }).asInstanceOf[Seq[NamedExpression]]
+
+        Project(substitutedProjection, child)
+      }
   }
 }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala
index 151654bffb..1aa89991cc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala
@@ -70,4 +70,30 @@ class ProjectCollapsingSuite extends PlanTest {
 
     comparePlans(optimized, correctAnswer)
   }
+
+  test("collapse two nondeterministic, independent projects into one") {
+    val query = testRelation
+      .select(Rand(10).as('rand))
+      .select(Rand(20).as('rand2))
+
+    val optimized = Optimize.execute(query.analyze)
+
+    val correctAnswer = testRelation
+      .select(Rand(20).as('rand2)).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
+
+  test("collapse one nondeterministic, one deterministic, independent projects into one") {
+    val query = testRelation
+      .select(Rand(10).as('rand), 'a)
+      .select(('a + 1).as('a_plus_1))
+
+    val optimized = Optimize.execute(query.analyze)
+
+    val correctAnswer = testRelation
+      .select(('a + 1).as('a_plus_1)).analyze
+
+    comparePlans(optimized, correctAnswer)
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 23244fd310..192cc0a6e5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -745,8 +745,8 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
   test("SPARK-8072: Better Exception for Duplicate Columns") {
     // only one duplicate column present
     val e = intercept[org.apache.spark.sql.AnalysisException] {
-      val df1 = Seq((1, 2, 3), (2, 3, 4), (3, 4, 5)).toDF("column1", "column2", "column1")
-                .write.format("parquet").save("temp")
+      Seq((1, 2, 3), (2, 3, 4), (3, 4, 5)).toDF("column1", "column2", "column1")
+        .write.format("parquet").save("temp")
     }
     assert(e.getMessage.contains("Duplicate column(s)"))
     assert(e.getMessage.contains("parquet"))
@@ -755,9 +755,9 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
 
     // multiple duplicate columns present
     val f = intercept[org.apache.spark.sql.AnalysisException] {
-      val df2 = Seq((1, 2, 3, 4, 5), (2, 3, 4, 5, 6), (3, 4, 5, 6, 7))
-                .toDF("column1", "column2", "column3", "column1", "column3")
-                .write.format("json").save("temp")
+      Seq((1, 2, 3, 4, 5), (2, 3, 4, 5, 6), (3, 4, 5, 6, 7))
+        .toDF("column1", "column2", "column3", "column1", "column3")
+        .write.format("json").save("temp")
     }
     assert(f.getMessage.contains("Duplicate column(s)"))
     assert(f.getMessage.contains("JSON"))
-- 
GitLab