diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 598953ac3bcc85e4469c8aaa3ade0fcdb4b9482f..55e563ee968bedbe8d2ee438439642b528ecf88b 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -207,6 +207,7 @@ private[spark] class PythonRDD(
 
     override def run(): Unit = Utils.logUncaughtExceptions {
       try {
+        TaskContext.setTaskContext(context)
         val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
         val dataOut = new DataOutputStream(stream)
         // Partition index
@@ -263,11 +264,6 @@ private[spark] class PythonRDD(
           if (!worker.isClosed) {
             Utils.tryLog(worker.shutdownOutput())
           }
-      } finally {
-        // Release memory used by this thread for shuffles
-        env.shuffleMemoryManager.releaseMemoryForThisThread()
-        // Release memory used by this thread for unrolling blocks
-        env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
       }
     }
   }
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index 23a470d6afcae85cc61e4e7fe8d3f533495820e3..1cf2824f862ee0a7d645783045d4c603cf1c63aa 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -112,6 +112,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
     partition: Int): Unit = {
 
     val env = SparkEnv.get
+    val taskContext = TaskContext.get()
     val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
     val stream = new BufferedOutputStream(output, bufferSize)
 
@@ -119,6 +120,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
       override def run(): Unit = {
         try {
           SparkEnv.set(env)
+          TaskContext.setTaskContext(taskContext)
           val dataOut = new DataOutputStream(stream)
           dataOut.writeInt(partition)
 
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index e76664f1bd7b00f88a36ef483a06c9cda4622a3f..7bc7fce7ae8dd18ac527ba76baba4ff0390e8ed5 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -313,10 +313,6 @@ private[spark] class Executor(
           }
 
       } finally {
-        // Release memory used by this thread for shuffles
-        env.shuffleMemoryManager.releaseMemoryForThisThread()
-        // Release memory used by this thread for unrolling blocks
-        env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
         runningTasks.remove(taskId)
       }
     }
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index defdabf95ac4b79f781271e9e9a1fd251afa07c6..3bb9998e1db44350426d4fc17ad6b07abf84909a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -133,6 +133,7 @@ private[spark] class PipedRDD[T: ClassTag](
     // Start a thread to feed the process input from our parent's iterator
     new Thread("stdin writer for " + command) {
       override def run() {
+        TaskContext.setTaskContext(context)
         val out = new PrintWriter(proc.getOutputStream)
 
         // scalastyle:off println
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index d11a00956a9a91caeceb1b4ded418620b79561f6..1978305cfefbdefc2e81fe15a953dcc30e3c3341 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -23,7 +23,7 @@ import java.nio.ByteBuffer
 import scala.collection.mutable.HashMap
 
 import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.{TaskContextImpl, TaskContext}
+import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext}
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.serializer.SerializerInstance
 import org.apache.spark.unsafe.memory.TaskMemoryManager
@@ -86,7 +86,18 @@ private[spark] abstract class Task[T](
       (runTask(context), context.collectAccumulators())
     } finally {
       context.markTaskCompleted()
-      TaskContext.unset()
+      try {
+        Utils.tryLogNonFatalError {
+          // Release memory used by this thread for shuffles
+          SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask()
+        }
+        Utils.tryLogNonFatalError {
+          // Release memory used by this thread for unrolling blocks
+          SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask()
+        }
+      } finally {
+        TaskContext.unset()
+      }
     }
   }
 
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
index 3bcc7178a3d8b2bf50c3a810e8e7c566b19f7b82..f038b722957b8abac85e4ab0ce56d305938578b1 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
@@ -19,95 +19,101 @@ package org.apache.spark.shuffle
 
 import scala.collection.mutable
 
-import org.apache.spark.{Logging, SparkException, SparkConf}
+import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext}
 
 /**
- * Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling
+ * Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling
  * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory
  * from this pool and release it as it spills data out. When a task ends, all its memory will be
  * released by the Executor.
  *
- * This class tries to ensure that each thread gets a reasonable share of memory, instead of some
- * thread ramping up to a large amount first and then causing others to spill to disk repeatedly.
- * If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory
+ * This class tries to ensure that each task gets a reasonable share of memory, instead of some
+ * task ramping up to a large amount first and then causing others to spill to disk repeatedly.
+ * If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory
  * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the
- * set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever
+ * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever
  * this set changes. This is all done by synchronizing access on "this" to mutate state and using
  * wait() and notifyAll() to signal changes.
  */
 private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
-  private val threadMemory = new mutable.HashMap[Long, Long]()  // threadId -> memory bytes
+  private val taskMemory = new mutable.HashMap[Long, Long]()  // taskAttemptId -> memory bytes
 
   def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf))
 
