Skip to content
Snippets Groups Projects
Commit 85e654c5 authored by Josh Rosen's avatar Josh Rosen
Browse files

[SPARK-10984] Simplify *MemoryManager class structure

This patch refactors the MemoryManager class structure. After #9000, Spark had the following classes:

- MemoryManager
- StaticMemoryManager
- ExecutorMemoryManager
- TaskMemoryManager
- ShuffleMemoryManager

This is fairly confusing. To simplify things, this patch consolidates several of these classes:

- ShuffleMemoryManager and ExecutorMemoryManager were merged into MemoryManager.
- TaskMemoryManager is moved into Spark Core.

**Key changes and tasks**:

- [x] Merge ExecutorMemoryManager into MemoryManager.
  - [x] Move pooling logic into Allocator.
- [x] Move TaskMemoryManager from `spark-unsafe` to `spark-core`.
- [x] Refactor the existing Tungsten TaskMemoryManager interactions so Tungsten code use only this and not both this and ShuffleMemoryManager.
- [x] Refactor non-Tungsten code to use the TaskMemoryManager instead of ShuffleMemoryManager.
- [x] Merge ShuffleMemoryManager into MemoryManager.
  - [x] Move code
  - [x] ~~Simplify 1/n calculation.~~ **Will defer to followup, since this needs more work.**
- [x] Port ShuffleMemoryManagerSuite tests.
- [x] Move classes from `unsafe` package to `memory` package.
- [ ] Figure out how to handle the hacky use of the memory managers in HashedRelation's broadcast variable construction.
- [x] Test porting and cleanup: several tests relied on mock functionality (such as `TestShuffleMemoryManager.markAsOutOfMemory`) which has been changed or broken during the memory manager consolidation
  - [x] AbstractBytesToBytesMapSuite
  - [x] UnsafeExternalSorterSuite
  - [x] UnsafeFixedWidthAggregationMapSuite
  - [x] UnsafeKVExternalSorterSuite

**Compatiblity notes**:

