diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index e75f1dbf8107a991902ab05aafb526829d3ca5d8..c19ed1529bbf658c3f21a48e8f3bbed5e79e5113 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -169,42 +169,37 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10)
 
   var noLocality = true  // if true if no preferredLocations exists for parent RDD
 
-  // gets the *current* preferred locations from the DAGScheduler (as opposed to the static ones)
-  def currPrefLocs(part: Partition, prev: RDD[_]): Seq[String] = {
-    prev.context.getPreferredLocs(prev, part.index).map(tl => tl.host)
-  }
-
-  // this class just keeps iterating and rotating infinitely over the partitions of the RDD
-  // next() returns the next preferred machine that a partition is replicated on
-  // the rotator first goes through the first replica copy of each partition, then second, third
-  // the iterators return type is a tuple: (replicaString, partition)
-  class LocationIterator(prev: RDD[_]) extends Iterator[(String, Partition)] {
-
-    var it: Iterator[(String, Partition)] = resetIterator()
-
-    override val isEmpty = !it.hasNext
-
-    // initializes/resets to start iterating from the beginning
-    def resetIterator(): Iterator[(String, Partition)] = {
-      val iterators = (0 to 2).map { x =>
-        prev.partitions.iterator.flatMap { p =>
-          if (currPrefLocs(p, prev).size > x) Some((currPrefLocs(p, prev)(x), p)) else None
+  class PartitionLocations(prev: RDD[_]) {
+
+    // contains all the partitions from the previous RDD that don't have preferred locations
+    val partsWithoutLocs = ArrayBuffer[Partition]()
+    // contains all the partitions from the previous RDD that have preferred locations
+    val partsWithLocs = ArrayBuffer[(String, Partition)]()
+
+    getAllPrefLocs(prev)
+
+    // gets all the preffered locations of the previous RDD and splits them into partitions
+    // with preferred locations and ones without
+    def getAllPrefLocs(prev: RDD[_]) {
+      val tmpPartsWithLocs = mutable.LinkedHashMap[Partition, Seq[String]]()
+      // first get the locations for each partition, only do this once since it can be expensive
+      prev.partitions.foreach(p => {
+          val locs = prev.context.getPreferredLocs(prev, p.index).map(tl => tl.host)
+          if (locs.size > 0) {
+            tmpPartsWithLocs.put(p, locs)
+          } else {
+            partsWithoutLocs += p
+          }
         }
-      }
-      iterators.reduceLeft((x, y) => x ++ y)
-    }
-
-    // hasNext() is false iff there are no preferredLocations for any of the partitions of the RDD
-    override def hasNext: Boolean = { !isEmpty }
-
-    // return the next preferredLocation of some partition of the RDD
-    override def next(): (String, Partition) = {
-      if (it.hasNext) {
-        it.next()
-      } else {
-        it = resetIterator() // ran out of preferred locations, reset and rotate to the beginning
-        it.next()
-      }
+      )
+      // convert it into an array of host to partition
+      (0 to 2).map(x =>
+        tmpPartsWithLocs.foreach(parts => {
+          val p = parts._1
+          val locs = parts._2
+          if (locs.size > x) partsWithLocs += ((locs(x), p))
+        } )
+      )
     }
   }
 
@@ -228,33 +223,32 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10)
   }
 
   /**
-   * Initializes targetLen partition groups and assigns a preferredLocation
-   * This uses coupon collector to estimate how many preferredLocations it must rotate through
-   * until it has seen most of the preferred locations (2 * n log(n))
+   * Initializes targetLen partition groups. If there are preferred locations, each group
+   * is assigned a preferredLocation. This uses coupon collector to estimate how many
+   * preferredLocations it must rotate through until it has seen most of the preferred
+   * locations (2 * n log(n))
    * @param targetLen
    */
-  def setupGroups(targetLen: Int, prev: RDD[_]) {
-    val rotIt = new LocationIterator(prev)
-
+  def setupGroups(targetLen: Int, partitionLocs: PartitionLocations) {
     // deal with empty case, just create targetLen partition groups with no preferred location
-    if (!rotIt.hasNext) {
+    if (partitionLocs.partsWithLocs.isEmpty) {
       (1 to targetLen).foreach(x => groupArr += new PartitionGroup())
       return
     }
 
     noLocality = false
-
     // number of iterations needed to be certain that we've seen most preferred locations
     val expectedCoupons2 = 2 * (math.log(targetLen)*targetLen + targetLen + 0.5).toInt
     var numCreated = 0
     var tries = 0
 
     // rotate through until either targetLen unique/distinct preferred locations have been created
-    // OR we've rotated expectedCoupons2, in which case we have likely seen all preferred locations,
-    // i.e. likely targetLen >> number of preferred locations (more buckets than there are machines)
-    while (numCreated < targetLen && tries < expectedCoupons2) {
+    // OR (we have went through either all partitions OR we've rotated expectedCoupons2 - in
+    // which case we have likely seen all preferred locations)
+    val numPartsToLookAt = math.min(expectedCoupons2, partitionLocs.partsWithLocs.length)
+    while (numCreated < targetLen && tries < numPartsToLookAt) {
+      val (nxt_replica, nxt_part) = partitionLocs.partsWithLocs(tries)
       tries += 1
-      val (nxt_replica, nxt_part) = rotIt.next()
       if (!groupHash.contains(nxt_replica)) {
         val pgroup = new PartitionGroup(Some(nxt_replica))
         groupArr += pgroup
@@ -263,20 +257,18 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10)
         numCreated += 1
       }
     }
-
-    while (numCreated < targetLen) {  // if we don't have enough partition groups, create duplicates
-      var (nxt_replica, nxt_part) = rotIt.next()
+    tries = 0
+    // if we don't have enough partition groups, create duplicates
+    while (numCreated < targetLen) {
+      var (nxt_replica, nxt_part) = partitionLocs.partsWithLocs(tries)
+      tries += 1
       val pgroup = new PartitionGroup(Some(nxt_replica))
       groupArr += pgroup
       groupHash.getOrElseUpdate(nxt_replica, ArrayBuffer()) += pgroup
-      var tries = 0
-      while (!addPartToPGroup(nxt_part, pgroup) && tries < targetLen) { // ensure at least one part
-        nxt_part = rotIt.next()._2
-        tries += 1
-      }
+      addPartToPGroup(nxt_part, pgroup)
       numCreated += 1
+      if (tries >= partitionLocs.partsWithLocs.length) tries = 0
     }
-
   }
 
   /**
@@ -289,10 +281,15 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10)
    *                     imbalance in favor of locality
    * @return partition group (bin to be put in)
    */
-  def pickBin(p: Partition, prev: RDD[_], balanceSlack: Double): PartitionGroup = {
+  def pickBin(
+      p: Partition,
+      prev: RDD[_],
+      balanceSlack: Double,
+      partitionLocs: PartitionLocations): PartitionGroup = {
     val slack = (balanceSlack * prev.partitions.length).toInt
+    val preflocs = partitionLocs.partsWithLocs.filter(_._2 == p).map(_._1).toSeq
     // least loaded pref locs
-    val pref = currPrefLocs(p, prev).map(getLeastGroupHash(_)).sortWith(compare)
+    val pref = preflocs.map(getLeastGroupHash(_)).sortWith(compare) // least loaded pref locs
     val prefPart = if (pref == Nil) None else pref.head
 
     val r1 = rnd.nextInt(groupArr.size)
@@ -320,7 +317,10 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10)
     }
   }
 
-  def throwBalls(maxPartitions: Int, prev: RDD[_], balanceSlack: Double) {
+  def throwBalls(
+      maxPartitions: Int,
+      prev: RDD[_],
+      balanceSlack: Double, partitionLocs: PartitionLocations) {
     if (noLocality) {  // no preferredLocations in parent RDD, no randomization needed
       if (maxPartitions > groupArr.size) { // just return prev.partitions
         for ((p, i) <- prev.partitions.zipWithIndex) {
@@ -334,8 +334,39 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10)
         }
       }
     } else {
+      // It is possible to have unionRDD where one rdd has preferred locations and another rdd
+      // that doesn't. To make sure we end up with the requested number of partitions,
+      // make sure to put a partition in every group.
+
+      // if we don't have a partition assigned to every group first try to fill them
+      // with the partitions with preferred locations
+      val partIter = partitionLocs.partsWithLocs.iterator
+      groupArr.filter(pg => pg.numPartitions == 0).foreach { pg =>
+        while (partIter.hasNext && pg.numPartitions == 0) {
+          var (nxt_replica, nxt_part) = partIter.next()
+          if (!initialHash.contains(nxt_part)) {
+            pg.partitions += nxt_part
+            initialHash += nxt_part
+          }
+        }
+      }
+
+      // if we didn't get one partitions per group from partitions with preferred locations
+      // use partitions without preferred locations
+      val partNoLocIter = partitionLocs.partsWithoutLocs.iterator
+      groupArr.filter(pg => pg.numPartitions == 0).foreach { pg =>
+        while (partNoLocIter.hasNext && pg.numPartitions == 0) {
+          var nxt_part = partNoLocIter.next()
+          if (!initialHash.contains(nxt_part)) {
+            pg.partitions += nxt_part
+            initialHash += nxt_part
+          }
+        }
+      }
+
+      // finally pick bin for the rest
       for (p <- prev.partitions if (!initialHash.contains(p))) { // throw every partition into group
-        pickBin(p, prev, balanceSlack).partitions += p
+        pickBin(p, prev, balanceSlack, partitionLocs).partitions += p
       }
     }
   }
@@ -349,8 +380,11 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10)
     * @return array of partition groups
    */
   def coalesce(maxPartitions: Int, prev: RDD[_]): Array[PartitionGroup] = {
-    setupGroups(math.min(prev.partitions.length, maxPartitions), prev)   // setup the groups (bins)
-    throwBalls(maxPartitions, prev, balanceSlack) // assign partitions (balls) to each group (bins)
+    val partitionLocs = new PartitionLocations(prev)
+    // setup the groups (bins)
+    setupGroups(math.min(prev.partitions.length, maxPartitions), partitionLocs)
+    // assign partitions (balls) to each group (bins)
+    throwBalls(maxPartitions, prev, balanceSlack, partitionLocs)
     getPartitions
   }
 }
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 8dc463d56d18297167306a149fa262bcffa12180..a663dab772bf9a23d72da1bcace78c1ca6265af8 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -377,6 +377,33 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
       map{x => List(x)}.toList, "Tried coalescing 9 partitions to 20 but didn't get 9 back")
   }
 