+  private def currentTaskAttemptId(): Long = {
+    // In case this is called on the driver, return an invalid task attempt id.
+    Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L)
+  }
+
   /**
-   * Try to acquire up to numBytes memory for the current thread, and return the number of bytes
+   * Try to acquire up to numBytes memory for the current task, and return the number of bytes
    * obtained, or 0 if none can be allocated. This call may block until there is enough free memory
-   * in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the
-   * total memory pool (where N is the # of active threads) before it is forced to spill. This can
-   * happen if the number of threads increases but an older thread had a lot of memory already.
+   * in some situations, to make sure each task has a chance to ramp up to at least 1 / 2N of the
+   * total memory pool (where N is the # of active tasks) before it is forced to spill. This can
+   * happen if the number of tasks increases but an older task had a lot of memory already.
    */
   def tryToAcquire(numBytes: Long): Long = synchronized {
-    val threadId = Thread.currentThread().getId
+    val taskAttemptId = currentTaskAttemptId()
     assert(numBytes > 0, "invalid number of bytes requested: " + numBytes)
 
-    // Add this thread to the threadMemory map just so we can keep an accurate count of the number
-    // of active threads, to let other threads ramp down their memory in calls to tryToAcquire
-    if (!threadMemory.contains(threadId)) {
-      threadMemory(threadId) = 0L
-      notifyAll()  // Will later cause waiting threads to wake up and check numThreads again
+    // Add this task to the taskMemory map just so we can keep an accurate count of the number
+    // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire
+    if (!taskMemory.contains(taskAttemptId)) {
+      taskMemory(taskAttemptId) = 0L
+      notifyAll()  // Will later cause waiting tasks to wake up and check numThreads again
     }
 
     // Keep looping until we're either sure that we don't want to grant this request (because this
-    // thread would have more than 1 / numActiveThreads of the memory) or we have enough free
-    // memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)).
+    // task would have more than 1 / numActiveTasks of the memory) or we have enough free
+    // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)).
     while (true) {
-      val numActiveThreads = threadMemory.keys.size
-      val curMem = threadMemory(threadId)
-      val freeMemory = maxMemory - threadMemory.values.sum
+      val numActiveTasks = taskMemory.keys.size
+      val curMem = taskMemory(taskAttemptId)
+      val freeMemory = maxMemory - taskMemory.values.sum
 
-      // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads;
+      // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks;
       // don't let it be negative
-      val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveThreads) - curMem))
+      val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem))
 
-      if (curMem < maxMemory / (2 * numActiveThreads)) {
-        // We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking;
-        // if we can't give it this much now, wait for other threads to free up memory
-        // (this happens if older threads allocated lots of memory before N grew)
-        if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) {
+      if (curMem < maxMemory / (2 * numActiveTasks)) {
+        // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking;
+        // if we can't give it this much now, wait for other tasks to free up memory
+        // (this happens if older tasks allocated lots of memory before N grew)
+        if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) {
           val toGrant = math.min(maxToGrant, freeMemory)
-          threadMemory(threadId) += toGrant
+          taskMemory(taskAttemptId) += toGrant
           return toGrant
         } else {
-          logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free")
+          logInfo(
+            s"Thread $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free")
           wait()
         }
       } else {
         // Only give it as much memory as is free, which might be none if it reached 1 / numThreads
         val toGrant = math.min(maxToGrant, freeMemory)
-        threadMemory(threadId) += toGrant
+        taskMemory(taskAttemptId) += toGrant
         return toGrant
       }
     }
     0L  // Never reached
   }
 
