Skip to content
Snippets Groups Projects
Commit 4fde28c2 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

SPARK-2711. Create a ShuffleMemoryManager to track memory for all spilling collections

This tracks memory properly if there are multiple spilling collections in the same task (which was a problem before), and also implements an algorithm that lets each thread grow up to 1 / 2N of the memory pool (where N is the number of threads) before spilling, which avoids an inefficiency with small spills we had before (some threads would spill many times at 0-1 MB because the pool was allocated elsewhere).

Author: Matei Zaharia <matei@databricks.com>

Closes #1707 from mateiz/spark-2711 and squashes the following commits:

debf75b [Matei Zaharia] Review comments
24f28f3 [Matei Zaharia] Small rename
c8f3a8b [Matei Zaharia] Update ShuffleMemoryManager to be able to partially grant requests
315e3a5 [Matei Zaharia] Some review comments
b810120 [Matei Zaharia] Create central manager to track memory for all spilling collections
parent 066765d6
No related branches found
No related tags found
No related merge requests found
......@@ -35,7 +35,7 @@ import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.ConnectionManager
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.storage._
import org.apache.spark.util.{AkkaUtils, Utils}
......@@ -66,12 +66,9 @@ class SparkEnv (
val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
val shuffleMemoryManager: ShuffleMemoryManager,
val conf: SparkConf) extends Logging {
// A mapping of thread ID to amount of memory, in bytes, used for shuffle aggregations
// All accesses should be manually synchronized
val shuffleMemoryMap = mutable.HashMap[Long, Long]()
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
// A general, soft-reference map for metadata needed during HadoopRDD split computation
......@@ -252,6 +249,8 @@ object SparkEnv extends Logging {
val shuffleManager = instantiateClass[ShuffleManager](
"spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager")
val shuffleMemoryManager = new ShuffleMemoryManager(conf)
// Warn about deprecated spark.cache.class property
if (conf.contains("spark.cache.class")) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " +
......@@ -273,6 +272,7 @@ object SparkEnv extends Logging {
httpFileServer,
sparkFilesDir,
metricsSystem,
shuffleMemoryManager,
conf)
}
......
......@@ -276,10 +276,7 @@ private[spark] class Executor(
}
} finally {
// Release memory used by this thread for shuffles
val shuffleMemoryMap = env.shuffleMemoryMap
shuffleMemoryMap.synchronized {
shuffleMemoryMap.remove(Thread.currentThread().getId)
}
env.shuffleMemoryManager.releaseMemoryForThisThread()
// Release memory used by this thread for unrolling blocks
env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
runningTasks.remove(taskId)
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle
import scala.collection.mutable
import org.apache.spark.{Logging, SparkException, SparkConf}
/**
* Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling
* collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory
* from this pool and release it as it spills data out. When a task ends, all its memory will be
* released by the Executor.
*
* This class tries to ensure that each thread gets a reasonable share of memory, instead of some
* thread ramping up to a large amount first and then causing others to spill to disk repeatedly.
* If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory
* before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the
* set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever
* this set changes. This is all done by synchronizing access on "this" to mutate state and using
* wait() and notifyAll() to signal changes.
*/
private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes
def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf))
/**
* Try to acquire up to numBytes memory for the current thread, and return the number of bytes
* obtained, or 0 if none can be allocated. This call may block until there is enough free memory
* in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the
* total memory pool (where N is the # of active threads) before it is forced to spill. This can
* happen if the number of threads increases but an older thread had a lot of memory already.
*/
def tryToAcquire(numBytes: Long): Long = synchronized {
val threadId = Thread.currentThread().getId
assert(numBytes > 0, "invalid number of bytes requested: " + numBytes)
// Add this thread to the threadMemory map just so we can keep an accurate count of the number
// of active threads, to let other threads ramp down their memory in calls to tryToAcquire
if (!threadMemory.contains(threadId)) {
threadMemory(threadId) = 0L
notifyAll() // Will later cause waiting threads to wake up and check numThreads again
}
// Keep looping until we're either sure that we don't want to grant this request (because this
// thread would have more than 1 / numActiveThreads of the memory) or we have enough free
// memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)).
while (true) {
val numActiveThreads = threadMemory.keys.size
val curMem = threadMemory(threadId)
val freeMemory = maxMemory - threadMemory.values.sum
// How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads
val maxToGrant = math.min(numBytes, (maxMemory / numActiveThreads) - curMem)
if (curMem < maxMemory / (2 * numActiveThreads)) {
// We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking;
// if we can't give it this much now, wait for other threads to free up memory
// (this happens if older threads allocated lots of memory before N grew)
if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) {
val toGrant = math.min(maxToGrant, freeMemory)
threadMemory(threadId) += toGrant
return toGrant
} else {
logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free")
wait()
}
} else {
// Only give it as much memory as is free, which might be none if it reached 1 / numThreads
val toGrant = math.min(maxToGrant, freeMemory)
threadMemory(threadId) += toGrant
return toGrant
}
}
0L // Never reached
}
/** Release numBytes bytes for the current thread. */
def release(numBytes: Long): Unit = synchronized {
val threadId = Thread.currentThread().getId
val curMem = threadMemory.getOrElse(threadId, 0L)
if (curMem < numBytes) {
throw new SparkException(
s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}")
}
threadMemory(threadId) -= numBytes
notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
}
/** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */
def releaseMemoryForThisThread(): Unit = synchronized {
val threadId = Thread.currentThread().getId
threadMemory.remove(threadId)
notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
}
}
private object ShuffleMemoryManager {
/**
* Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction
* of the memory pool and a safety factor since collections can sometimes grow bigger than
* the size we target before we estimate their sizes again.
*/
def getMaxMemory(conf: SparkConf): Long = {
val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2)
val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8)
(Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
}
}
......@@ -71,13 +71,7 @@ class ExternalAppendOnlyMap[K, V, C](
private val spilledMaps = new ArrayBuffer[DiskMapIterator]
private val sparkConf = SparkEnv.get.conf
private val diskBlockManager = blockManager.diskBlockManager
// Collective memory threshold shared across all running tasks
private val maxMemoryThreshold = {
val memoryFraction = sparkConf.getDouble("spark.shuffle.memoryFraction", 0.2)
val safetyFraction = sparkConf.getDouble("spark.shuffle.safetyFraction", 0.8)
(Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
}
private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
// Number of pairs inserted since last spill; note that we count them even if a value is merged
// with a previous key in case we're doing something like groupBy where the result grows
......@@ -140,28 +134,15 @@ class ExternalAppendOnlyMap[K, V, C](
if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
currentMap.estimateSize() >= myMemoryThreshold)
{
val currentSize = currentMap.estimateSize()
var shouldSpill = false
val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
// Atomically check whether there is sufficient memory in the global pool for
// this map to grow and, if possible, allocate the required amount
shuffleMemoryMap.synchronized {
val threadId = Thread.currentThread().getId
val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId)
val availableMemory = maxMemoryThreshold -
(shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L))
// Try to allocate at least 2x more memory, otherwise spill
shouldSpill = availableMemory < currentSize * 2
if (!shouldSpill) {
shuffleMemoryMap(threadId) = currentSize * 2
myMemoryThreshold = currentSize * 2
}
}
// Do not synchronize spills
if (shouldSpill) {
spill(currentSize)
// Claim up to double our current memory from the shuffle memory pool
val currentMemory = currentMap.estimateSize()
val amountToRequest = 2 * currentMemory - myMemoryThreshold
val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
myMemoryThreshold += granted
if (myMemoryThreshold <= currentMemory) {
// We were granted too little memory to grow further (either tryToAcquire returned 0,
// or we already had more memory than myMemoryThreshold); spill the current collection
spill(currentMemory) // Will also release memory back to ShuffleMemoryManager
}
}
currentMap.changeValue(curEntry._1, update)
......@@ -245,12 +226,9 @@ class ExternalAppendOnlyMap[K, V, C](
currentMap = new SizeTrackingAppendOnlyMap[K, C]
spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
// Reset the amount of shuffle memory used by this map in the global pool
val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
shuffleMemoryMap.synchronized {
shuffleMemoryMap(Thread.currentThread().getId) = 0
}
myMemoryThreshold = 0
// Release our memory back to the shuffle pool so that other threads can grab it
shuffleMemoryManager.release(myMemoryThreshold)
myMemoryThreshold = 0L
elementsRead = 0
_memoryBytesSpilled += mapSize
......
......@@ -78,6 +78,7 @@ private[spark] class ExternalSorter[K, V, C](
private val blockManager = SparkEnv.get.blockManager
private val diskBlockManager = blockManager.diskBlockManager
private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
private val ser = Serializer.getSerializer(serializer)
private val serInstance = ser.newInstance()
......@@ -116,13 +117,6 @@ private[spark] class ExternalSorter[K, V, C](
private var _memoryBytesSpilled = 0L
private var _diskBytesSpilled = 0L
// Collective memory threshold shared across all running tasks
private val maxMemoryThreshold = {
val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2)
val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8)
(Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
}
// How much of the shared memory pool this collection has claimed
private var myMemoryThreshold = 0L
......@@ -218,31 +212,15 @@ private[spark] class ExternalSorter[K, V, C](
if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
collection.estimateSize() >= myMemoryThreshold)
{
// TODO: This logic doesn't work if there are two external collections being used in the same
// task (e.g. to read shuffle output and write it out into another shuffle) [SPARK-2711]
val currentSize = collection.estimateSize()
var shouldSpill = false
val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
// Atomically check whether there is sufficient memory in the global pool for
// us to double our threshold
shuffleMemoryMap.synchronized {
val threadId = Thread.currentThread().getId
val previouslyClaimedMemory = shuffleMemoryMap.get(threadId)
val availableMemory = maxMemoryThreshold -
(shuffleMemoryMap.values.sum - previouslyClaimedMemory.getOrElse(0L))
// Try to allocate at least 2x more memory, otherwise spill
shouldSpill = availableMemory < currentSize * 2
if (!shouldSpill) {
shuffleMemoryMap(threadId) = currentSize * 2
myMemoryThreshold = currentSize * 2
}
}
// Do not hold lock during spills
if (shouldSpill) {
spill(currentSize, usingMap)
// Claim up to double our current memory from the shuffle memory pool
val currentMemory = collection.estimateSize()
val amountToRequest = 2 * currentMemory - myMemoryThreshold
val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
myMemoryThreshold += granted
if (myMemoryThreshold <= currentMemory) {
// We were granted too little memory to grow further (either tryToAcquire returned 0,
// or we already had more memory than myMemoryThreshold); spill the current collection
spill(currentMemory, usingMap) // Will also release memory back to ShuffleMemoryManager
}
}
}
......@@ -327,11 +305,8 @@ private[spark] class ExternalSorter[K, V, C](
buffer = new SizeTrackingPairBuffer[(Int, K), C]
}
// Reset the amount of shuffle memory used by this map in the global pool
val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
shuffleMemoryMap.synchronized {
shuffleMemoryMap(Thread.currentThread().getId) = 0
}
// Release our memory back to the shuffle pool so that other threads can grab it
shuffleMemoryManager.release(myMemoryThreshold)
myMemoryThreshold = 0
spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition))
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle
import org.scalatest.FunSuite
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.CountDownLatch
class ShuffleMemoryManagerSuite extends FunSuite with Timeouts {
/** Launch a thread with the given body block and return it. */
private def startThread(name: String)(body: => Unit): Thread = {
val thread = new Thread("ShuffleMemorySuite " + name) {
override def run() {
body
}
}
thread.start()
thread
}
test("single thread requesting memory") {
val manager = new ShuffleMemoryManager(1000L)
assert(manager.tryToAcquire(100L) === 100L)
assert(manager.tryToAcquire(400L) === 400L)
assert(manager.tryToAcquire(400L) === 400L)
assert(manager.tryToAcquire(200L) === 100L)
assert(manager.tryToAcquire(100L) === 0L)
assert(manager.tryToAcquire(100L) === 0L)
manager.release(500L)
assert(manager.tryToAcquire(300L) === 300L)
assert(manager.tryToAcquire(300L) === 200L)
manager.releaseMemoryForThisThread()
assert(manager.tryToAcquire(1000L) === 1000L)
assert(manager.tryToAcquire(100L) === 0L)
}
test("two threads requesting full memory") {
// Two threads request 500 bytes first, wait for each other to get it, and then request
// 500 more; we should immediately return 0 as both are now at 1 / N
val manager = new ShuffleMemoryManager(1000L)
class State {
var t1Result1 = -1L
var t2Result1 = -1L
var t1Result2 = -1L
var t2Result2 = -1L
}
val state = new State
val t1 = startThread("t1") {
val r1 = manager.tryToAcquire(500L)
state.synchronized {
state.t1Result1 = r1
state.notifyAll()
while (state.t2Result1 === -1L) {
state.wait()
}
}
val r2 = manager.tryToAcquire(500L)
state.synchronized { state.t1Result2 = r2 }
}
val t2 = startThread("t2") {
val r1 = manager.tryToAcquire(500L)
state.synchronized {
state.t2Result1 = r1
state.notifyAll()
while (state.t1Result1 === -1L) {
state.wait()
}
}
val r2 = manager.tryToAcquire(500L)
state.synchronized { state.t2Result2 = r2 }
}
failAfter(20 seconds) {
t1.join()
t2.join()
}
assert(state.t1Result1 === 500L)
assert(state.t2Result1 === 500L)
assert(state.t1Result2 === 0L)
assert(state.t2Result2 === 0L)
}
test("threads cannot grow past 1 / N") {
// Two threads request 250 bytes first, wait for each other to get it, and then request
// 500 more; we should only grant 250 bytes to each of them on this second request
val manager = new ShuffleMemoryManager(1000L)
class State {
var t1Result1 = -1L
var t2Result1 = -1L
var t1Result2 = -1L
var t2Result2 = -1L
}
val state = new State
val t1 = startThread("t1") {
val r1 = manager.tryToAcquire(250L)
state.synchronized {
state.t1Result1 = r1
state.notifyAll()
while (state.t2Result1 === -1L) {
state.wait()
}
}
val r2 = manager.tryToAcquire(500L)
state.synchronized { state.t1Result2 = r2 }
}
val t2 = startThread("t2") {
val r1 = manager.tryToAcquire(250L)
state.synchronized {
state.t2Result1 = r1
state.notifyAll()
while (state.t1Result1 === -1L) {
state.wait()
}
}
val r2 = manager.tryToAcquire(500L)
state.synchronized { state.t2Result2 = r2 }
}
failAfter(20 seconds) {
t1.join()
t2.join()
}
assert(state.t1Result1 === 250L)
assert(state.t2Result1 === 250L)
assert(state.t1Result2 === 250L)
assert(state.t2Result2 === 250L)
}
test("threads can block to get at least 1 / 2N memory") {
// t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
// for a bit and releases 250 bytes, which should then be greanted to t2. Further requests
// by t2 will return false right away because it now has 1 / 2N of the memory.
val manager = new ShuffleMemoryManager(1000L)
class State {
var t1Requested = false
var t2Requested = false
var t1Result = -1L
var t2Result = -1L
var t2Result2 = -1L
var t2WaitTime = 0L
}
val state = new State
val t1 = startThread("t1") {
state.synchronized {
state.t1Result = manager.tryToAcquire(1000L)
state.t1Requested = true
state.notifyAll()
while (!state.t2Requested) {
state.wait()
}
}
// Sleep a bit before releasing our memory; this is hacky but it would be difficult to make
// sure the other thread blocks for some time otherwise
Thread.sleep(300)
manager.release(250L)
}
val t2 = startThread("t2") {
state.synchronized {
while (!state.t1Requested) {
state.wait()
}
state.t2Requested = true
state.notifyAll()
}
val startTime = System.currentTimeMillis()
val result = manager.tryToAcquire(250L)
val endTime = System.currentTimeMillis()
state.synchronized {
state.t2Result = result
// A second call should return 0 because we're now already at 1 / 2N
state.t2Result2 = manager.tryToAcquire(100L)
state.t2WaitTime = endTime - startTime
}
}
failAfter(20 seconds) {
t1.join()
t2.join()
}
// Both threads should've been able to acquire their memory; the second one will have waited
// until the first one acquired 1000 bytes and then released 250
state.synchronized {
assert(state.t1Result === 1000L, "t1 could not allocate memory")
assert(state.t2Result === 250L, "t2 could not allocate memory")
assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})")
assert(state.t2Result2 === 0L, "t1 got extra memory the second time")
}
}
test("releaseMemoryForThisThread") {
// t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
// for a bit and releases all its memory. t2 should now be able to grab all the memory.
val manager = new ShuffleMemoryManager(1000L)
class State {
var t1Requested = false
var t2Requested = false
var t1Result = -1L
var t2Result1 = -1L
var t2Result2 = -1L
var t2Result3 = -1L
var t2WaitTime = 0L
}
val state = new State
val t1 = startThread("t1") {
state.synchronized {
state.t1Result = manager.tryToAcquire(1000L)
state.t1Requested = true
state.notifyAll()
while (!state.t2Requested) {
state.wait()
}
}
// Sleep a bit before releasing our memory; this is hacky but it would be difficult to make
// sure the other thread blocks for some time otherwise
Thread.sleep(300)
manager.releaseMemoryForThisThread()
}
val t2 = startThread("t2") {
state.synchronized {
while (!state.t1Requested) {
state.wait()
}
state.t2Requested = true
state.notifyAll()
}
val startTime = System.currentTimeMillis()
val r1 = manager.tryToAcquire(500L)
val endTime = System.currentTimeMillis()
val r2 = manager.tryToAcquire(500L)
val r3 = manager.tryToAcquire(500L)
state.synchronized {
state.t2Result1 = r1
state.t2Result2 = r2
state.t2Result3 = r3
state.t2WaitTime = endTime - startTime
}
}
failAfter(20 seconds) {
t1.join()
t2.join()
}
// Both threads should've been able to acquire their memory; the second one will have waited
// until the first one acquired 1000 bytes and then released all of it
state.synchronized {
assert(state.t1Result === 1000L, "t1 could not allocate memory")
assert(state.t2Result1 === 500L, "t2 didn't get 500 bytes the first time")
assert(state.t2Result2 === 500L, "t2 didn't get 500 bytes the second time")
assert(state.t2Result3 === 0L, s"t2 got more bytes a third time (${state.t2Result3})")
assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})")
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment