diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
index cb73b377fca98f200148e431a996a470f264656a..9f543120748565194bd703aac1e129b560f48ca5 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
@@ -24,10 +24,7 @@ import org.apache.spark.SparkEnv
  * Spills contents of an in-memory collection to disk when the memory threshold
  * has been exceeded.
  */
-private[spark] trait Spillable[C] {
-
-  this: Logging =>
-
+private[spark] trait Spillable[C] extends Logging {
   /**
    * Spills the current in-memory collection to disk, and releases the memory.
    *
@@ -45,15 +42,21 @@ private[spark] trait Spillable[C] {
   // Memory manager that can be used to acquire/release memory
   private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
 
-  // What threshold of elementsRead we start estimating collection size at
+  // Threshold for `elementsRead` before we start tracking this collection's memory usage
   private[this] val trackMemoryThreshold = 1000
 
+  // Initial threshold for the size of a collection before we start tracking its memory usage
+  // Exposed for testing
+  private[this] val initialMemoryThreshold: Long =
+    SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024)
+
+  // Threshold for this collection's size in bytes before we start tracking its memory usage
+  // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0
+  private[this] var myMemoryThreshold = initialMemoryThreshold
+
   // Number of elements read from input since last spill
   private[this] var _elementsRead = 0L
 
-  // How much of the shared memory pool this collection has claimed
-  private[this] var myMemoryThreshold = 0L
-
   // Number of bytes spilled in total
   private[this] var _memoryBytesSpilled = 0L
 
@@ -102,8 +105,9 @@ private[spark] trait Spillable[C] {
    * Release our memory back to the shuffle pool so that other threads can grab it.
    */
   private def releaseMemoryForThisThread(): Unit = {
-    shuffleMemoryManager.release(myMemoryThreshold)
-    myMemoryThreshold = 0L
+    // The amount we requested does not include the initial memory tracking threshold
+    shuffleMemoryManager.release(myMemoryThreshold - initialMemoryThreshold)
+    myMemoryThreshold = initialMemoryThreshold
   }
 
   /**
@@ -114,7 +118,7 @@ private[spark] trait Spillable[C] {
   @inline private def logSpillage(size: Long) {
     val threadId = Thread.currentThread().getId
     logInfo("Thread %d spilling in-memory map of %s to disk (%d time%s so far)"
-        .format(threadId, org.apache.spark.util.Utils.bytesToString(size),
-            _spillCount, if (_spillCount > 1) "s" else ""))
+      .format(threadId, org.apache.spark.util.Utils.bytesToString(size),
+        _spillCount, if (_spillCount > 1) "s" else ""))
   }
 }
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index f26e40fbd4b36284c75a823ff7be8e976df6cd7b..3cb42d416de4f8d59c922a2dbeaea9b18031116b 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -127,6 +127,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
   test("empty partitions with spilling") {
     val conf = createSparkConf(false)
     conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.spill.initialMemoryThreshold", "512")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local", "test", conf)
 
@@ -152,6 +153,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
   test("empty partitions with spilling, bypass merge-sort") {
     val conf = createSparkConf(false)
     conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.spill.initialMemoryThreshold", "512")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local", "test", conf)