-  /** Release numBytes bytes for the current thread. */
+  /** Release numBytes bytes for the current task. */
   def release(numBytes: Long): Unit = synchronized {
-    val threadId = Thread.currentThread().getId
-    val curMem = threadMemory.getOrElse(threadId, 0L)
+    val taskAttemptId = currentTaskAttemptId()
+    val curMem = taskMemory.getOrElse(taskAttemptId, 0L)
     if (curMem < numBytes) {
       throw new SparkException(
-        s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}")
+        s"Internal error: release called on ${numBytes} bytes but task only has ${curMem}")
     }
-    threadMemory(threadId) -= numBytes
+    taskMemory(taskAttemptId) -= numBytes
     notifyAll()  // Notify waiters who locked "this" in tryToAcquire that memory has been freed
   }
 
-  /** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */
-  def releaseMemoryForThisThread(): Unit = synchronized {
-    val threadId = Thread.currentThread().getId
-    threadMemory.remove(threadId)
+  /** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */
+  def releaseMemoryForThisTask(): Unit = synchronized {
+    val taskAttemptId = currentTaskAttemptId()
+    taskMemory.remove(taskAttemptId)
     notifyAll()  // Notify waiters who locked "this" in tryToAcquire that memory has been freed
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index ed609772e697994c95fbeed63d744e9f5c1370a2..6f27f00307f8c05e814eddbe81ce1bf17f6fc7f4 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -23,6 +23,7 @@ import java.util.LinkedHashMap
 import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 
+import org.apache.spark.TaskContext
 import org.apache.spark.util.{SizeEstimator, Utils}
 import org.apache.spark.util.collection.SizeTrackingVector
 
@@ -43,11 +44,11 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
   // Ensure only one thread is putting, and if necessary, dropping blocks at any given time
   private val accountingLock = new Object
 
-  // A mapping from thread ID to amount of memory used for unrolling a block (in bytes)
+  // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes)
   // All accesses of this map are assumed to have manually synchronized on `accountingLock`
   private val unrollMemoryMap = mutable.HashMap[Long, Long]()
   // Same as `unrollMemoryMap`, but for pending unroll memory as defined below.
-  // Pending unroll memory refers to the intermediate memory occupied by a thread
+  // Pending unroll memory refers to the intermediate memory occupied by a task
   // after the unroll but before the actual putting of the block in the cache.
   // This chunk of memory is expected to be released *as soon as* we finish
   // caching the corresponding block as opposed to until after the task finishes.
@@ -250,21 +251,21 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
     var elementsUnrolled = 0
     // Whether there is still enough memory for us to continue unrolling this block
     var keepUnrolling = true
-    // Initial per-thread memory to request for unrolling blocks (bytes). Exposed for testing.
+    // Initial per-task memory to request for unrolling blocks (bytes). Exposed for testing.
     val initialMemoryThreshold = unrollMemoryThreshold
     // How often to check whether we need to request more memory
     val memoryCheckPeriod = 16
-    // Memory currently reserved by this thread for this particular unrolling operation
+    // Memory currently reserved by this task for this particular unrolling operation
     var memoryThreshold = initialMemoryThreshold
     // Memory to request as a multiple of current vector size
     val memoryGrowthFactor = 1.5
-    // Previous unroll memory held by this thread, for releasing later (only at the very end)
-    val previousMemoryReserved = currentUnrollMemoryForThisThread
+    // Previous unroll memory held by this task, for releasing later (only at the very end)
+    val previousMemoryReserved = currentUnrollMemoryForThisTask
     // Underlying vector for unrolling the block
     var vector = new SizeTrackingVector[Any]
 
     // Request enough memory to begin unrolling
-    keepUnrolling = reserveUnrollMemoryForThisThread(initialMemoryThreshold)
+    keepUnrolling = reserveUnrollMemoryForThisTask(initialMemoryThreshold)
 
     if (!keepUnrolling) {
       logWarning(s"Failed to reserve initial memory threshold of " +
@@ -283,7 +284,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
             // Hold the accounting lock, in case another thread concurrently puts a block that
             // takes up the unrolling space we just ensured here
             accountingLock.synchronized {
-              if (!reserveUnrollMemoryForThisThread(amountToRequest)) {
+              if (!reserveUnrollMemoryForThisTask(amountToRequest)) {
                 // If the first request is not granted, try again after ensuring free space
                 // If there is still not enough space, give up and drop the partition
                 val spaceToEnsure = maxUnrollMemory - currentUnrollMemory
@@ -291,7 +292,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
                   val result = ensureFreeSpace(blockId, spaceToEnsure)
                   droppedBlocks ++= result.droppedBlocks
                 }
-                keepUnrolling = reserveUnrollMemoryForThisThread(amountToRequest)
+                keepUnrolling = reserveUnrollMemoryForThisTask(amountToRequest)
               }
             }
             // New threshold is currentSize * memoryGrowthFactor
@@ -317,9 +318,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
       // later when the task finishes.
       if (keepUnrolling) {
         accountingLock.synchronized {
-          val amountToRelease = currentUnrollMemoryForThisThread - previousMemoryReserved
-          releaseUnrollMemoryForThisThread(amountToRelease)
-          reservePendingUnrollMemoryForThisThread(amountToRelease)
+          val amountToRelease = currentUnrollMemoryForThisTask - previousMemoryReserved
+          releaseUnrollMemoryForThisTask(amountToRelease)
+          reservePendingUnrollMemoryForThisTask(amountToRelease)
         }
       }
     }
@@ -397,7 +398,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
         droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) }
       }
       // Release the unroll memory used because we no longer need the underlying Array
