From d75579d09912cfb1eeac0589d625ea0452701fa0 Mon Sep 17 00:00:00 2001
From: Tianshuo Deng <tdeng@twitter.com>
Date: Wed, 19 Nov 2014 10:01:09 -0800
Subject: [PATCH] [SPARK-4467] fix elements read count for ExtrenalSorter

the elementsRead variable should be reset to 0 after each spilling

Author: Tianshuo Deng <tdeng@twitter.com>

Closes #3302 from tsdeng/fix_external_sorter_record_count and squashes the following commits:

7b56ca0 [Tianshuo Deng] fix method signature
782c7de [Tianshuo Deng] make elementsRead private, fix comment
bb7ff28 [Tianshuo Deng] update elemetsRead through addElementsRead method
74ca246 [Tianshuo Deng] fix elements read count
---
 .../spark/util/collection/ExternalAppendOnlyMap.scala  |  8 +-------
 .../apache/spark/util/collection/ExternalSorter.scala  |  8 ++------
 .../org/apache/spark/util/collection/Spillable.scala   | 10 +++++++++-
 3 files changed, 12 insertions(+), 14 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 26fa0cb6d7..8a0f5a602d 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -76,10 +76,6 @@ class ExternalAppendOnlyMap[K, V, C](
   private val sparkConf = SparkEnv.get.conf
   private val diskBlockManager = blockManager.diskBlockManager
 
-  // Number of pairs inserted since last spill; note that we count them even if a value is merged
-  // with a previous key in case we're doing something like groupBy where the result grows
-  protected[this] var elementsRead = 0L
-
   /**
    * Size of object batches when reading/writing from serializers.
    *
@@ -132,7 +128,7 @@ class ExternalAppendOnlyMap[K, V, C](
         currentMap = new SizeTrackingAppendOnlyMap[K, C]
       }
       currentMap.changeValue(curEntry._1, update)
-      elementsRead += 1
+      addElementsRead()
     }
   }
 
@@ -209,8 +205,6 @@ class ExternalAppendOnlyMap[K, V, C](
     }
 
     spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
-
-    elementsRead = 0
   }
 
   def diskBytesSpilled: Long = _diskBytesSpilled
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index c1ce13683b..c617ff5c51 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -119,10 +119,6 @@ private[spark] class ExternalSorter[K, V, C](
   private var map = new SizeTrackingAppendOnlyMap[(Int, K), C]
   private var buffer = new SizeTrackingPairBuffer[(Int, K), C]
 
-  // Number of pairs read from input since last spill; note that we count them even if a value is
-  // merged with a previous key in case we're doing something like groupBy where the result grows
-  protected[this] var elementsRead = 0L
-
   // Total spilling statistics
   private var _diskBytesSpilled = 0L
 
@@ -204,7 +200,7 @@ private[spark] class ExternalSorter[K, V, C](
         if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
       }
       while (records.hasNext) {
-        elementsRead += 1
+        addElementsRead()
         kv = records.next()
         map.changeValue((getPartition(kv._1), kv._1), update)
         maybeSpillCollection(usingMap = true)
@@ -212,7 +208,7 @@ private[spark] class ExternalSorter[K, V, C](
     } else {
       // Stick values into our buffer
       while (records.hasNext) {
-        elementsRead += 1
+        addElementsRead()
         val kv = records.next()
         buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
         maybeSpillCollection(usingMap = false)
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
index 0e4c6d633a..cb73b377fc 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
@@ -36,7 +36,11 @@ private[spark] trait Spillable[C] {
   protected def spill(collection: C): Unit
 
   // Number of elements read from input since last spill
-  protected var elementsRead: Long
+  protected def elementsRead: Long = _elementsRead
+
+  // Called by subclasses every time a record is read
+  // It's used for checking spilling frequency
+  protected def addElementsRead(): Unit = { _elementsRead += 1 }
 
   // Memory manager that can be used to acquire/release memory
   private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
@@ -44,6 +48,9 @@ private[spark] trait Spillable[C] {
   // What threshold of elementsRead we start estimating collection size at
   private[this] val trackMemoryThreshold = 1000
 
+  // Number of elements read from input since last spill
+  private[this] var _elementsRead = 0L
+
   // How much of the shared memory pool this collection has claimed
   private[this] var myMemoryThreshold = 0L
 
@@ -76,6 +83,7 @@ private[spark] trait Spillable[C] {
 
         spill(collection)
 
+        _elementsRead = 0
         // Keep track of spills, and release memory
         _memoryBytesSpilled += currentMemory
         releaseMemoryForThisThread()
-- 
GitLab