diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 73a5df65e0ab3344902091d0c81e868b0d0b8e5e..4bfe6e9eb32432506f2c9a3bc46e0e69297c25a8 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -342,6 +342,15 @@ class SQLTests(ReusedPySparkTestCase):
         df = df.withColumn('b', udf(lambda x: 'x')(df.a))
         self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')])
 
+    def test_udf_in_filter_on_top_of_join(self):
+        # regression test for SPARK-18589
+        from pyspark.sql.functions import udf
+        left = self.spark.createDataFrame([Row(a=1)])
+        right = self.spark.createDataFrame([Row(b=1)])
+        f = udf(lambda a, b: a == b, BooleanType())
+        df = left.crossJoin(right).filter(f("a", "b"))
+        self.assertEqual(df.collect(), [Row(a=1, b=1)])
+
     def test_udf_without_arguments(self):
         self.spark.catalog.registerFunction("foo", lambda: "bar")
         [row] = self.spark.sql("SELECT foo()").collect()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 3fcbb05372d87107dd8c335319583c3f601b90b8..ac56ff13fa5bf45b97b96a0fd17a8e225e2e54a9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
-import org.apache.spark.util.Utils
 
 
 object InterpretedPredicate {
@@ -86,6 +85,18 @@ trait PredicateHelper {
    */
   protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean =
     expr.references.subsetOf(plan.outputSet)
+
+  /**
+   * Returns true iff `expr` could be evaluated as a condition within join.
+   */
+  protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match {
+    case e: SubqueryExpression =>
+      // non-correlated subquery will be replaced as literal
+      e.children.isEmpty
+    case a: AttributeReference => true
+    case e: Unevaluable => false
+    case e => e.children.forall(canEvaluateWithinJoin)
+  }
 }
 
 @ExpressionDescription(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 009c517ae4651a0434290599afed1ef9bc61ba24..20b3898f8a6fe207a87798d8ecc3ce980168b3f9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -893,7 +893,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
           val newRight = rightFilterConditions.
             reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
           val (newJoinConditions, others) =
-            commonFilterCondition.partition(e => !SubqueryExpression.hasCorrelatedSubquery(e))
+            commonFilterCondition.partition(canEvaluateWithinJoin)
           val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And)
 
           val join = Join(newLeft, newRight, joinType, newJoinCond)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index 180ad2e0ad1fa52858217c38987451062c6a8c3a..bfe529e21e9ad915d3f1efa480110e9a6fca7859 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -46,8 +46,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
     : LogicalPlan = {
     assert(input.size >= 2)
     if (input.size == 2) {
-      val (joinConditions, others) = conditions.partition(
-        e => !SubqueryExpression.hasCorrelatedSubquery(e))
+      val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin)
       val ((left, leftJoinType), (right, rightJoinType)) = (input(0), input(1))
       val innerJoinType = (leftJoinType, rightJoinType) match {
         case (Inner, Inner) => Inner
@@ -75,7 +74,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
 
       val joinedRefs = left.outputSet ++ right.outputSet
       val (joinConditions, others) = conditions.partition(
-        e => e.references.subsetOf(joinedRefs) && !SubqueryExpression.hasCorrelatedSubquery(e))
+        e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e))
       val joined = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And))
 
       // should not have reference to same logical plan
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
index 81bea2fef8bd449f830de3c15a1d3a065e5e0e02..2a3d1cf0b298a8f7ef5648a74209f2dd7085b0cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
@@ -21,7 +21,7 @@ import scala.collection.JavaConverters._
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.api.python.PythonFunction
-import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In}
+import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GreaterThan, In}
 import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec}
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.BooleanType
@@ -86,13 +86,11 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext {
   test("Python UDF refers to the attributes from more than one child") {
     val df = Seq(("Hello", 4)).toDF("a", "b")
     val df2 = Seq(("Hello", 4)).toDF("c", "d")
-    val joinDF = df.join(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)")
-
-    val e = intercept[RuntimeException] {
-      joinDF.queryExecution.executedPlan
-    }.getMessage
-    assert(Seq("Invalid PythonUDF dummyUDF", "requires attributes from more than one child")
-      .forall(e.contains))
+    val joinDF = df.crossJoin(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)")
+    val qualifiedPlanNodes = joinDF.queryExecution.executedPlan.collect {
+      case b: BatchEvalPythonExec => b
+    }
+    assert(qualifiedPlanNodes.size == 1)
   }
 }