-      releasePendingUnrollMemoryForThisThread()
+      releasePendingUnrollMemoryForThisTask()
     }
     ResultWithDroppedBlocks(putSuccess, droppedBlocks)
   }
@@ -427,9 +428,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
 
     // Take into account the amount of memory currently occupied by unrolling blocks
     // and minus the pending unroll memory for that block on current thread.
-    val threadId = Thread.currentThread().getId
+    val taskAttemptId = currentTaskAttemptId()
     val actualFreeMemory = freeMemory - currentUnrollMemory +
-      pendingUnrollMemoryMap.getOrElse(threadId, 0L)
+      pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L)
 
     if (actualFreeMemory < space) {
       val rddToAdd = getRddId(blockIdToAdd)
@@ -455,7 +456,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
         logInfo(s"${selectedBlocks.size} blocks selected for dropping")
         for (blockId <- selectedBlocks) {
           val entry = entries.synchronized { entries.get(blockId) }
-          // This should never be null as only one thread should be dropping
+          // This should never be null as only one task should be dropping
           // blocks and removing entries. However the check is still here for
           // future safety.
           if (entry != null) {
@@ -482,79 +483,85 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
     entries.synchronized { entries.containsKey(blockId) }
   }
 
+  private def currentTaskAttemptId(): Long = {
+    // In case this is called on the driver, return an invalid task attempt id.
+    Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L)
+  }
+
   /**
-   * Reserve additional memory for unrolling blocks used by this thread.
+   * Reserve additional memory for unrolling blocks used by this task.
    * Return whether the request is granted.
    */
-  def reserveUnrollMemoryForThisThread(memory: Long): Boolean = {
+  def reserveUnrollMemoryForThisTask(memory: Long): Boolean = {
     accountingLock.synchronized {
       val granted = freeMemory > currentUnrollMemory + memory
       if (granted) {
-        val threadId = Thread.currentThread().getId
-        unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, 0L) + memory
+        val taskAttemptId = currentTaskAttemptId()
+        unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory
       }
       granted
     }
   }
 
   /**
-   * Release memory used by this thread for unrolling blocks.
-   * If the amount is not specified, remove the current thread's allocation altogether.
+   * Release memory used by this task for unrolling blocks.
+   * If the amount is not specified, remove the current task's allocation altogether.
    */
-  def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = {
-    val threadId = Thread.currentThread().getId
+  def releaseUnrollMemoryForThisTask(memory: Long = -1L): Unit = {
+    val taskAttemptId = currentTaskAttemptId()
     accountingLock.synchronized {
       if (memory < 0) {
-        unrollMemoryMap.remove(threadId)
+        unrollMemoryMap.remove(taskAttemptId)
       } else {
-        unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, memory) - memory
-        // If this thread claims no more unroll memory, release it completely
-        if (unrollMemoryMap(threadId) <= 0) {
-          unrollMemoryMap.remove(threadId)
+        unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, memory) - memory
+        // If this task claims no more unroll memory, release it completely
+        if (unrollMemoryMap(taskAttemptId) <= 0) {
+          unrollMemoryMap.remove(taskAttemptId)
         }
       }
     }
   }
 
   /**
-   * Reserve the unroll memory of current unroll successful block used by this thread
+   * Reserve the unroll memory of current unroll successful block used by this task
    * until actually put the block into memory entry.
    */
-  def reservePendingUnrollMemoryForThisThread(memory: Long): Unit = {
-    val threadId = Thread.currentThread().getId
+  def reservePendingUnrollMemoryForThisTask(memory: Long): Unit = {
+    val taskAttemptId = currentTaskAttemptId()
     accountingLock.synchronized {
-       pendingUnrollMemoryMap(threadId) = pendingUnrollMemoryMap.getOrElse(threadId, 0L) + memory
+       pendingUnrollMemoryMap(taskAttemptId) =
+         pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory
     }
   }
 
   /**
-   * Release pending unroll memory of current unroll successful block used by this thread
+   * Release pending unroll memory of current unroll successful block used by this task
    */
-  def releasePendingUnrollMemoryForThisThread(): Unit = {
-    val threadId = Thread.currentThread().getId
+  def releasePendingUnrollMemoryForThisTask(): Unit = {
+    val taskAttemptId = currentTaskAttemptId()
     accountingLock.synchronized {
-      pendingUnrollMemoryMap.remove(threadId)
+      pendingUnrollMemoryMap.remove(taskAttemptId)
     }
   }
 
   /**
-   * Return the amount of memory currently occupied for unrolling blocks across all threads.
+   * Return the amount of memory currently occupied for unrolling blocks across all tasks.
    */
   def currentUnrollMemory: Long = accountingLock.synchronized {
     unrollMemoryMap.values.sum + pendingUnrollMemoryMap.values.sum
   }
 
   /**
-   * Return the amount of memory currently occupied for unrolling blocks by this thread.
+   * Return the amount of memory currently occupied for unrolling blocks by this task.
    */
-  def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized {
-    unrollMemoryMap.getOrElse(Thread.currentThread().getId, 0L)
+  def currentUnrollMemoryForThisTask: Long = accountingLock.synchronized {
+    unrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L)
   }
 
   /**
-   * Return the number of threads currently unrolling blocks.
+   * Return the number of tasks currently unrolling blocks.
    */
-  def numThreadsUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size }
+  def numTasksUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size }
 
   /**
    * Log information about current memory usage.
@@ -566,7 +573,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
     logInfo(
       s"Memory use = ${Utils.bytesToString(blocksMemory)} (blocks) + " +
       s"${Utils.bytesToString(unrollMemory)} (scratch space shared across " +
-      s"$numThreadsUnrolling thread(s)) = ${Utils.bytesToString(totalMemory)}. " +
+      s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(totalMemory)}. " +
       s"Storage limit = ${Utils.bytesToString(maxMemory)}."
     )
   }
diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala
index 96778c9ebafb1e551da775316d9ee904a0083327..f495b6a0379581def58622bd3d0cfdfba9b8c77a 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala
@@ -17,26 +17,39 @@
 
 package org.apache.spark.shuffle
 
+import java.util.concurrent.CountDownLatch
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.mockito.Mockito._
 import org.scalatest.concurrent.Timeouts
 import org.scalatest.time.SpanSugar._
-import java.util.concurrent.atomic.AtomicBoolean
-import java.util.concurrent.CountDownLatch
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkFunSuite, TaskContext}
 
 class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
+
+  val nextTaskAttemptId = new AtomicInteger()
+
   /** Launch a thread with the given body block and return it. */
   private def startThread(name: String)(body: => Unit): Thread = {
     val thread = new Thread("ShuffleMemorySuite " + name) {
       override def run() {
-        body
+        try {
+          val taskAttemptId = nextTaskAttemptId.getAndIncrement
+          val mockTaskContext = mock(classOf[TaskContext], RETURNS_SMART_NULLS)
+          when(mockTaskContext.taskAttemptId()).thenReturn(taskAttemptId)
+          TaskContext.setTaskContext(mockTaskContext)
+          body
+        } finally {
+          TaskContext.unset()
+        }
       }
     }
     thread.start()
     thread
   }
 
-  test("single thread requesting memory") {
+  test("single task requesting memory") {
     val manager = new ShuffleMemoryManager(1000L)
 
     assert(manager.tryToAcquire(100L) === 100L)
@@ -50,7 +63,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
     assert(manager.tryToAcquire(300L) === 300L)
     assert(manager.tryToAcquire(300L) === 200L)
 
-    manager.releaseMemoryForThisThread()
+    manager.releaseMemoryForThisTask()
     assert(manager.tryToAcquire(1000L) === 1000L)
     assert(manager.tryToAcquire(100L) === 0L)
   }
@@ -107,8 +120,8 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
   }
 
 
-  test("threads cannot grow past 1 / N") {
-    // Two threads request 250 bytes first, wait for each other to get it, and then request
+  test("tasks cannot grow past 1 / N") {
+    // Two tasks request 250 bytes first, wait for each other to get it, and then request
     // 500 more; we should only grant 250 bytes to each of them on this second request
 
     val manager = new ShuffleMemoryManager(1000L)
@@ -158,7 +171,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
     assert(state.t2Result2 === 250L)
   }
 
-  test("threads can block to get at least 1 / 2N memory") {
+  test("tasks can block to get at least 1 / 2N memory") {
     // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
     // for a bit and releases 250 bytes, which should then be granted to t2. Further requests
     // by t2 will return false right away because it now has 1 / 2N of the memory.
@@ -224,7 +237,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
     }
   }
 
-  test("releaseMemoryForThisThread") {
+  test("releaseMemoryForThisTask") {
     // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
     // for a bit and releases all its memory. t2 should now be able to grab all the memory.
 
@@ -251,9 +264,9 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
         }
       }
       // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make
-      // sure the other thread blocks for some time otherwise
+      // sure the other task blocks for some time otherwise
       Thread.sleep(300)
-      manager.releaseMemoryForThisThread()
+      manager.releaseMemoryForThisTask()
     }
 
     val t2 = startThread("t2") {
@@ -282,7 +295,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
       t2.join()
     }
 
-    // Both threads should've been able to acquire their memory; the second one will have waited
+    // Both tasks should've been able to acquire their memory; the second one will have waited
     // until the first one acquired 1000 bytes and then released all of it
     state.synchronized {
       assert(state.t1Result === 1000L, "t1 could not allocate memory")
@@ -293,7 +306,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
     }
   }
 
