diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 8f78fc5a416294faa0aa03c8e88b565b90f2665e..4c54ba4bce4082238f6b74127c154df886012eea 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -138,6 +138,11 @@ public final class UnsafeExternalSorter {
       this.inMemSorter = existingInMemorySorter;
     }
 
+    // Acquire a new page as soon as we construct the sorter to ensure that we have at
+    // least one page to work with. Otherwise, other operators in the same task may starve
+    // this sorter (SPARK-9709).
+    acquireNewPage();
+
     // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
     // the end of the task. This is necessary to avoid memory leaks in when the downstream operator
     // does not fully consume the sorter's output (e.g. sort followed by limit).
@@ -343,22 +348,32 @@ public final class UnsafeExternalSorter {
         throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
           pageSizeBytes + ")");
       } else {
-        final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
-        if (memoryAcquired < pageSizeBytes) {
-          shuffleMemoryManager.release(memoryAcquired);
-          spill();
-          final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
-          if (memoryAcquiredAfterSpilling != pageSizeBytes) {
-            shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
-            throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
-          }
-        }
-        currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
-        currentPagePosition = currentPage.getBaseOffset();
-        freeSpaceInCurrentPage = pageSizeBytes;
-        allocatedPages.add(currentPage);
+        acquireNewPage();
+      }
+    }
+  }
+
+  /**
+   * Acquire a new page from the {@link ShuffleMemoryManager}.
+   *
+   * If there is not enough space to allocate the new page, spill all existing ones
+   * and try again. If there is still not enough space, report error to the caller.
+   */
+  private void acquireNewPage() throws IOException {
+    final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+    if (memoryAcquired < pageSizeBytes) {
+      shuffleMemoryManager.release(memoryAcquired);
+      spill();
+      final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+      if (memoryAcquiredAfterSpilling != pageSizeBytes) {
+        shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
+        throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
       }
     }
+    currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
+    currentPagePosition = currentPage.getBaseOffset();
+    freeSpaceInCurrentPage = pageSizeBytes;
+    allocatedPages.add(currentPage);
   }
 
   /**
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
index a838aac6e8d1a17008571ebb1d89039e0c5e61b6..4312d3a417759239a63ebeba52b6e129b42b4abb 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
@@ -21,6 +21,9 @@ import scala.reflect.ClassTag
 
 import org.apache.spark.{Partition, TaskContext}
 
+/**
+ * An RDD that applies the provided function to every partition of the parent RDD.
+ */
 private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
     prev: RDD[T],
     f: (TaskContext, Int, Iterator[T]) => Iterator[U],  // (TaskContext, partition index, iterator)
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala
new file mode 100644
index 0000000000000000000000000000000000000000..b475bd8d79f8586e626e69a8dd68773d560ba96e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.{Partition, Partitioner, TaskContext}
+
+/**
+ * An RDD that applies a user provided function to every partition of the parent RDD, and
+ * additionally allows the user to prepare each partition before computing the parent partition.
+ */
+private[spark] class MapPartitionsWithPreparationRDD[U: ClassTag, T: ClassTag, M: ClassTag](
+    prev: RDD[T],
+    preparePartition: () => M,
+    executePartition: (TaskContext, Int, M, Iterator[T]) => Iterator[U],
+    preservesPartitioning: Boolean = false)
+  extends RDD[U](prev) {
+
+  override val partitioner: Option[Partitioner] = {
+    if (preservesPartitioning) firstParent[T].partitioner else None
+  }
+
+  override def getPartitions: Array[Partition] = firstParent[T].partitions
+
+  /**
+   * Prepare a partition before computing it from its parent.
+   */
+  override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
+    val preparedArgument = preparePartition()
+    val parentIterator = firstParent[T].iterator(partition, context)
+    executePartition(context, partition.index, preparedArgument, parentIterator)
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
index 00c1e078a441c75ccd51231c8e42c2d3b00af56f..e3d229cc9982147402f558b8bb3eef68fe889eea 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
@@ -124,7 +124,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
   }
 }
 
