diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 26fa0cb6d7bdef509430b9882d43809e9047f45f..8a0f5a602de12bfa61ddf980290181eeba6a9b6c 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -76,10 +76,6 @@ class ExternalAppendOnlyMap[K, V, C](
   private val sparkConf = SparkEnv.get.conf
   private val diskBlockManager = blockManager.diskBlockManager
 
-  // Number of pairs inserted since last spill; note that we count them even if a value is merged
-  // with a previous key in case we're doing something like groupBy where the result grows
-  protected[this] var elementsRead = 0L
-
   /**
    * Size of object batches when reading/writing from serializers.
    *
@@ -132,7 +128,7 @@ class ExternalAppendOnlyMap[K, V, C](
         currentMap = new SizeTrackingAppendOnlyMap[K, C]
       }
       currentMap.changeValue(curEntry._1, update)
-      elementsRead += 1
+      addElementsRead()
     }
   }
 
@@ -209,8 +205,6 @@ class ExternalAppendOnlyMap[K, V, C](
     }
 
     spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
-
-    elementsRead = 0
   }
 
   def diskBytesSpilled: Long = _diskBytesSpilled
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index c1ce13683b5690b54e8daecc073494d838f1f8d7..c617ff5c51d045dc1d95c84ae64cdd0dc084c6b3 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -119,10 +119,6 @@ private[spark] class ExternalSorter[K, V, C](
   private var map = new SizeTrackingAppendOnlyMap[(Int, K), C]
   private var buffer = new SizeTrackingPairBuffer[(Int, K), C]
 
-  // Number of pairs read from input since last spill; note that we count them even if a value is
-  // merged with a previous key in case we're doing something like groupBy where the result grows
-  protected[this] var elementsRead = 0L
-
   // Total spilling statistics
   private var _diskBytesSpilled = 0L
 
@@ -204,7 +200,7 @@ private[spark] class ExternalSorter[K, V, C](
         if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
       }
       while (records.hasNext) {
-        elementsRead += 1
+        addElementsRead()
         kv = records.next()
         map.changeValue((getPartition(kv._1), kv._1), update)
         maybeSpillCollection(usingMap = true)
@@ -212,7 +208,7 @@ private[spark] class ExternalSorter[K, V, C](
     } else {
       // Stick values into our buffer
       while (records.hasNext) {
-        elementsRead += 1
+        addElementsRead()
         val kv = records.next()
         buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
         maybeSpillCollection(usingMap = false)
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
index 0e4c6d633a4a9ed0e428b3286617566421cd9cf4..cb73b377fca98f200148e431a996a470f264656a 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
@@ -36,7 +36,11 @@ private[spark] trait Spillable[C] {
   protected def spill(collection: C): Unit
 
   // Number of elements read from input since last spill
-  protected var elementsRead: Long
+  protected def elementsRead: Long = _elementsRead
+
+  // Called by subclasses every time a record is read
+  // It's used for checking spilling frequency
+  protected def addElementsRead(): Unit = { _elementsRead += 1 }
 
   // Memory manager that can be used to acquire/release memory
   private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
@@ -44,6 +48,9 @@ private[spark] trait Spillable[C] {
   // What threshold of elementsRead we start estimating collection size at
   private[this] val trackMemoryThreshold = 1000
 
+  // Number of elements read from input since last spill
+  private[this] var _elementsRead = 0L
+
   // How much of the shared memory pool this collection has claimed
   private[this] var myMemoryThreshold = 0L
 
@@ -76,6 +83,7 @@ private[spark] trait Spillable[C] {
 
         spill(collection)
 
+        _elementsRead = 0
         // Keep track of spills, and release memory
         _memoryBytesSpilled += currentMemory
         releaseMemoryForThisThread()