-  test("threads should not be granted a negative size") {
+  test("tasks should not be granted a negative size") {
     val manager = new ShuffleMemoryManager(1000L)
     manager.tryToAcquire(700L)
 
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index bcee901f5dd5fd13e84fcda1ee9baecf23973c46..f480fd107a0c2557b74f2571330ea27132743016 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -1004,32 +1004,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     store = makeBlockManager(12000)
     val memoryStore = store.memoryStore
     assert(memoryStore.currentUnrollMemory === 0)
-    assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 0)
 
     // Reserve
-    memoryStore.reserveUnrollMemoryForThisThread(100)
-    assert(memoryStore.currentUnrollMemoryForThisThread === 100)
-    memoryStore.reserveUnrollMemoryForThisThread(200)
-    assert(memoryStore.currentUnrollMemoryForThisThread === 300)
-    memoryStore.reserveUnrollMemoryForThisThread(500)
-    assert(memoryStore.currentUnrollMemoryForThisThread === 800)
-    memoryStore.reserveUnrollMemoryForThisThread(1000000)
-    assert(memoryStore.currentUnrollMemoryForThisThread === 800) // not granted
+    memoryStore.reserveUnrollMemoryForThisTask(100)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 100)
+    memoryStore.reserveUnrollMemoryForThisTask(200)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 300)
+    memoryStore.reserveUnrollMemoryForThisTask(500)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 800)
+    memoryStore.reserveUnrollMemoryForThisTask(1000000)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted
     // Release