+ test("coalesced RDDs with partial locality") {
+    // Make an RDD that has some locality preferences and some without. This can happen
+    // with UnionRDD
+    val data = sc.makeRDD((1 to 9).map(i => {
+      if (i > 4) {
+        (i, (i to (i + 2)).map { j => "m" + (j % 6) })
+      } else {
+        (i, Vector())
+      }
+    }))
+    val coalesced1 = data.coalesce(3)
+    assert(coalesced1.collect().toList.sorted === (1 to 9).toList, "Data got *lost* in coalescing")
+
+    val splits = coalesced1.glom().collect().map(_.toList).toList
+    assert(splits.length === 3, "Supposed to coalesce to 3 but got " + splits.length)
+
+    assert(splits.forall(_.length >= 1) === true, "Some partitions were empty")
+
+    // If we try to coalesce into more partitions than the original RDD, it should just
+    // keep the original number of partitions.
+    val coalesced4 = data.coalesce(20)
+    val listOfLists = coalesced4.glom().collect().map(_.toList).toList
+    val sortedList = listOfLists.sortWith{ (x, y) => !x.isEmpty && (y.isEmpty || (x(0) < y(0))) }
+    assert(sortedList === (1 to 9).
+      map{x => List(x)}.toList, "Tried coalescing 9 partitions to 20 but didn't get 9 back")
+  }
+
   test("coalesced RDDs with locality, large scale (10K partitions)") {
     // large scale experiment
     import collection.mutable
@@ -418,6 +445,48 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
     }
   }
 
