diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index 98c3abe93b55354e2b84e0ec713e83bb94d2f609..93dfbc0e6ed652fa1478f7dbef6352796d2b6122 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -55,14 +55,16 @@ object Partitioner {
    * We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD.
    */
   def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
-    val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.length).reverse
-    for (r <- bySize if r.partitioner.isDefined && r.partitioner.get.numPartitions > 0) {
-      return r.partitioner.get
-    }
-    if (rdd.context.conf.contains("spark.default.parallelism")) {
-      new HashPartitioner(rdd.context.defaultParallelism)
+    val rdds = (Seq(rdd) ++ others)
+    val hasPartitioner = rdds.filter(_.partitioner.exists(_.numPartitions > 0))
+    if (hasPartitioner.nonEmpty) {
+      hasPartitioner.maxBy(_.partitions.length).partitioner.get
     } else {
-      new HashPartitioner(bySize.head.partitions.length)
+      if (rdd.context.conf.contains("spark.default.parallelism")) {
+        new HashPartitioner(rdd.context.defaultParallelism)
+      } else {
+        new HashPartitioner(rdds.map(_.partitions.length).max)
+      }
     }
   }
 }