-    memoryStore.releaseUnrollMemoryForThisThread(100)
-    assert(memoryStore.currentUnrollMemoryForThisThread === 700)
-    memoryStore.releaseUnrollMemoryForThisThread(100)
-    assert(memoryStore.currentUnrollMemoryForThisThread === 600)
+    memoryStore.releaseUnrollMemoryForThisTask(100)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 700)
+    memoryStore.releaseUnrollMemoryForThisTask(100)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 600)
     // Reserve again
-    memoryStore.reserveUnrollMemoryForThisThread(4400)
-    assert(memoryStore.currentUnrollMemoryForThisThread === 5000)
-    memoryStore.reserveUnrollMemoryForThisThread(20000)
-    assert(memoryStore.currentUnrollMemoryForThisThread === 5000) // not granted
+    memoryStore.reserveUnrollMemoryForThisTask(4400)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 5000)
+    memoryStore.reserveUnrollMemoryForThisTask(20000)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted
     // Release again
-    memoryStore.releaseUnrollMemoryForThisThread(1000)
-    assert(memoryStore.currentUnrollMemoryForThisThread === 4000)
-    memoryStore.releaseUnrollMemoryForThisThread() // release all
-    assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+    memoryStore.releaseUnrollMemoryForThisTask(1000)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 4000)
+    memoryStore.releaseUnrollMemoryForThisTask() // release all
+    assert(memoryStore.currentUnrollMemoryForThisTask === 0)
   }
 
   /**
@@ -1060,24 +1060,24 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     val bigList = List.fill(40)(new Array[Byte](1000))
     val memoryStore = store.memoryStore
     val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
-    assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 0)
 
     // Unroll with all the space in the world. This should succeed and return an array.
     var unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks)
     verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true)
-    assert(memoryStore.currentUnrollMemoryForThisThread === 0)
-    memoryStore.releasePendingUnrollMemoryForThisThread()
+    assert(memoryStore.currentUnrollMemoryForThisTask === 0)
+    memoryStore.releasePendingUnrollMemoryForThisTask()
 
     // Unroll with not enough space. This should succeed after kicking out someBlock1.
     store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY)
     store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY)
     unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks)
     verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true)
-    assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 0)
     assert(droppedBlocks.size === 1)
     assert(droppedBlocks.head._1 === TestBlockId("someBlock1"))
     droppedBlocks.clear()
-    memoryStore.releasePendingUnrollMemoryForThisThread()
+    memoryStore.releasePendingUnrollMemoryForThisTask()
 
     // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 =
     // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator.
@@ -1085,7 +1085,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY)
     unrollResult = memoryStore.unrollSafely("unroll", bigList.iterator, droppedBlocks)
     verifyUnroll(bigList.iterator, unrollResult, shouldBeArray = false)
-    assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator
+    assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator
     assert(droppedBlocks.size === 1)
     assert(droppedBlocks.head._1 === TestBlockId("someBlock2"))
     droppedBlocks.clear()
@@ -1099,7 +1099,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     val bigList = List.fill(40)(new Array[Byte](1000))
     def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]]
     def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]]
-    assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 0)
 
     // Unroll with plenty of space. This should succeed and cache both blocks.
     val result1 = memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true)
@@ -1110,7 +1110,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     assert(result2.size > 0)
     assert(result1.data.isLeft) // unroll did not drop this block to disk
     assert(result2.data.isLeft)
-    assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 0)
 
     // Re-put these two blocks so block manager knows about them too. Otherwise, block manager
     // would not know how to drop them from memory later.
@@ -1126,7 +1126,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     assert(!memoryStore.contains("b1"))
     assert(memoryStore.contains("b2"))
     assert(memoryStore.contains("b3"))
-    assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 0)
     memoryStore.remove("b3")
     store.putIterator("b3", smallIterator, memOnly)
 
@@ -1138,7 +1138,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     assert(!memoryStore.contains("b2"))
     assert(memoryStore.contains("b3"))
     assert(!memoryStore.contains("b4"))
-    assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator
+    assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator
   }
 
   /**
@@ -1153,7 +1153,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     val bigList = List.fill(40)(new Array[Byte](1000))
     def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]]
     def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]]
-    assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 0)
 
     store.putIterator("b1", smallIterator, memAndDisk)
     store.putIterator("b2", smallIterator, memAndDisk)
@@ -1170,7 +1170,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     assert(!diskStore.contains("b3"))
     memoryStore.remove("b3")
     store.putIterator("b3", smallIterator, StorageLevel.MEMORY_ONLY)
-    assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 0)
 
     // Unroll huge block with not enough space. This should fail and drop the new block to disk
     // directly in addition to kicking out b2 in the process. Memory store should contain only
@@ -1186,7 +1186,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     assert(diskStore.contains("b2"))
     assert(!diskStore.contains("b3"))
     assert(diskStore.contains("b4"))
-    assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator
+    assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator
   }
 
   test("multiple unrolls by the same thread") {
@@ -1195,32 +1195,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     val memoryStore = store.memoryStore
     val smallList = List.fill(40)(new Array[Byte](100))
     def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]]
-    assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 0)
 
     // All unroll memory used is released because unrollSafely returned an array
     memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true)
-    assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 0)
     memoryStore.putIterator("b2", smallIterator, memOnly, returnValues = true)
-    assert(memoryStore.currentUnrollMemoryForThisThread === 0)
+    assert(memoryStore.currentUnrollMemoryForThisTask === 0)
 
     // Unroll memory is not released because unrollSafely returned an iterator
     // that still depends on the underlying vector used in the process
     memoryStore.putIterator("b3", smallIterator, memOnly, returnValues = true)
-    val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisThread
+    val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisTask
     assert(unrollMemoryAfterB3 > 0)
 
     // The unroll memory owned by this thread builds on top of its value after the previous unrolls
     memoryStore.putIterator("b4", smallIterator, memOnly, returnValues = true)
-    val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisThread
+    val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisTask
     assert(unrollMemoryAfterB4 > unrollMemoryAfterB3)
 
     // ... but only to a certain extent (until we run out of free space to grant new unroll memory)
     memoryStore.putIterator("b5", smallIterator, memOnly, returnValues = true)
-    val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisThread
+    val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisTask
     memoryStore.putIterator("b6", smallIterator, memOnly, returnValues = true)
-    val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisThread
+    val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisTask
     memoryStore.putIterator("b7", smallIterator, memOnly, returnValues = true)
-    val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisThread
+    val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisTask
     assert(unrollMemoryAfterB5 === unrollMemoryAfterB4)
     assert(unrollMemoryAfterB6 === unrollMemoryAfterB4)
     assert(unrollMemoryAfterB7 === unrollMemoryAfterB4)