diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala
index 191cfde5651b53457ac6be208fc7e7d7e4527b88..d8700becb0e870ff5242c77219b53c777771c210 100644
--- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala
+++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala
@@ -33,8 +33,9 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo
   extends NarrowDependency[T](rdd) {
 
   @transient
-  val partitions: Array[Partition] = rdd.partitions.filter(s => partitionFilterFunc(s.index))
-    .zipWithIndex.map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition }
+  val partitions: Array[Partition] = rdd.partitions.zipWithIndex
+    .filter(s => partitionFilterFunc(s._2))
+    .map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition }
 
   override def getParents(partitionId: Int) = List(partitions(partitionId).index)
 }
diff --git a/core/src/test/scala/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/spark/PartitionPruningRDDSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..88352b639f9b1d953246bc438a82a0e364a7f1bf
--- /dev/null
+++ b/core/src/test/scala/spark/PartitionPruningRDDSuite.scala
@@ -0,0 +1,28 @@
+package spark
+
+import org.scalatest.FunSuite
+import spark.SparkContext._
+import spark.rdd.PartitionPruningRDD
+
+
+class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext {
+
+  test("Pruned Partitions inherit locality prefs correctly") {
+    class TestPartition(i: Int) extends Partition {
+      def index = i
+    }
+    val rdd = new RDD[Int](sc, Nil) {
+      override protected def getPartitions = {
+        Array[Partition](
+            new TestPartition(1),
+            new TestPartition(2), 
+            new TestPartition(3))
+      }
+      def compute(split: Partition, context: TaskContext) = {Iterator()}
+    }
+    val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false})
+    val p = prunedRDD.partitions(0)
+    assert(p.index == 2)
+    assert(prunedRDD.partitions.length == 1)
+  }
+}