diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index cb73b377fca98f200148e431a996a470f264656a..9f543120748565194bd703aac1e129b560f48ca5 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -24,10 +24,7 @@ import org.apache.spark.SparkEnv * Spills contents of an in-memory collection to disk when the memory threshold * has been exceeded. */ -private[spark] trait Spillable[C] { - - this: Logging => - +private[spark] trait Spillable[C] extends Logging { /** * Spills the current in-memory collection to disk, and releases the memory. * @@ -45,15 +42,21 @@ private[spark] trait Spillable[C] { // Memory manager that can be used to acquire/release memory private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager - // What threshold of elementsRead we start estimating collection size at + // Threshold for `elementsRead` before we start tracking this collection's memory usage private[this] val trackMemoryThreshold = 1000 + // Initial threshold for the size of a collection before we start tracking its memory usage + // Exposed for testing + private[this] val initialMemoryThreshold: Long = + SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024) + + // Threshold for this collection's size in bytes before we start tracking its memory usage + // To avoid a large number of small spills, initialize this to a value orders of magnitude > 0 + private[this] var myMemoryThreshold = initialMemoryThreshold + // Number of elements read from input since last spill private[this] var _elementsRead = 0L - // How much of the shared memory pool this collection has claimed - private[this] var myMemoryThreshold = 0L - // Number of bytes spilled in total private[this] var _memoryBytesSpilled = 0L @@ -102,8 +105,9 @@ private[spark] trait Spillable[C] { * Release our memory back to the shuffle pool so that other threads can grab it. */ private def releaseMemoryForThisThread(): Unit = { - shuffleMemoryManager.release(myMemoryThreshold) - myMemoryThreshold = 0L + // The amount we requested does not include the initial memory tracking threshold + shuffleMemoryManager.release(myMemoryThreshold - initialMemoryThreshold) + myMemoryThreshold = initialMemoryThreshold } /** @@ -114,7 +118,7 @@ private[spark] trait Spillable[C] { @inline private def logSpillage(size: Long) { val threadId = Thread.currentThread().getId logInfo("Thread %d spilling in-memory map of %s to disk (%d time%s so far)" - .format(threadId, org.apache.spark.util.Utils.bytesToString(size), - _spillCount, if (_spillCount > 1) "s" else "")) + .format(threadId, org.apache.spark.util.Utils.bytesToString(size), + _spillCount, if (_spillCount > 1) "s" else "")) } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index f26e40fbd4b36284c75a823ff7be8e976df6cd7b..3cb42d416de4f8d59c922a2dbeaea9b18031116b 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -127,6 +127,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe test("empty partitions with spilling") { val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.spill.initialMemoryThreshold", "512") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf) @@ -152,6 +153,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe test("empty partitions with spilling, bypass merge-sort") { val conf = createSparkConf(false) conf.set("spark.shuffle.memoryFraction", "0.001") + conf.set("spark.shuffle.spill.initialMemoryThreshold", "512") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") sc = new SparkContext("local", "test", conf)