diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 20b9351ca8d68220ef729a293a0b5260a4434ad4..877ab88d172facfef18848bbccf6897fe477f417 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -342,6 +342,15 @@ class SQLTests(ReusedPySparkTestCase):
         df = df.withColumn('b', udf(lambda x: 'x')(df.a))
         self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')])
 
+    def test_udf_in_filter_on_top_of_join(self):
+        # regression test for SPARK-18589
+        from pyspark.sql.functions import udf
+        left = self.spark.createDataFrame([Row(a=1)])
+        right = self.spark.createDataFrame([Row(b=1)])
+        f = udf(lambda a, b: a == b, BooleanType())
+        df = left.crossJoin(right).filter(f("a", "b"))
+        self.assertEqual(df.collect(), [Row(a=1, b=1)])
+
     def test_udf_without_arguments(self):
         self.spark.catalog.registerFunction("foo", lambda: "bar")
         [row] = self.spark.sql("SELECT foo()").collect()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 3fcbb05372d87107dd8c335319583c3f601b90b8..ac56ff13fa5bf45b97b96a0fd17a8e225e2e54a9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
-import org.apache.spark.util.Utils
 
 
 object InterpretedPredicate {
@@ -86,6 +85,18 @@ trait PredicateHelper {
    */
   protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean =
     expr.references.subsetOf(plan.outputSet)
+
+  /**
+   * Returns true iff `expr` could be evaluated as a condition within join.
+   */
+  protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match {
+    case e: SubqueryExpression =>
+      // non-correlated subquery will be replaced as literal
+      e.children.isEmpty
+    case a: AttributeReference => true
+    case e: Unevaluable => false
+    case e => e.children.forall(canEvaluateWithinJoin)
+  }
 }
 
 @ExpressionDescription(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index dfd66aac2dd44930b0b22291b29e33ec8965e285..06fcbcb4ae2b25c0443b0bea9668fe5a88c86c3b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -892,7 +892,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
           val newRight = rightFilterConditions.
             reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
           val (newJoinConditions, others) =
-            commonFilterCondition.partition(e => !SubqueryExpression.hasCorrelatedSubquery(e))
+            commonFilterCondition.partition(canEvaluateWithinJoin)
           val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And)
 
           val join = Join(newLeft, newRight, joinType, newJoinCond)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index 180ad2e0ad1fa52858217c38987451062c6a8c3a..bfe529e21e9ad915d3f1efa480110e9a6fca7859 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -46,8 +46,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
     : LogicalPlan = {
     assert(input.size >= 2)
     if (input.size == 2) {
-      val (joinConditions, others) = conditions.partition(
-        e => !SubqueryExpression.hasCorrelatedSubquery(e))
+      val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin)
       val ((left, leftJoinType), (right, rightJoinType)) = (input(0), input(1))
       val innerJoinType = (leftJoinType, rightJoinType) match {
         case (Inner, Inner) => Inner
@@ -75,7 +74,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
 
       val joinedRefs = left.outputSet ++ right.outputSet
       val (joinConditions, others) = conditions.partition(
-        e => e.references.subsetOf(joinedRefs) && !SubqueryExpression.hasCorrelatedSubquery(e))
+        e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e))
       val joined = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And))
 
       // should not have reference to same logical plan