diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 73a5df65e0ab3344902091d0c81e868b0d0b8e5e..4bfe6e9eb32432506f2c9a3bc46e0e69297c25a8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -342,6 +342,15 @@ class SQLTests(ReusedPySparkTestCase): df = df.withColumn('b', udf(lambda x: 'x')(df.a)) self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')]) + def test_udf_in_filter_on_top_of_join(self): + # regression test for SPARK-18589 + from pyspark.sql.functions import udf + left = self.spark.createDataFrame([Row(a=1)]) + right = self.spark.createDataFrame([Row(b=1)]) + f = udf(lambda a, b: a == b, BooleanType()) + df = left.crossJoin(right).filter(f("a", "b")) + self.assertEqual(df.collect(), [Row(a=1, b=1)]) + def test_udf_without_arguments(self): self.spark.catalog.registerFunction("foo", lambda: "bar") [row] = self.spark.sql("SELECT foo()").collect() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 3fcbb05372d87107dd8c335319583c3f601b90b8..ac56ff13fa5bf45b97b96a0fd17a8e225e2e54a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils object InterpretedPredicate { @@ -86,6 +85,18 @@ trait PredicateHelper { */ protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean = expr.references.subsetOf(plan.outputSet) + + /** + * Returns true iff `expr` could be evaluated as a condition within join. + */ + protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match { + case e: SubqueryExpression => + // non-correlated subquery will be replaced as literal + e.children.isEmpty + case a: AttributeReference => true + case e: Unevaluable => false + case e => e.children.forall(canEvaluateWithinJoin) + } } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 009c517ae4651a0434290599afed1ef9bc61ba24..20b3898f8a6fe207a87798d8ecc3ce980168b3f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -893,7 +893,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newRight = rightFilterConditions. reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) val (newJoinConditions, others) = - commonFilterCondition.partition(e => !SubqueryExpression.hasCorrelatedSubquery(e)) + commonFilterCondition.partition(canEvaluateWithinJoin) val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And) val join = Join(newLeft, newRight, joinType, newJoinCond) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 180ad2e0ad1fa52858217c38987451062c6a8c3a..bfe529e21e9ad915d3f1efa480110e9a6fca7859 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -46,8 +46,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { : LogicalPlan = { assert(input.size >= 2) if (input.size == 2) { - val (joinConditions, others) = conditions.partition( - e => !SubqueryExpression.hasCorrelatedSubquery(e)) + val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin) val ((left, leftJoinType), (right, rightJoinType)) = (input(0), input(1)) val innerJoinType = (leftJoinType, rightJoinType) match { case (Inner, Inner) => Inner @@ -75,7 +74,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { val joinedRefs = left.outputSet ++ right.outputSet val (joinConditions, others) = conditions.partition( - e => e.references.subsetOf(joinedRefs) && !SubqueryExpression.hasCorrelatedSubquery(e)) + e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e)) val joined = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And)) // should not have reference to same logical plan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 81bea2fef8bd449f830de3c15a1d3a065e5e0e02..2a3d1cf0b298a8f7ef5648a74209f2dd7085b0cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.api.python.PythonFunction -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GreaterThan, In} import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.BooleanType @@ -86,13 +86,11 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { test("Python UDF refers to the attributes from more than one child") { val df = Seq(("Hello", 4)).toDF("a", "b") val df2 = Seq(("Hello", 4)).toDF("c", "d") - val joinDF = df.join(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)") - - val e = intercept[RuntimeException] { - joinDF.queryExecution.executedPlan - }.getMessage - assert(Seq("Invalid PythonUDF dummyUDF", "requires attributes from more than one child") - .forall(e.contains)) + val joinDF = df.crossJoin(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)") + val qualifiedPlanNodes = joinDF.queryExecution.executedPlan.collect { + case b: BatchEvalPythonExec => b + } + assert(qualifiedPlanNodes.size == 1) } }