diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
index af913454fce69ec02a31f7806e427bf1d8dc5af9..4d884dec079164620e97ec5f328ece970b9fb0ce 100644
--- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -18,7 +18,8 @@
 package org.apache.spark
 
 import java.lang.ref.{ReferenceQueue, WeakReference}
-import java.util.concurrent.{ConcurrentLinkedQueue, ScheduledExecutorService, TimeUnit}
+import java.util.Collections
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, ScheduledExecutorService, TimeUnit}
 
 import scala.collection.JavaConverters._
 
@@ -58,7 +59,12 @@ private class CleanupTaskWeakReference(
  */
 private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
 
-  private val referenceBuffer = new ConcurrentLinkedQueue[CleanupTaskWeakReference]()
+  /**
+   * A buffer to ensure that `CleanupTaskWeakReference`s are not garbage collected as long as they
+   * have not been handled by the reference queue.
+   */
+  private val referenceBuffer =
+    Collections.newSetFromMap[CleanupTaskWeakReference](new ConcurrentHashMap)
 
   private val referenceQueue = new ReferenceQueue[AnyRef]
 
@@ -176,10 +182,10 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
           .map(_.asInstanceOf[CleanupTaskWeakReference])
         // Synchronize here to avoid being interrupted on stop()
         synchronized {
-          reference.map(_.task).foreach { task =>
-            logDebug("Got cleaning task " + task)
-            referenceBuffer.remove(reference.get)
-            task match {
+          reference.foreach { ref =>
+            logDebug("Got cleaning task " + ref.task)
+            referenceBuffer.remove(ref)
+            ref.task match {
               case CleanRDD(rddId) =>
                 doCleanupRDD(rddId, blocking = blockOnCleanupTasks)
               case CleanShuffle(shuffleId) =>