From 4d286c903b1f88ef175209156b72ccbc3b9e8ae7 Mon Sep 17 00:00:00 2001
From: Davies Liu <davies@databricks.com>
Date: Fri, 20 Jan 2017 16:11:40 -0800
Subject: [PATCH] [SPARK-18589][SQL] Fix Python UDF accessing attributes from
 both side of join

PythonUDF is unevaluable, which can not be used inside a join condition, currently the optimizer will push a PythonUDF which accessing both side of join into the join condition, then the query will fail to plan.

This PR fix this issue by checking the expression is evaluable  or not before pushing it into Join.

Add a regression test.

Author: Davies Liu <davies@databricks.com>

Closes #16581 from davies/pyudf_join.
---
 python/pyspark/sql/tests.py                         |  9 +++++++++
 .../spark/sql/catalyst/expressions/predicates.scala | 13 ++++++++++++-
 .../spark/sql/catalyst/optimizer/Optimizer.scala    |  2 +-
 .../apache/spark/sql/catalyst/optimizer/joins.scala |  5 ++---
 4 files changed, 24 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 20b9351ca8..877ab88d17 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -342,6 +342,15 @@ class SQLTests(ReusedPySparkTestCase):
         df = df.withColumn('b', udf(lambda x: 'x')(df.a))
         self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')])
 
+    def test_udf_in_filter_on_top_of_join(self):
+        # regression test for SPARK-18589
+        from pyspark.sql.functions import udf
+        left = self.spark.createDataFrame([Row(a=1)])
+        right = self.spark.createDataFrame([Row(b=1)])
+        f = udf(lambda a, b: a == b, BooleanType())
+        df = left.crossJoin(right).filter(f("a", "b"))
+        self.assertEqual(df.collect(), [Row(a=1, b=1)])
+
     def test_udf_without_arguments(self):
         self.spark.catalog.registerFunction("foo", lambda: "bar")
         [row] = self.spark.sql("SELECT foo()").collect()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 3fcbb05372..ac56ff13fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
-import org.apache.spark.util.Utils
 
 
 object InterpretedPredicate {
@@ -86,6 +85,18 @@ trait PredicateHelper {
    */
   protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean =
     expr.references.subsetOf(plan.outputSet)
+
+  /**
+   * Returns true iff `expr` could be evaluated as a condition within join.
+   */
+  protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match {
+    case e: SubqueryExpression =>
+      // non-correlated subquery will be replaced as literal
+      e.children.isEmpty
+    case a: AttributeReference => true
+    case e: Unevaluable => false
+    case e => e.children.forall(canEvaluateWithinJoin)
+  }
 }
 
 @ExpressionDescription(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index dfd66aac2d..06fcbcb4ae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -892,7 +892,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
           val newRight = rightFilterConditions.
             reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
           val (newJoinConditions, others) =
-            commonFilterCondition.partition(e => !SubqueryExpression.hasCorrelatedSubquery(e))
+            commonFilterCondition.partition(canEvaluateWithinJoin)
           val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And)
 
           val join = Join(newLeft, newRight, joinType, newJoinCond)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index 180ad2e0ad..bfe529e21e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -46,8 +46,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
     : LogicalPlan = {
     assert(input.size >= 2)
     if (input.size == 2) {
-      val (joinConditions, others) = conditions.partition(
-        e => !SubqueryExpression.hasCorrelatedSubquery(e))
+      val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin)
       val ((left, leftJoinType), (right, rightJoinType)) = (input(0), input(1))
       val innerJoinType = (leftJoinType, rightJoinType) match {
         case (Inner, Inner) => Inner
@@ -75,7 +74,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
 
       val joinedRefs = left.outputSet ++ right.outputSet
       val (joinConditions, others) = conditions.partition(
-        e => e.references.subsetOf(joinedRefs) && !SubqueryExpression.hasCorrelatedSubquery(e))
+        e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e))
       val joined = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And))
 
       // should not have reference to same logical plan
-- 
GitLab