diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index c19ed1529bbf658c3f21a48e8f3bbed5e79e5113..2ec9846e33f5a5e667c694e6bc6cc499c104630c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -169,6 +169,11 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10)
 
   var noLocality = true  // if true if no preferredLocations exists for parent RDD
 
+  // gets the *current* preferred locations from the DAGScheduler (as opposed to the static ones)
+  def currPrefLocs(part: Partition, prev: RDD[_]): Seq[String] = {
+    prev.context.getPreferredLocs(prev, part.index).map(tl => tl.host)
+  }
+
   class PartitionLocations(prev: RDD[_]) {
 
     // contains all the partitions from the previous RDD that don't have preferred locations
@@ -184,7 +189,7 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10)
       val tmpPartsWithLocs = mutable.LinkedHashMap[Partition, Seq[String]]()
       // first get the locations for each partition, only do this once since it can be expensive
       prev.partitions.foreach(p => {
-          val locs = prev.context.getPreferredLocs(prev, p.index).map(tl => tl.host)
+          val locs = currPrefLocs(p, prev)
           if (locs.size > 0) {
             tmpPartsWithLocs.put(p, locs)
           } else {
@@ -287,9 +292,8 @@ private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10)
       balanceSlack: Double,
       partitionLocs: PartitionLocations): PartitionGroup = {
     val slack = (balanceSlack * prev.partitions.length).toInt
-    val preflocs = partitionLocs.partsWithLocs.filter(_._2 == p).map(_._1).toSeq
     // least loaded pref locs
-    val pref = preflocs.map(getLeastGroupHash(_)).sortWith(compare) // least loaded pref locs
+    val pref = currPrefLocs(p, prev).map(getLeastGroupHash(_)).sortWith(compare)
     val prefPart = if (pref == Nil) None else pref.head
 
     val r1 = rnd.nextInt(groupArr.size)