diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 0e3a8a6bd30a893da6997f86117944c5a88520b1..4544b32958c7e1d8f8de1d894db688460be7b361 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -129,11 +129,12 @@ object HashFilteredJoin extends Logging with PredicateHelper {
   // as join keys.
   def splitPredicates(allPredicates: Seq[Expression], join: Join): Option[ReturnType] = {
     val Join(left, right, joinType, _) = join
-    val (joinPredicates, otherPredicates) = allPredicates.partition {
-      case Equals(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) ||
-        (canEvaluate(l, right) && canEvaluate(r, left)) => true
-      case _ => false
-    }
+    val (joinPredicates, otherPredicates) =
+      allPredicates.flatMap(splitConjunctivePredicates).partition {
+        case Equals(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) ||
+          (canEvaluate(l, right) && canEvaluate(r, left)) => true
+        case _ => false
+      }
 
     val joinKeys = joinPredicates.map {
       case Equals(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index e24c74a7a557263c19e58d186eed7611eb4e45b4..c563d636275447305e1ee88cac619f57dc280e74 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -21,7 +21,7 @@ import org.scalatest.FunSuite
 
 import org.apache.spark.sql.TestData._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.execution
 import org.apache.spark.sql.test.TestSQLContext._
 import org.apache.spark.sql.test.TestSQLContext.planner._
@@ -57,4 +57,21 @@ class PlannerSuite extends FunSuite {
     val planned = PartialAggregation(query)
     assert(planned.isEmpty)
   }
+
+  test("equi-join is hash-join") {
+    val x = testData2.as('x)
+    val y = testData2.as('y)
+    val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr)).queryExecution.analyzed
+    val planned = planner.HashJoin(join)
+    assert(planned.size === 1)
+  }
+
+  test("multiple-key equi-join is hash-join") {
+    val x = testData2.as('x)
+    val y = testData2.as('y)
+    val join = x.join(y, Inner,
+      Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).queryExecution.analyzed
+    val planned = planner.HashJoin(join)
+    assert(planned.size === 1)
+  }
 }