diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 20b9351ca8d68220ef729a293a0b5260a4434ad4..877ab88d172facfef18848bbccf6897fe477f417 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -342,6 +342,15 @@ class SQLTests(ReusedPySparkTestCase): df = df.withColumn('b', udf(lambda x: 'x')(df.a)) self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')]) + def test_udf_in_filter_on_top_of_join(self): + # regression test for SPARK-18589 + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1)]) + right = self.spark.createDataFrame([Row(b=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.crossJoin(right).filter(f("a", "b")) + self.assertEqual(df.collect(), [Row(a=1, b=1)]) + def test_udf_without_arguments(self): self.spark.catalog.registerFunction("foo", lambda: "bar") [row] = self.spark.sql("SELECT foo()").collect() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 3fcbb05372d87107dd8c335319583c3f601b90b8..ac56ff13fa5bf45b97b96a0fd17a8e225e2e54a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils object InterpretedPredicate { @@ -86,6 +85,18 @@ trait PredicateHelper { */ protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean = expr.references.subsetOf(plan.outputSet) + + /** + * Returns true iff `expr` could be evaluated as a condition within join. + */ + protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match { + case e: SubqueryExpression => + // non-correlated subquery will be replaced as literal + e.children.isEmpty + case a: AttributeReference => true + case e: Unevaluable => false + case e => e.children.forall(canEvaluateWithinJoin) + } } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index dfd66aac2dd44930b0b22291b29e33ec8965e285..06fcbcb4ae2b25c0443b0bea9668fe5a88c86c3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -892,7 +892,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newRight = rightFilterConditions. reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val (newJoinConditions, others) = - commonFilterCondition.partition(e => !SubqueryExpression.hasCorrelatedSubquery(e)) + commonFilterCondition.partition(canEvaluateWithinJoin) val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And) val join = Join(newLeft, newRight, joinType, newJoinCond) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 180ad2e0ad1fa52858217c38987451062c6a8c3a..bfe529e21e9ad915d3f1efa480110e9a6fca7859 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -46,8 +46,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { : LogicalPlan = { assert(input.size >= 2) if (input.size == 2) { - val (joinConditions, others) = conditions.partition( - e => !SubqueryExpression.hasCorrelatedSubquery(e)) + val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin) val ((left, leftJoinType), (right, rightJoinType)) = (input(0), input(1)) val innerJoinType = (leftJoinType, rightJoinType) match { case (Inner, Inner) => Inner @@ -75,7 +74,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { val joinedRefs = left.outputSet ++ right.outputSet val (joinConditions, others) = conditions.partition( - e => e.references.subsetOf(joinedRefs) && !SubqueryExpression.hasCorrelatedSubquery(e)) + e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e)) val joined = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And)) // should not have reference to same logical plan