diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index 000bbd6b532addbde9f8313e1bc64c6223cb5e14..5f31bfba3f8d62a039d60c226f00ff192935d617 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -19,6 +19,7 @@ package org.apache.spark
 
 import java.io.{ObjectInputStream, Serializable}
 import java.util.concurrent.atomic.AtomicLong
+import java.lang.ThreadLocal
 
 import scala.collection.generic.Growable
 import scala.collection.mutable.Map
@@ -278,10 +279,12 @@ object AccumulatorParam {
 
 // TODO: The multi-thread support in accumulators is kind of lame; check
 // if there's a more intuitive way of doing it right
-private object Accumulators {
+private[spark] object Accumulators {
   // TODO: Use soft references? => need to make readObject work properly then
   val originals = Map[Long, Accumulable[_, _]]()
-  val localAccums = Map[Thread, Map[Long, Accumulable[_, _]]]()
+  val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() {
+    override protected def initialValue() = Map[Long, Accumulable[_, _]]()
+  }
   var lastId: Long = 0
 
   def newId(): Long = synchronized {
@@ -293,22 +296,21 @@ private object Accumulators {
     if (original) {
       originals(a.id) = a
     } else {
-      val accums = localAccums.getOrElseUpdate(Thread.currentThread, Map())
-      accums(a.id) = a
+      localAccums.get()(a.id) = a
     }
   }
 
   // Clear the local (non-original) accumulators for the current thread
   def clear() {
     synchronized {
-      localAccums.remove(Thread.currentThread)
+      localAccums.get.clear
     }
   }
 
   // Get the values of the local accumulators for the current thread (by ID)
   def values: Map[Long, Any] = synchronized {
     val ret = Map[Long, Any]()
-    for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) {
+    for ((id, accum) <- localAccums.get) {
       ret(id) = accum.localValue
     }
     return ret
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 835157fc520aa77bcfbf2eebc7163d5d91df144f..52de6980ecbf8bcde7a9a80c1ee84fa56a16453d 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -172,7 +172,6 @@ private[spark] class Executor(
       val startGCTime = gcTime
 
       try {
-        Accumulators.clear()
         val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
         updateDependencies(taskFiles, taskJars)
         task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
@@ -278,6 +277,8 @@ private[spark] class Executor(
         env.shuffleMemoryManager.releaseMemoryForThisThread()
         // Release memory used by this thread for unrolling blocks
         env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
+        // Release memory used by this thread for accumulators
+        Accumulators.clear()
         runningTasks.remove(taskId)
       }
     }