-private object ShuffleMemoryManager {
+private[spark] object ShuffleMemoryManager {
   /**
    * Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction
    * of the memory pool and a safety factor since collections can sometimes grow bigger than
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 117745f9a9c00f63d31764f11b77eecdfad227d9..f5300373d87ea1b6f2d99029c478072d06827a00 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -340,7 +340,8 @@ public class UnsafeExternalSorterSuite {
       for (int i = 0; i < numRecordsPerPage * 10; i++) {
         insertNumber(sorter, i);
         newPeakMemory = sorter.getPeakMemoryUsedBytes();
-        if (i % numRecordsPerPage == 0) {
+        // The first page is pre-allocated on instantiation
+        if (i % numRecordsPerPage == 0 && i > 0) {
           // We allocated a new page for this record, so peak memory should change
           assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
         } else {
@@ -364,5 +365,21 @@ public class UnsafeExternalSorterSuite {
     }
   }
 
+  @Test
+  public void testReservePageOnInstantiation() throws Exception {
+    final UnsafeExternalSorter sorter = newSorter();
+    try {
+      assertEquals(1, sorter.getNumberOfAllocatedPages());
+      // Inserting a new record doesn't allocate more memory since we already have a page
+      long peakMemory = sorter.getPeakMemoryUsedBytes();
+      insertNumber(sorter, 100);
+      assertEquals(peakMemory, sorter.getPeakMemoryUsedBytes());
+      assertEquals(1, sorter.getNumberOfAllocatedPages());
+    } finally {
+      sorter.cleanupResources();
+      assertSpillFilesWereCleanedUp();
+    }
+  }
+
 }
 
diff --git a/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..c16930e7d649182a5edae2099a93bc072256062e
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.collection.mutable
+
+import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite, TaskContext}
+
+class MapPartitionsWithPreparationRDDSuite extends SparkFunSuite with LocalSparkContext {
+
+  test("prepare called before parent partition is computed") {
+    sc = new SparkContext("local", "test")
+
+    // Have the parent partition push a number to the list
+    val parent = sc.parallelize(1 to 100, 1).mapPartitions { iter =>
+      TestObject.things.append(20)
+      iter
+    }
+
+    // Push a different number during the prepare phase
+    val preparePartition = () => { TestObject.things.append(10) }
+
+    // Push yet another number during the execution phase
+    val executePartition = (
+        taskContext: TaskContext,
+        partitionIndex: Int,
+        notUsed: Unit,
+        parentIterator: Iterator[Int]) => {
+      TestObject.things.append(30)
+      TestObject.things.iterator
+    }
+
+    // Verify that the numbers are pushed in the order expected
+    val result = {
+      new MapPartitionsWithPreparationRDD[Int, Int, Unit](
+        parent, preparePartition, executePartition).collect()
+    }
+    assert(result === Array(10, 20, 30))
+  }
+
+}
+
+private object TestObject {
+  val things = new mutable.ListBuffer[Int]
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 2f29067f5646a45f80989ad1e87201ecf547b31d..490428965a61d643a6c423e0967ed839ff72d091 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -158,7 +158,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
    */
   final def prepare(): Unit = {
     if (prepareCalled.compareAndSet(false, true)) {
-      doPrepare
+      doPrepare()
       children.foreach(_.prepare())
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
index 3192b6ebe9075dbbe651530cbbe34bbfa8b930fa..7f69cdb08aa7898f190e362346acc64c0943022a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.{InternalAccumulator, TaskContext}
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{MapPartitionsWithPreparationRDD, RDD}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
@@ -123,7 +123,12 @@ case class TungstenSort(
     val schema = child.schema
     val childOutput = child.output
     val pageSize = sparkContext.conf.getSizeAsBytes("spark.buffer.pageSize", "64m")
-    child.execute().mapPartitions({ iter =>
+
+    /**
+     * Set up the sorter in each partition before computing the parent partition.
+     * This makes sure our sorter is not starved by other sorters used in the same task.
+     */
+    def preparePartition(): UnsafeExternalRowSorter = {
       val ordering = newOrdering(sortOrder, childOutput)
 
       // The comparator for comparing prefix
@@ -143,12 +148,25 @@ case class TungstenSort(
       if (testSpillFrequency > 0) {
         sorter.setTestSpillFrequency(testSpillFrequency)
       }
-      val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
-      val taskContext = TaskContext.get()
+      sorter
+    }
+
+    /** Compute a partition using the sorter already set up previously. */
+    def executePartition(
+        taskContext: TaskContext,
+        partitionIndex: Int,
+        sorter: UnsafeExternalRowSorter,
+        parentIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
+      val sortedIterator = sorter.sort(parentIterator.asInstanceOf[Iterator[UnsafeRow]])
       taskContext.internalMetricsToAccumulators(
         InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage)
       sortedIterator
-    }, preservesPartitioning = true)
+    }
+
+    // Note: we need to set up the external sorter in each partition before computing
+    // the parent partition, so we cannot simply use `mapPartitions` here (SPARK-9709).
+    new MapPartitionsWithPreparationRDD[InternalRow, InternalRow, UnsafeExternalRowSorter](
+      child.execute(), preparePartition, executePartition, preservesPartitioning = true)
   }
 
 }