diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index ad68512dccb7993d818f81dd0d0d5b3c1cf89410..4b9d59975bdc215dba7e54be65bb96b65a2a1e74 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -56,7 +56,7 @@ object Partitioner {
    */
   def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
     val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse
-    for (r <- bySize if r.partitioner.isDefined) {
+    for (r <- bySize if r.partitioner.isDefined && r.partitioner.get.numPartitions > 0) {
       return r.partitioner.get
     }
     if (rdd.context.conf.contains("spark.default.parallelism")) {
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index dfa102f432a029e167595b034c7323967c4cbd04..1321ec84735b5e48ae94666900fc8ff40c13eca6 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -282,6 +282,29 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
     ))
   }
 
+  // See SPARK-9326
+  test("cogroup with empty RDD") {
+    import scala.reflect.classTag
+    val intPairCT = classTag[(Int, Int)]
+
+    val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+    val rdd2 = sc.emptyRDD[(Int, Int)](intPairCT)
+
+    val joined = rdd1.cogroup(rdd2).collect()
+    assert(joined.size > 0)
+  }
+
+  // See SPARK-9326
+  test("cogroup with groupByed RDD having 0 partitions") {
+    import scala.reflect.classTag
+    val intCT = classTag[Int]
+
+    val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
+    val rdd2 = sc.emptyRDD[Int](intCT).groupBy((x) => 5)
+    val joined = rdd1.cogroup(rdd2).collect()
+    assert(joined.size > 0)
+  }
+
   test("rightOuterJoin") {
     val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
     val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))