+  test("coalesced RDDs with partial locality, large scale (10K partitions)") {
+    // large scale experiment
+    import collection.mutable
+    val halfpartitions = 5000
+    val partitions = 10000
+    val numMachines = 50
+    val machines = mutable.ListBuffer[String]()
+    (1 to numMachines).foreach(machines += "m" + _)
+    val rnd = scala.util.Random
+    for (seed <- 1 to 5) {
+      rnd.setSeed(seed)
+
+      val firstBlocks = (1 to halfpartitions).map { i =>
+        (i, Array.fill(3)(machines(rnd.nextInt(machines.size))).toList)
+      }
+      val blocksNoLocality = (halfpartitions + 1 to partitions).map { i =>
+        (i, List())
+      }
+      val blocks = firstBlocks ++ blocksNoLocality
+
+      val data2 = sc.makeRDD(blocks)
+
+      // first try going to same number of partitions
+      val coalesced2 = data2.coalesce(partitions)
+
+      // test that we have 10000 partitions
+      assert(coalesced2.partitions.size == 10000, "Expected 10000 partitions, but got " +
+        coalesced2.partitions.size)
+
+      // test that we have 100 partitions
+      val coalesced3 = data2.coalesce(numMachines * 2)
+      assert(coalesced3.partitions.size == 100, "Expected 100 partitions, but got " +
+        coalesced3.partitions.size)
+
+      // test that the groups are load balanced with 100 +/- 20 elements in each
+      val maxImbalance3 = coalesced3.partitions
+        .map(part => part.asInstanceOf[CoalescedRDDPartition].parents.size)
+        .foldLeft(0)((dev, curr) => math.max(math.abs(100 - curr), dev))
+      assert(maxImbalance3 <= 20, "Expected 100 +/- 20 per partition, but got " + maxImbalance3)
+    }
+  }
+
   // Test for SPARK-2412 -- ensure that the second pass of the algorithm does not throw an exception
   test("coalesced RDDs with locality, fail first pass") {
     val initialPartitions = 1000