diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
index 358bb372501582831c9b4b6a5edaf850b530777d..ca70d7f4a4311d09c29548067041304efbdc325e 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
@@ -144,14 +144,16 @@ public class TaskMemoryManager {
   public void freePage(MemoryBlock page) {
     assert (page.pageNumber != -1) :
       "Called freePage() on memory that wasn't allocated with allocatePage()";
-    executorMemoryManager.free(page);
+    assert(allocatedPages.get(page.pageNumber));
+    pageTable[page.pageNumber] = null;
     synchronized (this) {
       allocatedPages.clear(page.pageNumber);
     }
-    pageTable[page.pageNumber] = null;
     if (logger.isTraceEnabled()) {
       logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size());
     }
+    // Cannot access a page once it's freed.
+    executorMemoryManager.free(page);
   }
 
   /**
@@ -166,7 +168,9 @@ public class TaskMemoryManager {
   public MemoryBlock allocate(long size) throws OutOfMemoryError {
     assert(size > 0) : "Size must be positive, but got " + size;
     final MemoryBlock memory = executorMemoryManager.allocate(size);
-    allocatedNonPageMemory.add(memory);
+    synchronized(allocatedNonPageMemory) {
+      allocatedNonPageMemory.add(memory);
+    }
     return memory;
   }
 
@@ -176,8 +180,10 @@ public class TaskMemoryManager {
   public void free(MemoryBlock memory) {
     assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()";
     executorMemoryManager.free(memory);
-    final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory);
-    assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!";
+    synchronized(allocatedNonPageMemory) {
+      final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory);
+      assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!";
+    }
   }
 
   /**
@@ -223,9 +229,10 @@ public class TaskMemoryManager {
     if (inHeap) {
       final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
       assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
-      final Object page = pageTable[pageNumber].getBaseObject();
+      final MemoryBlock page = pageTable[pageNumber];
       assert (page != null);
-      return page;
+      assert (page.getBaseObject() != null);
+      return page.getBaseObject();
     } else {
       return null;
     }
@@ -244,7 +251,9 @@ public class TaskMemoryManager {
       // converted the absolute address into a relative address. Here, we invert that operation:
       final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
       assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
-      return pageTable[pageNumber].getBaseOffset() + offsetInPage;
+      final MemoryBlock page = pageTable[pageNumber];
+      assert (page != null);
+      return page.getBaseOffset() + offsetInPage;
     }
   }
 
@@ -260,14 +269,17 @@ public class TaskMemoryManager {
         freePage(page);
       }
     }
-    final Iterator<MemoryBlock> iter = allocatedNonPageMemory.iterator();
-    while (iter.hasNext()) {
-      final MemoryBlock memory = iter.next();
-      freedBytes += memory.size();
-      // We don't call free() here because that calls Set.remove, which would lead to a
-      // ConcurrentModificationException here.
-      executorMemoryManager.free(memory);
-      iter.remove();
+
+    synchronized (allocatedNonPageMemory) {
+      final Iterator<MemoryBlock> iter = allocatedNonPageMemory.iterator();
+      while (iter.hasNext()) {
+        final MemoryBlock memory = iter.next();
+        freedBytes += memory.size();
+        // We don't call free() here because that calls Set.remove, which would lead to a
+        // ConcurrentModificationException here.
+        executorMemoryManager.free(memory);
+        iter.remove();
+      }
     }
     return freedBytes;
   }