- This patch introduces breaking changes in `ExternalAppendOnlyMap`, which is marked as `DevloperAPI` (likely for legacy reasons): this class now cannot be used outside of a task.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #9127 from JoshRosen/SPARK-10984.
parent 63accc79
No related branches found
No related tags found
No related merge requests found
Showing
with 323 additions and 428 deletions
......@@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.spark.unsafe.memory;
package org.apache.spark.memory;
import java.util.*;
......@@ -23,6 +23,8 @@ import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.unsafe.memory.MemoryBlock;
/**
* Manages the memory allocated by an individual task.
* <p>
......@@ -87,13 +89,9 @@ public class TaskMemoryManager {
*/
private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE);
/**
* Tracks memory allocated with {@link TaskMemoryManager#allocate(long)}, used to detect / clean
* up leaked memory.
*/
private final HashSet<MemoryBlock> allocatedNonPageMemory = new HashSet<MemoryBlock>();
private final MemoryManager memoryManager;
private final ExecutorMemoryManager executorMemoryManager;
private final long taskAttemptId;
/**
* Tracks whether we're in-heap or off-heap. For off-heap, we short-circuit most of these methods
......@@ -103,16 +101,38 @@ public class TaskMemoryManager {
private final boolean inHeap;
/**
* Construct a new MemoryManager.
* Construct a new TaskMemoryManager.
*/
public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) {
this.inHeap = executorMemoryManager.inHeap;
this.executorMemoryManager = executorMemoryManager;
public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) {
this.inHeap = memoryManager.tungstenMemoryIsAllocatedInHeap();
this.memoryManager = memoryManager;
this.taskAttemptId = taskAttemptId;
}
/**
* Acquire N bytes of memory for execution, evicting cached blocks if necessary.
* @return number of bytes successfully granted (<= N).
*/
public long acquireExecutionMemory(long size) {
return memoryManager.acquireExecutionMemory(size, taskAttemptId);
}
/**
* Release N bytes of execution memory.
*/
public void releaseExecutionMemory(long size) {
memoryManager.releaseExecutionMemory(size, taskAttemptId);
}
public long pageSizeBytes() {
return memoryManager.pageSizeBytes();
}
/**
* Allocate a block of memory that will be tracked in the MemoryManager's page table; this is
* intended for allocating large blocks of memory that will be shared between operators.
* intended for allocating large blocks of Tungsten memory that will be shared between operators.
*
* Returns `null` if there was not enough memory to allocate the page.
*/
public MemoryBlock allocatePage(long size) {
if (size > MAXIMUM_PAGE_SIZE_BYTES) {
......@@ -129,7 +149,15 @@ public class TaskMemoryManager {
}
allocatedPages.set(pageNumber);
}
final MemoryBlock page = executorMemoryManager.allocate(size);
final long acquiredExecutionMemory = acquireExecutionMemory(size);
if (acquiredExecutionMemory != size) {
releaseExecutionMemory(acquiredExecutionMemory);
synchronized (this) {
allocatedPages.clear(pageNumber);
}
return null;
}
final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(size);
page.pageNumber = pageNumber;
pageTable[pageNumber] = page;
if (logger.isTraceEnabled()) {
......@@ -152,45 +180,16 @@ public class TaskMemoryManager {
if (logger.isTraceEnabled()) {
logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size());
}
// Cannot access a page once it's freed.
executorMemoryManager.free(page);
}
/**
* Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed
* to be zeroed out (call `zero()` on the result if this is necessary). This method is intended
* to be used for allocating operators' internal data structures. For data pages that you want to
* exchange between operators, consider using {@link TaskMemoryManager#allocatePage(long)}, since
* that will enable intra-memory pointers (see
* {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} and this class's
* top-level Javadoc for more details).
*/
public MemoryBlock allocate(long size) throws OutOfMemoryError {
assert(size > 0) : "Size must be positive, but got " + size;
final MemoryBlock memory = executorMemoryManager.allocate(size);
synchronized(allocatedNonPageMemory) {
allocatedNonPageMemory.add(memory);
}
return memory;
}
/**
* Free memory allocated by {@link TaskMemoryManager#allocate(long)}.
*/
public void free(MemoryBlock memory) {
assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()";
executorMemoryManager.free(memory);
synchronized(allocatedNonPageMemory) {
final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory);
assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!";
}
long pageSize = page.size();
memoryManager.tungstenMemoryAllocator().free(page);
releaseExecutionMemory(pageSize);
}
/**
* Given a memory page and offset within that page, encode this address into a 64-bit long.
* This address will remain valid as long as the corresponding page has not been freed.
*
* @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}.
* @param page a data page allocated by {@link TaskMemoryManager#allocatePage(long)}/
* @param offsetInPage an offset in this page which incorporates the base offset. In other words,
* this should be the value that you would pass as the base offset into an
* UNSAFE call (e.g. page.baseOffset() + something).
......@@ -270,17 +269,15 @@ public class TaskMemoryManager {
}
}
synchronized (allocatedNonPageMemory) {
final Iterator<MemoryBlock> iter = allocatedNonPageMemory.iterator();
while (iter.hasNext()) {
final MemoryBlock memory = iter.next();
freedBytes += memory.size();
// We don't call free() here because that calls Set.remove, which would lead to a
// ConcurrentModificationException here.
executorMemoryManager.free(memory);
iter.remove();
}
}
freedBytes += memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId);
return freedBytes;
}
/**
* Returns the memory consumption, in bytes, for the current task
*/
public long getMemoryConsumptionForThisTask() {
return memoryManager.getExecutionMemoryUsageForTask(taskAttemptId);
}
}
......@@ -17,6 +17,8 @@
package org.apache.spark.shuffle.sort;
import org.apache.spark.memory.TaskMemoryManager;
/**
* Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.
* <p>
......@@ -26,7 +28,7 @@ package org.apache.spark.shuffle.sort;
* </pre>
* This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that
* our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the
* 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this
* 13-bit page numbers assigned by {@link TaskMemoryManager}), this
* implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task.
* <p>
* Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this
......
......@@ -33,14 +33,13 @@ import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.serializer.DummySerializerInstance;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.DiskBlockObjectWriter;
import org.apache.spark.storage.TempShuffleBlockId;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;
/**
......@@ -72,7 +71,6 @@ final class ShuffleExternalSorter {
@VisibleForTesting
final int maxRecordSizeBytes;
private final TaskMemoryManager taskMemoryManager;
private final ShuffleMemoryManager shuffleMemoryManager;
private final BlockManager blockManager;
private final TaskContext taskContext;
private final ShuffleWriteMetrics writeMetrics;
......@@ -105,7 +103,6 @@ final class ShuffleExternalSorter {
public ShuffleExternalSorter(
TaskMemoryManager memoryManager,
ShuffleMemoryManager shuffleMemoryManager,
BlockManager blockManager,
TaskContext taskContext,
int initialSize,
......@@ -113,7 +110,6 @@ final class ShuffleExternalSorter {
SparkConf conf,
ShuffleWriteMetrics writeMetrics) throws IOException {
this.taskMemoryManager = memoryManager;
this.shuffleMemoryManager = shuffleMemoryManager;
this.blockManager = blockManager;
this.taskContext = taskContext;
this.initialSize = initialSize;
......@@ -124,7 +120,7 @@ final class ShuffleExternalSorter {
this.numElementsForSpillThreshold =
conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
this.pageSizeBytes = (int) Math.min(
PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes());
PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, taskMemoryManager.pageSizeBytes());
this.maxRecordSizeBytes = pageSizeBytes - 4;
this.writeMetrics = writeMetrics;
initializeForWriting();
......@@ -140,9 +136,9 @@ final class ShuffleExternalSorter {
private void initializeForWriting() throws IOException {
// TODO: move this sizing calculation logic into a static method of sorter:
final long memoryRequested = initialSize * 8L;
final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryRequested);
if (memoryAcquired != memoryRequested) {
shuffleMemoryManager.release(memoryAcquired);
taskMemoryManager.releaseExecutionMemory(memoryAcquired);
throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
}
......@@ -272,6 +268,7 @@ final class ShuffleExternalSorter {
*/
@VisibleForTesting
void spill() throws IOException {
assert(inMemSorter != null);
logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
Thread.currentThread().getId(),
Utils.bytesToString(getMemoryUsage()),
......@@ -281,7 +278,7 @@ final class ShuffleExternalSorter {
writeSortedFile(false);
final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage();
inMemSorter = null;
shuffleMemoryManager.release(inMemSorterMemoryUsage);
taskMemoryManager.releaseExecutionMemory(inMemSorterMemoryUsage);
final long spillSize = freeMemory();
taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
......@@ -316,9 +313,13 @@ final class ShuffleExternalSorter {
long memoryFreed = 0;
for (MemoryBlock block : allocatedPages) {
taskMemoryManager.freePage(block);
shuffleMemoryManager.release(block.size());
memoryFreed += block.size();
}
if (inMemSorter != null) {
long sorterMemoryUsage = inMemSorter.getMemoryUsage();
inMemSorter = null;
taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage);
}
allocatedPages.clear();
currentPage = null;
currentPagePosition = -1;
......@@ -337,8 +338,9 @@ final class ShuffleExternalSorter {
}
}
if (inMemSorter != null) {
shuffleMemoryManager.release(inMemSorter.getMemoryUsage());
long sorterMemoryUsage = inMemSorter.getMemoryUsage();
inMemSorter = null;
taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage);
}
}
......@@ -353,21 +355,20 @@ final class ShuffleExternalSorter {
logger.debug("Attempting to expand sort pointer array");
final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage();
final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryToGrowPointerArray);
if (memoryAcquired < memoryToGrowPointerArray) {
shuffleMemoryManager.release(memoryAcquired);
taskMemoryManager.releaseExecutionMemory(memoryAcquired);
spill();
} else {
inMemSorter.expandPointerArray();
shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
taskMemoryManager.releaseExecutionMemory(oldPointerArrayMemoryUsage);
}
}
}
/**
* Allocates more memory in order to insert an additional record. This will request additional
* memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
* obtained.
* memory from the memory manager and spill if the requested memory can not be obtained.
*
* @param requiredSpace the required space in the data page, in bytes, including space for storing
* the record size. This must be less than or equal to the page size (records
......@@ -386,17 +387,14 @@ final class ShuffleExternalSorter {
throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
pageSizeBytes + ")");
} else {
final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryAcquired < pageSizeBytes) {
shuffleMemoryManager.release(memoryAcquired);
currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
if (currentPage == null) {
spill();
final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryAcquiredAfterSpilling != pageSizeBytes) {
shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
if (currentPage == null) {
throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
}
}
currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
currentPagePosition = currentPage.getBaseOffset();
freeSpaceInCurrentPage = pageSizeBytes;
allocatedPages.add(currentPage);
......@@ -430,17 +428,14 @@ final class ShuffleExternalSorter {
long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
// The record is larger than the page size, so allocate a special overflow page just to hold
// that record.
final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize);
if (memoryGranted != overflowPageSize) {
shuffleMemoryManager.release(memoryGranted);
MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
if (overflowPage == null) {
spill();
final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize);
if (memoryGrantedAfterSpill != overflowPageSize) {
shuffleMemoryManager.release(memoryGrantedAfterSpill);
overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
if (overflowPage == null) {
throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
}
}
MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
allocatedPages.add(overflowPage);
dataPage = overflowPage;
dataPagePosition = overflowPage.getBaseOffset();
......
......@@ -49,12 +49,11 @@ import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
import org.apache.spark.memory.TaskMemoryManager;
@Private
public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
......@@ -69,7 +68,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private final BlockManager blockManager;
private final IndexShuffleBlockResolver shuffleBlockResolver;
private final TaskMemoryManager memoryManager;
private final ShuffleMemoryManager shuffleMemoryManager;
private final SerializerInstance serializer;
private final Partitioner partitioner;
private final ShuffleWriteMetrics writeMetrics;
......@@ -103,7 +101,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
BlockManager blockManager,
IndexShuffleBlockResolver shuffleBlockResolver,
TaskMemoryManager memoryManager,
ShuffleMemoryManager shuffleMemoryManager,
SerializedShuffleHandle<K, V> handle,
int mapId,
TaskContext taskContext,
......@@ -117,7 +114,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
this.blockManager = blockManager;
this.shuffleBlockResolver = shuffleBlockResolver;
this.memoryManager = memoryManager;
this.shuffleMemoryManager = shuffleMemoryManager;
this.mapId = mapId;
final ShuffleDependency<K, V, V> dep = handle.dependency();
this.shuffleId = dep.shuffleId();
......@@ -197,7 +193,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
assert (sorter == null);
sorter = new ShuffleExternalSorter(
memoryManager,
shuffleMemoryManager,
blockManager,
taskContext,
INITIAL_SORT_BUFFER_SIZE,
......
......@@ -26,7 +26,6 @@ import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.array.LongArray;
......@@ -34,7 +33,7 @@ import org.apache.spark.unsafe.bitset.BitSet;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.MemoryLocation;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
import org.apache.spark.memory.TaskMemoryManager;
/**
* An append-only hash map where keys and values are contiguous regions of bytes.
......@@ -70,8 +69,6 @@ public final class BytesToBytesMap {
private final TaskMemoryManager taskMemoryManager;
private final ShuffleMemoryManager shuffleMemoryManager;
/**
* A linked list for tracking all allocated data pages so that we can free all of our memory.
*/
......@@ -169,13 +166,11 @@ public final class BytesToBytesMap {
public BytesToBytesMap(
TaskMemoryManager taskMemoryManager,
ShuffleMemoryManager shuffleMemoryManager,
int initialCapacity,
double loadFactor,
long pageSizeBytes,
boolean enablePerfMetrics) {
this.taskMemoryManager = taskMemoryManager;
this.shuffleMemoryManager = shuffleMemoryManager;
this.loadFactor = loadFactor;
this.loc = new Location();
this.pageSizeBytes = pageSizeBytes;
......@@ -201,21 +196,18 @@ public final class BytesToBytesMap {
public BytesToBytesMap(
TaskMemoryManager taskMemoryManager,
ShuffleMemoryManager shuffleMemoryManager,
int initialCapacity,
long pageSizeBytes) {
this(taskMemoryManager, shuffleMemoryManager, initialCapacity, 0.70, pageSizeBytes, false);
this(taskMemoryManager, initialCapacity, 0.70, pageSizeBytes, false);
}
public BytesToBytesMap(
TaskMemoryManager taskMemoryManager,
ShuffleMemoryManager shuffleMemoryManager,
int initialCapacity,
long pageSizeBytes,
boolean enablePerfMetrics) {
this(
taskMemoryManager,
shuffleMemoryManager,
initialCapacity,
0.70,
pageSizeBytes,
......@@ -260,7 +252,6 @@ public final class BytesToBytesMap {
if (destructive && currentPage != null) {
dataPagesIterator.remove();
this.bmap.taskMemoryManager.freePage(currentPage);
this.bmap.shuffleMemoryManager.release(currentPage.size());
}
currentPage = dataPagesIterator.next();
pageBaseObject = currentPage.getBaseObject();
......@@ -572,14 +563,12 @@ public final class BytesToBytesMap {
if (useOverflowPage) {
// The record is larger than the page size, so allocate a special overflow page just to hold
// that record.
final long memoryRequested = requiredSize + 8;
final long memoryGranted = shuffleMemoryManager.tryToAcquire(memoryRequested);
if (memoryGranted != memoryRequested) {
shuffleMemoryManager.release(memoryGranted);
logger.debug("Failed to acquire {} bytes of memory", memoryRequested);
final long overflowPageSize = requiredSize + 8;
MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
if (overflowPage == null) {
logger.debug("Failed to acquire {} bytes of memory", overflowPageSize);
return false;
}
MemoryBlock overflowPage = taskMemoryManager.allocatePage(memoryRequested);
dataPages.add(overflowPage);
dataPage = overflowPage;
dataPageBaseObject = overflowPage.getBaseObject();
......@@ -655,17 +644,15 @@ public final class BytesToBytesMap {
}
/**
* Acquire a new page from the {@link ShuffleMemoryManager}.
* Acquire a new page from the memory manager.
* @return whether there is enough space to allocate the new page.
*/
private boolean acquireNewPage() {
final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryGranted != pageSizeBytes) {
shuffleMemoryManager.release(memoryGranted);
MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
if (newPage == null) {
logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
return false;
}
MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
dataPages.add(newPage);
pageCursor = 0;
currentDataPage = newPage;
......@@ -705,7 +692,6 @@ public final class BytesToBytesMap {
MemoryBlock dataPage = dataPagesIterator.next();
dataPagesIterator.remove();
taskMemoryManager.freePage(dataPage);
shuffleMemoryManager.release(dataPage.size());
}
assert(dataPages.isEmpty());
}
......@@ -714,10 +700,6 @@ public final class BytesToBytesMap {
return taskMemoryManager;
}
public ShuffleMemoryManager getShuffleMemoryManager() {
return shuffleMemoryManager;
}
public long getPageSizeBytes() {
return pageSizeBytes;
}
......
......@@ -17,9 +17,11 @@
package org.apache.spark.util.collection.unsafe.sort;
import org.apache.spark.memory.TaskMemoryManager;
final class RecordPointerAndKeyPrefix {
/**
* A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a
* A pointer to a record; see {@link TaskMemoryManager} for a
* description of how these addresses are encoded.
*/
public long recordPointer;
......
......@@ -32,12 +32,11 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;
/**
......@@ -52,7 +51,6 @@ public final class UnsafeExternalSorter {
private final RecordComparator recordComparator;
private final int initialSize;
private final TaskMemoryManager taskMemoryManager;
private final ShuffleMemoryManager shuffleMemoryManager;
private final BlockManager blockManager;
private final TaskContext taskContext;
private ShuffleWriteMetrics writeMetrics;
......@@ -82,7 +80,6 @@ public final class UnsafeExternalSorter {
public static UnsafeExternalSorter createWithExistingInMemorySorter(
TaskMemoryManager taskMemoryManager,
ShuffleMemoryManager shuffleMemoryManager,
BlockManager blockManager,
TaskContext taskContext,
RecordComparator recordComparator,
......@@ -90,26 +87,24 @@ public final class UnsafeExternalSorter {
int initialSize,
long pageSizeBytes,
UnsafeInMemorySorter inMemorySorter) throws IOException {
return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager,
return new UnsafeExternalSorter(taskMemoryManager, blockManager,
taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter);
}
public static UnsafeExternalSorter create(
TaskMemoryManager taskMemoryManager,
ShuffleMemoryManager shuffleMemoryManager,
BlockManager blockManager,
TaskContext taskContext,
RecordComparator recordComparator,
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes) throws IOException {
return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager,
return new UnsafeExternalSorter(taskMemoryManager, blockManager,
taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null);
}
private UnsafeExternalSorter(
TaskMemoryManager taskMemoryManager,
ShuffleMemoryManager shuffleMemoryManager,
BlockManager blockManager,
TaskContext taskContext,
RecordComparator recordComparator,
......@@ -118,7 +113,6 @@ public final class UnsafeExternalSorter {
long pageSizeBytes,
@Nullable UnsafeInMemorySorter existingInMemorySorter) throws IOException {
this.taskMemoryManager = taskMemoryManager;
this.shuffleMemoryManager = shuffleMemoryManager;
this.blockManager = blockManager;
this.taskContext = taskContext;
this.recordComparator = recordComparator;
......@@ -261,7 +255,6 @@ public final class UnsafeExternalSorter {
long memoryFreed = 0;
for (MemoryBlock block : allocatedPages) {
taskMemoryManager.freePage(block);
shuffleMemoryManager.release(block.size());
memoryFreed += block.size();
}
// TODO: track in-memory sorter memory usage (SPARK-10474)
......@@ -309,8 +302,7 @@ public final class UnsafeExternalSorter {
/**
* Allocates more memory in order to insert an additional record. This will request additional
* memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
* obtained.
* memory from the memory manager and spill if the requested memory can not be obtained.
*
* @param requiredSpace the required space in the data page, in bytes, including space for storing
* the record size. This must be less than or equal to the page size (records
......@@ -335,23 +327,20 @@ public final class UnsafeExternalSorter {
}
/**
* Acquire a new page from the {@link ShuffleMemoryManager}.
* Acquire a new page from the memory manager.
*
* If there is not enough space to allocate the new page, spill all existing ones
* and try again. If there is still not enough space, report error to the caller.
*/
private void acquireNewPage() throws IOException {
final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryAcquired < pageSizeBytes) {
shuffleMemoryManager.release(memoryAcquired);
currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
if (currentPage == null) {
spill();
final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryAcquiredAfterSpilling != pageSizeBytes) {
shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
if (currentPage == null) {
throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
}
}
currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
currentPagePosition = currentPage.getBaseOffset();
freeSpaceInCurrentPage = pageSizeBytes;
allocatedPages.add(currentPage);
......@@ -379,17 +368,14 @@ public final class UnsafeExternalSorter {
long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
// The record is larger than the page size, so allocate a special overflow page just to hold
// that record.
final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize);
if (memoryGranted != overflowPageSize) {
shuffleMemoryManager.release(memoryGranted);
MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
if (overflowPage == null) {
spill();
final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize);
if (memoryGrantedAfterSpill != overflowPageSize) {
shuffleMemoryManager.release(memoryGrantedAfterSpill);
overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
if (overflowPage == null) {
throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
}
}
MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
allocatedPages.add(overflowPage);
dataPage = overflowPage;
dataPagePosition = overflowPage.getBaseOffset();
......@@ -441,17 +427,14 @@ public final class UnsafeExternalSorter {
long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
// The record is larger than the page size, so allocate a special overflow page just to hold
// that record.
final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize);
if (memoryGranted != overflowPageSize) {
shuffleMemoryManager.release(memoryGranted);
MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
if (overflowPage == null) {
spill();
final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize);
if (memoryGrantedAfterSpill != overflowPageSize) {
shuffleMemoryManager.release(memoryGrantedAfterSpill);
overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
if (overflowPage == null) {
throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
}
}
MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
allocatedPages.add(overflowPage);
dataPage = overflowPage;
dataPagePosition = overflowPage.getBaseOffset();
......
......@@ -21,7 +21,7 @@ import java.util.Comparator;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.util.collection.Sorter;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
import org.apache.spark.memory.TaskMemoryManager;
/**
* Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records
......
......@@ -38,9 +38,8 @@ import org.apache.spark.rpc.akka.AkkaRpcEnv
import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus}
import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage._
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator}
import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils}
/**
......@@ -70,10 +69,7 @@ class SparkEnv (
val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
// TODO: unify these *MemoryManager classes (SPARK-10984)
val memoryManager: MemoryManager,
val shuffleMemoryManager: ShuffleMemoryManager,
val executorMemoryManager: ExecutorMemoryManager,
val outputCommitCoordinator: OutputCommitCoordinator,
val conf: SparkConf) extends Logging {
......@@ -340,13 +336,11 @@ object SparkEnv extends Logging {
val useLegacyMemoryManager = conf.getBoolean("spark.memory.useLegacyMode", false)
val memoryManager: MemoryManager =
if (useLegacyMemoryManager) {
new StaticMemoryManager(conf)
new StaticMemoryManager(conf, numUsableCores)
} else {
new UnifiedMemoryManager(conf)
new UnifiedMemoryManager(conf, numUsableCores)
}
val shuffleMemoryManager = ShuffleMemoryManager.create(conf, memoryManager, numUsableCores)
val blockTransferService = new NettyBlockTransferService(conf, securityManager, numUsableCores)
val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint(
......@@ -405,15 +399,6 @@ object SparkEnv extends Logging {
new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator))
outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef)
val executorMemoryManager: ExecutorMemoryManager = {
val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) {
MemoryAllocator.UNSAFE
} else {
MemoryAllocator.HEAP
}
new ExecutorMemoryManager(allocator)
}
val envInstance = new SparkEnv(
executorId,
rpcEnv,
......@@ -431,8 +416,6 @@ object SparkEnv extends Logging {
sparkFilesDir,
metricsSystem,
memoryManager,
shuffleMemoryManager,
executorMemoryManager,
outputCommitCoordinator,
conf)
......
......@@ -21,8 +21,8 @@ import java.io.Serializable
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.Source
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.TaskCompletionListener
......
......@@ -20,9 +20,9 @@ package org.apache.spark
import scala.collection.mutable.{ArrayBuffer, HashMap}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.metrics.source.Source
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
private[spark] class TaskContextImpl(
......
......@@ -29,10 +29,10 @@ import scala.util.control.NonFatal
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util._
/**
......@@ -179,7 +179,7 @@ private[spark] class Executor(
}
override def run(): Unit = {
val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
val deserializeStartTime = System.currentTimeMillis()
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = env.closureSerializer.newInstance()
......
......@@ -17,20 +17,38 @@
package org.apache.spark.memory
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.Logging
import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore}
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.{SparkException, TaskContext, SparkConf, Logging}
import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore}
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.memory.MemoryAllocator
/**
* An abstract memory manager that enforces how memory is shared between execution and storage.
*
* In this context, execution memory refers to that used for computation in shuffles, joins,
* sorts and aggregations, while storage memory refers to that used for caching and propagating
* internal data across the cluster. There exists one of these per JVM.
* internal data across the cluster. There exists one MemoryManager per JVM.
*
* The MemoryManager abstract base class itself implements policies for sharing execution memory
* between tasks; it tries to ensure that each task gets a reasonable share of memory, instead of
* some task ramping up to a large amount first and then causing others to spill to disk repeatedly.
* If there are N tasks, it ensures that each task can acquire at least 1 / 2N of the memory
* before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the
* set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever
* this set changes. This is all done by synchronizing access to mutable state and using wait() and
* notifyAll() to signal changes to callers. Prior to Spark 1.6, this arbitration of memory across
* tasks was performed by the ShuffleMemoryManager.
*/
private[spark] abstract class MemoryManager extends Logging {
private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) extends Logging {
// -- Methods related to memory allocation policies and bookkeeping ------------------------------
// The memory store used to evict cached blocks
private var _memoryStore: MemoryStore = _
......@@ -42,8 +60,10 @@ private[spark] abstract class MemoryManager extends Logging {
}
// Amount of execution/storage memory in use, accesses must be synchronized on `this`
protected var _executionMemoryUsed: Long = 0
protected var _storageMemoryUsed: Long = 0
@GuardedBy("this") protected var _executionMemoryUsed: Long = 0
@GuardedBy("this") protected var _storageMemoryUsed: Long = 0
// Map from taskAttemptId -> memory consumption in bytes
@GuardedBy("this") private val executionMemoryForTask = new mutable.HashMap[Long, Long]()
/**
* Set the [[MemoryStore]] used by this manager to evict cached blocks.
......@@ -65,15 +85,6 @@ private[spark] abstract class MemoryManager extends Logging {
// TODO: avoid passing evicted blocks around to simplify method signatures (SPARK-10985)
/**
* Acquire N bytes of memory for execution, evicting cached blocks if necessary.
* Blocks evicted in the process, if any, are added to `evictedBlocks`.
* @return number of bytes successfully granted (<= N).
*/
def acquireExecutionMemory(
numBytes: Long,
evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long
/**
* Acquire N bytes of memory to cache the given block, evicting existing ones if necessary.
* Blocks evicted in the process, if any, are added to `evictedBlocks`.
......@@ -102,9 +113,92 @@ private[spark] abstract class MemoryManager extends Logging {
}
/**
* Release N bytes of execution memory.
* Acquire N bytes of memory for execution, evicting cached blocks if necessary.
* Blocks evicted in the process, if any, are added to `evictedBlocks`.
* @return number of bytes successfully granted (<= N).
*/
@VisibleForTesting
private[memory] def doAcquireExecutionMemory(
numBytes: Long,
evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long
/**
* Try to acquire up to `numBytes` of execution memory for the current task and return the number
* of bytes obtained, or 0 if none can be allocated.
*
* This call may block until there is enough free memory in some situations, to make sure each
* task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of
* active tasks) before it is forced to spill. This can happen if the number of tasks increase
* but an older task had a lot of memory already.
*
* Subclasses should override `doAcquireExecutionMemory` in order to customize the policies
* that control global sharing of memory between execution and storage.
*/
def releaseExecutionMemory(numBytes: Long): Unit = synchronized {
private[memory]
final def acquireExecutionMemory(numBytes: Long, taskAttemptId: Long): Long = synchronized {
assert(numBytes > 0, "invalid number of bytes requested: " + numBytes)
// Add this task to the taskMemory map just so we can keep an accurate count of the number
// of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire
if (!executionMemoryForTask.contains(taskAttemptId)) {
executionMemoryForTask(taskAttemptId) = 0L
// This will later cause waiting tasks to wake up and check numTasks again
notifyAll()
}
// Once the cross-task memory allocation policy has decided to grant more memory to a task,
// this method is called in order to actually obtain that execution memory, potentially
// triggering eviction of storage memory:
def acquire(toGrant: Long): Long = synchronized {
val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
val acquired = doAcquireExecutionMemory(toGrant, evictedBlocks)
// Register evicted blocks, if any, with the active task metrics
Option(TaskContext.get()).foreach { tc =>
val metrics = tc.taskMetrics()
val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]())
metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq)
}
executionMemoryForTask(taskAttemptId) += acquired
acquired
}
// Keep looping until we're either sure that we don't want to grant this request (because this
// task would have more than 1 / numActiveTasks of the memory) or we have enough free
// memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)).
// TODO: simplify this to limit each task to its own slot
while (true) {
val numActiveTasks = executionMemoryForTask.keys.size
val curMem = executionMemoryForTask(taskAttemptId)
val freeMemory = maxExecutionMemory - executionMemoryForTask.values.sum
// How much we can grant this task; don't let it grow to more than 1 / numActiveTasks;
// don't let it be negative
val maxToGrant =
math.min(numBytes, math.max(0, (maxExecutionMemory / numActiveTasks) - curMem))
// Only give it as much memory as is free, which might be none if it reached 1 / numTasks
val toGrant = math.min(maxToGrant, freeMemory)
if (curMem < maxExecutionMemory / (2 * numActiveTasks)) {
// We want to let each task get at least 1 / (2 * numActiveTasks) before blocking;
// if we can't give it this much now, wait for other tasks to free up memory
// (this happens if older tasks allocated lots of memory before N grew)
if (
freeMemory >= math.min(maxToGrant, maxExecutionMemory / (2 * numActiveTasks) - curMem)) {
return acquire(toGrant)
} else {
logInfo(
s"TID $taskAttemptId waiting for at least 1/2N of execution memory pool to be free")
wait()
}
} else {
return acquire(toGrant)
}
}
0L // Never reached
}
@VisibleForTesting
private[memory] def releaseExecutionMemory(numBytes: Long): Unit = synchronized {
if (numBytes > _executionMemoryUsed) {
logWarning(s"Attempted to release $numBytes bytes of execution " +
s"memory when we only have ${_executionMemoryUsed} bytes")
......@@ -114,6 +208,36 @@ private[spark] abstract class MemoryManager extends Logging {
}
}
/**
* Release numBytes of execution memory belonging to the given task.
*/
private[memory]
final def releaseExecutionMemory(numBytes: Long, taskAttemptId: Long): Unit = synchronized {
val curMem = executionMemoryForTask.getOrElse(taskAttemptId, 0L)
if (curMem < numBytes) {
throw new SparkException(
s"Internal error: release called on $numBytes bytes but task only has $curMem")
}
if (executionMemoryForTask.contains(taskAttemptId)) {
executionMemoryForTask(taskAttemptId) -= numBytes
if (executionMemoryForTask(taskAttemptId) <= 0) {
executionMemoryForTask.remove(taskAttemptId)
}
releaseExecutionMemory(numBytes)
}
notifyAll() // Notify waiters in acquireExecutionMemory() that memory has been freed
}
/**
* Release all memory for the given task and mark it as inactive (e.g. when a task ends).
* @return the number of bytes freed.
*/
private[memory] def releaseAllExecutionMemoryForTask(taskAttemptId: Long): Long = synchronized {
val numBytesToFree = getExecutionMemoryUsageForTask(taskAttemptId)
releaseExecutionMemory(numBytesToFree, taskAttemptId)
numBytesToFree
}
/**
* Release N bytes of storage memory.
*/
......@@ -155,4 +279,43 @@ private[spark] abstract class MemoryManager extends Logging {
_storageMemoryUsed
}
/**
* Returns the execution memory consumption, in bytes, for the given task.
*/
private[memory] def getExecutionMemoryUsageForTask(taskAttemptId: Long): Long = synchronized {
executionMemoryForTask.getOrElse(taskAttemptId, 0L)
}
// -- Fields related to Tungsten managed memory -------------------------------------------------
/**
* The default page size, in bytes.
*
* If user didn't explicitly set "spark.buffer.pageSize", we figure out the default value
* by looking at the number of cores available to the process, and the total amount of memory,
* and then divide it by a factor of safety.
*/
val pageSizeBytes: Long = {
val minPageSize = 1L * 1024 * 1024 // 1MB
val maxPageSize = 64L * minPageSize // 64MB
val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors()
// Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case
val safetyFactor = 16
val size = ByteArrayMethods.nextPowerOf2(maxExecutionMemory / cores / safetyFactor)
val default = math.min(maxPageSize, math.max(minPageSize, size))
conf.getSizeAsBytes("spark.buffer.pageSize", default)
}
/**
* Tracks whether Tungsten memory will be allocated on the JVM heap or off-heap using
* sun.misc.Unsafe.
*/
final val tungstenMemoryIsAllocatedInHeap: Boolean =
!conf.getBoolean("spark.unsafe.offHeap", false)
/**
* Allocates memory for use by Unsafe/Tungsten code.
*/
private[memory] final val tungstenMemoryAllocator: MemoryAllocator =
if (tungstenMemoryIsAllocatedInHeap) MemoryAllocator.HEAP else MemoryAllocator.UNSAFE
}
......@@ -33,14 +33,16 @@ import org.apache.spark.storage.{BlockId, BlockStatus}
private[spark] class StaticMemoryManager(
conf: SparkConf,
override val maxExecutionMemory: Long,
override val maxStorageMemory: Long)
extends MemoryManager {
override val maxStorageMemory: Long,
numCores: Int)
extends MemoryManager(conf, numCores) {
def this(conf: SparkConf) {
def this(conf: SparkConf, numCores: Int) {
this(
conf,
StaticMemoryManager.getMaxExecutionMemory(conf),
StaticMemoryManager.getMaxStorageMemory(conf))
StaticMemoryManager.getMaxStorageMemory(conf),
numCores)
}
// Max number of bytes worth of blocks to evict when unrolling
......@@ -52,7 +54,7 @@ private[spark] class StaticMemoryManager(
* Acquire N bytes of memory for execution.
* @return number of bytes successfully granted (<= N).
*/
override def acquireExecutionMemory(
override def doAcquireExecutionMemory(
numBytes: Long,
evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized {
assert(numBytes >= 0)
......
......@@ -42,10 +42,14 @@ import org.apache.spark.storage.{BlockStatus, BlockId}
* up most of the storage space, in which case the new blocks will be evicted immediately
* according to their respective storage levels.
*/
private[spark] class UnifiedMemoryManager(conf: SparkConf, maxMemory: Long) extends MemoryManager {
private[spark] class UnifiedMemoryManager(
conf: SparkConf,
maxMemory: Long,
numCores: Int)
extends MemoryManager(conf, numCores) {
def this(conf: SparkConf) {
this(conf, UnifiedMemoryManager.getMaxMemory(conf))
def this(conf: SparkConf, numCores: Int) {
this(conf, UnifiedMemoryManager.getMaxMemory(conf), numCores)
}
/**
......@@ -91,7 +95,7 @@ private[spark] class UnifiedMemoryManager(conf: SparkConf, maxMemory: Long) exte
* Blocks evicted in the process, if any, are added to `evictedBlocks`.
* @return number of bytes successfully granted (<= N).
*/
override def acquireExecutionMemory(
private[memory] override def doAcquireExecutionMemory(
numBytes: Long,
evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized {
assert(numBytes >= 0)
......
......@@ -25,8 +25,8 @@ import scala.collection.mutable.HashMap
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.ByteBufferInputStream
import org.apache.spark.util.Utils
......@@ -89,10 +89,6 @@ private[spark] abstract class Task[T](
} finally {
context.markTaskCompleted()
try {
Utils.tryLogNonFatalError {
// Release memory used by this thread for shuffles
SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask()
}
Utils.tryLogNonFatalError {
// Release memory used by this thread for unrolling blocks
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask()
......
......@@ -98,13 +98,14 @@ private[spark] class BlockStoreShuffleReader[K, C](
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
// the ExternalSorter won't spill to disk.
val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.internalMetricsToAccumulators(
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
sorter.iterator
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.shuffle
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import com.google.common.annotations.VisibleForTesting
import org.apache.spark._
import org.apache.spark.memory.{StaticMemoryManager, MemoryManager}
import org.apache.spark.storage.{BlockId, BlockStatus}
import org.apache.spark.unsafe.array.ByteArrayMethods
/**
* Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling
* collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory
* from this pool and release it as it spills data out. When a task ends, all its memory will be
* released by the Executor.
*
* This class tries to ensure that each task gets a reasonable share of memory, instead of some
* task ramping up to a large amount first and then causing others to spill to disk repeatedly.
* If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory
* before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the
* set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever
* this set changes. This is all done by synchronizing access to `memoryManager` to mutate state
* and using wait() and notifyAll() to signal changes.
*
* Use `ShuffleMemoryManager.create()` factory method to create a new instance.
*
* @param memoryManager the interface through which this manager acquires execution memory
* @param pageSizeBytes number of bytes for each page, by default.
*/
private[spark]
class ShuffleMemoryManager protected (
memoryManager: MemoryManager,
val pageSizeBytes: Long)
extends Logging {
private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes
private def currentTaskAttemptId(): Long = {
// In case this is called on the driver, return an invalid task attempt id.
Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L)
}
/**
* Try to acquire up to numBytes memory for the current task, and return the number of bytes
* obtained, or 0 if none can be allocated. This call may block until there is enough free memory
* in some situations, to make sure each task has a chance to ramp up to at least 1 / 2N of the
* total memory pool (where N is the # of active tasks) before it is forced to spill. This can
* happen if the number of tasks increases but an older task had a lot of memory already.
*/
def tryToAcquire(numBytes: Long): Long = memoryManager.synchronized {
val taskAttemptId = currentTaskAttemptId()
assert(numBytes > 0, "invalid number of bytes requested: " + numBytes)
// Add this task to the taskMemory map just so we can keep an accurate count of the number
// of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire
if (!taskMemory.contains(taskAttemptId)) {
taskMemory(taskAttemptId) = 0L
// This will later cause waiting tasks to wake up and check numTasks again
memoryManager.notifyAll()
}
// Keep looping until we're either sure that we don't want to grant this request (because this
// task would have more than 1 / numActiveTasks of the memory) or we have enough free
// memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)).
// TODO: simplify this to limit each task to its own slot
while (true) {
val numActiveTasks = taskMemory.keys.size
val curMem = taskMemory(taskAttemptId)
val maxMemory = memoryManager.maxExecutionMemory
val freeMemory = maxMemory - taskMemory.values.sum
// How much we can grant this task; don't let it grow to more than 1 / numActiveTasks;
// don't let it be negative
val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem))
// Only give it as much memory as is free, which might be none if it reached 1 / numTasks
val toGrant = math.min(maxToGrant, freeMemory)
if (curMem < maxMemory / (2 * numActiveTasks)) {
// We want to let each task get at least 1 / (2 * numActiveTasks) before blocking;
// if we can't give it this much now, wait for other tasks to free up memory
// (this happens if older tasks allocated lots of memory before N grew)
if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) {
return acquire(toGrant)
} else {
logInfo(
s"TID $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free")
memoryManager.wait()
}
} else {
return acquire(toGrant)
}
}
0L // Never reached
}
/**
* Acquire N bytes of execution memory from the memory manager for the current task.
* @return number of bytes actually acquired (<= N).
*/
private def acquire(numBytes: Long): Long = memoryManager.synchronized {
val taskAttemptId = currentTaskAttemptId()
val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
val acquired = memoryManager.acquireExecutionMemory(numBytes, evictedBlocks)
// Register evicted blocks, if any, with the active task metrics
// TODO: just do this in `acquireExecutionMemory` (SPARK-10985)
Option(TaskContext.get()).foreach { tc =>
val metrics = tc.taskMetrics()
val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]())
metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq)
}
taskMemory(taskAttemptId) += acquired
acquired
}
/** Release numBytes bytes for the current task. */
def release(numBytes: Long): Unit = memoryManager.synchronized {
val taskAttemptId = currentTaskAttemptId()
val curMem = taskMemory.getOrElse(taskAttemptId, 0L)
if (curMem < numBytes) {
throw new SparkException(
s"Internal error: release called on $numBytes bytes but task only has $curMem")
}
if (taskMemory.contains(taskAttemptId)) {
taskMemory(taskAttemptId) -= numBytes
memoryManager.releaseExecutionMemory(numBytes)
}
memoryManager.notifyAll() // Notify waiters in tryToAcquire that memory has been freed
}
/** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */
def releaseMemoryForThisTask(): Unit = memoryManager.synchronized {
val taskAttemptId = currentTaskAttemptId()
taskMemory.remove(taskAttemptId).foreach { numBytes =>
memoryManager.releaseExecutionMemory(numBytes)
}
memoryManager.notifyAll() // Notify waiters in tryToAcquire that memory has been freed
}
/** Returns the memory consumption, in bytes, for the current task */
def getMemoryConsumptionForThisTask(): Long = memoryManager.synchronized {
val taskAttemptId = currentTaskAttemptId()
taskMemory.getOrElse(taskAttemptId, 0L)
}
}
private[spark] object ShuffleMemoryManager {
def create(
conf: SparkConf,
memoryManager: MemoryManager,
numCores: Int): ShuffleMemoryManager = {
val maxMemory = memoryManager.maxExecutionMemory
val pageSize = ShuffleMemoryManager.getPageSize(conf, maxMemory, numCores)
new ShuffleMemoryManager(memoryManager, pageSize)
}
/**
* Create a dummy [[ShuffleMemoryManager]] with the specified capacity and page size.
*/
def create(maxMemory: Long, pageSizeBytes: Long): ShuffleMemoryManager = {
val conf = new SparkConf
val memoryManager = new StaticMemoryManager(
conf, maxExecutionMemory = maxMemory, maxStorageMemory = Long.MaxValue)
new ShuffleMemoryManager(memoryManager, pageSizeBytes)
}
@VisibleForTesting
def createForTesting(maxMemory: Long): ShuffleMemoryManager = {
create(maxMemory, 4 * 1024 * 1024)
}
/**
* Sets the page size, in bytes.
*
* If user didn't explicitly set "spark.buffer.pageSize", we figure out the default value
* by looking at the number of cores available to the process, and the total amount of memory,
* and then divide it by a factor of safety.
*/
private def getPageSize(conf: SparkConf, maxMemory: Long, numCores: Int): Long = {
val minPageSize = 1L * 1024 * 1024 // 1MB
val maxPageSize = 64L * minPageSize // 64MB
val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors()
// Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case
val safetyFactor = 16
val size = ByteArrayMethods.nextPowerOf2(maxMemory / cores / safetyFactor)
val default = math.min(maxPageSize, math.max(minPageSize, size))
conf.getSizeAsBytes("spark.buffer.pageSize", default)
}
}
......@@ -133,7 +133,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
env.blockManager,
shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
context.taskMemoryManager(),
env.shuffleMemoryManager,
unsafeShuffleHandle,
mapId,
context,
......
......@@ -52,13 +52,13 @@ private[spark] class SortShuffleWriter[K, V, C](
sorter = if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
new ExternalSorter[K, V, C](
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
// if the operation being run is sortByKey.
new ExternalSorter[K, V, V](
aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
sorter.insertAll(records)
......@@ -67,7 +67,7 @@ private[spark] class SortShuffleWriter[K, V, C](
// (see SPARK-3570).
val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
val partitionLengths = sorter.writePartitionedFile(blockId, outputFile)
shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment