From 5fd57955ef477347408f68eb1cb6ad1881fdb6e0 Mon Sep 17 00:00:00 2001
From: Wenchen Fan <cloud0fan@outlook.com>
Date: Tue, 8 Sep 2015 12:05:41 -0700
Subject: [PATCH] [SPARK-10316] [SQL] respect nondeterministic expressions in
 PhysicalOperation

We did a lot of special handling for non-deterministic expressions in `Optimizer`. However, `PhysicalOperation` just collects all Projects and Filters and mess it up. We should respect the operators order caused by non-deterministic expressions in `PhysicalOperation`.

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #8486 from cloud-fan/fix.
---
 .../sql/catalyst/planning/patterns.scala      | 38 ++++---------------
 .../org/apache/spark/sql/DataFrameSuite.scala | 12 ++++++
 2 files changed, 20 insertions(+), 30 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index e8abcd63f7..5353779951 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -17,35 +17,12 @@
 
 package org.apache.spark.sql.catalyst.planning
 
-import scala.annotation.tailrec
-
 import org.apache.spark.Logging
 import org.apache.spark.sql.catalyst.trees.TreeNodeRef
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 
-/**
- * A pattern that matches any number of filter operations on top of another relational operator.
- * Adjacent filter operators are collected and their conditions are broken up and returned as a
- * sequence of conjunctive predicates.
- *
- * @return A tuple containing a sequence of conjunctive predicates that should be used to filter the
- *         output and a relational operator.
- */
-object FilteredOperation extends PredicateHelper {
-  type ReturnType = (Seq[Expression], LogicalPlan)
-
-  def unapply(plan: LogicalPlan): Option[ReturnType] = Some(collectFilters(Nil, plan))
-
-  @tailrec
-  private def collectFilters(filters: Seq[Expression], plan: LogicalPlan): ReturnType = plan match {
-    case Filter(condition, child) =>
-      collectFilters(filters ++ splitConjunctivePredicates(condition), child)
-    case other => (filters, other)
-  }
-}
-
 /**
  * A pattern that matches any number of project or filter operations on top of another relational
  * operator.  All filter operators are collected and their conditions are broken up and returned
@@ -62,8 +39,9 @@ object PhysicalOperation extends PredicateHelper {
   }
 
   /**
-   * Collects projects and filters, in-lining/substituting aliases if necessary.  Here are two
-   * examples for alias in-lining/substitution.  Before:
+   * Collects all deterministic projects and filters, in-lining/substituting aliases if necessary.
+   * Here are two examples for alias in-lining/substitution.
+   * Before:
    * {{{
    *   SELECT c1 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10
    *   SELECT c1 AS c2 FROM (SELECT key AS c1 FROM t1) t2 WHERE c1 > 10
@@ -74,15 +52,15 @@ object PhysicalOperation extends PredicateHelper {
    *   SELECT key AS c2 FROM t1 WHERE key > 10
    * }}}
    */
-  def collectProjectsAndFilters(plan: LogicalPlan):
+  private def collectProjectsAndFilters(plan: LogicalPlan):
       (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, Map[Attribute, Expression]) =
     plan match {
-      case Project(fields, child) =>
+      case Project(fields, child) if fields.forall(_.deterministic) =>
         val (_, filters, other, aliases) = collectProjectsAndFilters(child)
         val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
         (Some(substitutedFields), filters, other, collectAliases(substitutedFields))
 
-      case Filter(condition, child) =>
+      case Filter(condition, child) if condition.deterministic =>
         val (fields, filters, other, aliases) = collectProjectsAndFilters(child)
         val substitutedCondition = substitute(aliases)(condition)
         (fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases)
@@ -91,11 +69,11 @@ object PhysicalOperation extends PredicateHelper {
         (None, Nil, other, Map.empty)
     }
 
-  def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect {
+  private def collectAliases(fields: Seq[Expression]): Map[Attribute, Expression] = fields.collect {
     case a @ Alias(child, _) => a.toAttribute -> child
   }.toMap
 
-  def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = {
+  private def substitute(aliases: Map[Attribute, Expression])(expr: Expression): Expression = {
     expr.transform {
       case a @ Alias(ref: AttributeReference, name) =>
         aliases.get(ref).map(Alias(_, name)(a.exprId, a.qualifiers)).getOrElse(a)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index b5b9f11785..dbed4fc247 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -22,6 +22,8 @@ import java.io.File
 import scala.language.postfixOps
 import scala.util.Random
 
+import org.scalatest.Matchers._
+
 import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
@@ -895,4 +897,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
       .orderBy(sum('j))
     checkAnswer(query, Row(1, 2))
   }
+
+  test("SPARK-10316: respect non-deterministic expressions in PhysicalOperation") {
+    val input = sqlContext.read.json(sqlContext.sparkContext.makeRDD(
+      (1 to 10).map(i => s"""{"id": $i}""")))
+
+    val df = input.select($"id", rand(0).as('r))
+    df.as("a").join(df.filter($"r" < 0.5).as("b"), $"a.id" === $"b.id").collect().foreach { row =>
+      assert(row.getDouble(1) - row.getDouble(3) === 0.0 +- 0.001)
+    }
+  }
 }
-- 
GitLab