diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 008799cc7739526722f35665b319b4bddb599a21..8fbdb72832adffc54f504cd19a241746582ea27a 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -20,6 +20,7 @@ package org.apache.spark.memory; import java.io.IOException; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -28,9 +29,9 @@ import org.apache.spark.unsafe.memory.MemoryBlock; */ public abstract class MemoryConsumer { - private final TaskMemoryManager taskMemoryManager; + protected final TaskMemoryManager taskMemoryManager; private final long pageSize; - private long used; + protected long used; protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) { this.taskMemoryManager = taskMemoryManager; @@ -74,26 +75,29 @@ public abstract class MemoryConsumer { public abstract long spill(long size, MemoryConsumer trigger) throws IOException; /** - * Acquire `size` bytes memory. - * - * If there is not enough memory, throws OutOfMemoryError. + * Allocates a LongArray of `size`. */ - protected void acquireMemory(long size) { - long got = taskMemoryManager.acquireExecutionMemory(size, this); - if (got < size) { - taskMemoryManager.releaseExecutionMemory(got, this); + public LongArray allocateArray(long size) { + long required = size * 8L; + MemoryBlock page = taskMemoryManager.allocatePage(required, this); + if (page == null || page.size() < required) { + long got = 0; + if (page != null) { + got = page.size(); + taskMemoryManager.freePage(page, this); + } taskMemoryManager.showMemoryUsage(); - throw new OutOfMemoryError("Could not acquire " + size + " bytes of memory, got " + got); + throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); } - used += got; + used += required; + return new LongArray(page); } /** - * Release `size` bytes memory. + * Frees a LongArray. */ - protected void releaseMemory(long size) { - used -= size; - taskMemoryManager.releaseExecutionMemory(size, this); + public void freeArray(LongArray array) { + freePage(array.memoryBlock()); } /** @@ -109,7 +113,7 @@ public abstract class MemoryConsumer { long got = 0; if (page != null) { got = page.size(); - freePage(page); + taskMemoryManager.freePage(page, this); } taskMemoryManager.showMemoryUsage(); throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 4230575446d310e43c84413ef23562b0f9db258c..6440f9c0f30de8b35a15b51faef6269209c986a7 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -137,7 +137,7 @@ public class TaskMemoryManager { if (got < required) { // Call spill() on other consumers to release memory for (MemoryConsumer c: consumers) { - if (c != null && c != consumer && c.getUsed() > 0) { + if (c != consumer && c.getUsed() > 0) { try { long released = c.spill(required - got, consumer); if (released > 0) { @@ -173,7 +173,9 @@ public class TaskMemoryManager { } } - consumers.add(consumer); + if (consumer != null) { + consumers.add(consumer); + } logger.debug("Task {} acquire {} for {}", taskAttemptId, Utils.bytesToString(got), consumer); return got; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 400d8520019b9639dd4618fce0c30d1589327f23..9affff80143d71eaa8311694bca9a7b13086d1dd 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -39,6 +39,7 @@ import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempShuffleBlockId; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.Utils; @@ -114,8 +115,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { this.numElementsForSpillThreshold = conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE); this.writeMetrics = writeMetrics; - acquireMemory(initialSize * 8L); - this.inMemSorter = new ShuffleInMemorySorter(initialSize); + this.inMemSorter = new ShuffleInMemorySorter(this, initialSize); this.peakMemoryUsedBytes = getMemoryUsage(); } @@ -301,9 +301,8 @@ final class ShuffleExternalSorter extends MemoryConsumer { public void cleanupResources() { freeMemory(); if (inMemSorter != null) { - long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter.free(); inMemSorter = null; - releaseMemory(sorterMemoryUsage); } for (SpillInfo spill : spills) { if (spill.file.exists() && !spill.file.delete()) { @@ -321,9 +320,10 @@ final class ShuffleExternalSorter extends MemoryConsumer { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { long used = inMemSorter.getMemoryUsage(); - long needed = used + inMemSorter.getMemoryToExpand(); + LongArray array; try { - acquireMemory(needed); // could trigger spilling + // could trigger spilling + array = allocateArray(used / 8 * 2); } catch (OutOfMemoryError e) { // should have trigger spilling assert(inMemSorter.hasSpaceForAnotherRecord()); @@ -331,16 +331,9 @@ final class ShuffleExternalSorter extends MemoryConsumer { } // check if spilling is triggered or not if (inMemSorter.hasSpaceForAnotherRecord()) { - releaseMemory(needed); + freeArray(array); } else { - try { - inMemSorter.expandPointerArray(); - releaseMemory(used); - } catch (OutOfMemoryError oom) { - // Just in case that JVM had run out of memory - releaseMemory(needed); - spill(); - } + inMemSorter.expandPointerArray(array); } } } @@ -404,9 +397,8 @@ final class ShuffleExternalSorter extends MemoryConsumer { // Do not count the final file towards the spill count. writeSortedFile(true); freeMemory(); - long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter.free(); inMemSorter = null; - releaseMemory(sorterMemoryUsage); } return spills.toArray(new SpillInfo[spills.size()]); } catch (IOException e) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index e630575d1ae19fd2eb0cb26adddaad4cc0a1589a..58ad88e1ed87bde0a5a5b5d913923aa182410dec 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -19,11 +19,14 @@ package org.apache.spark.shuffle.sort; import java.util.Comparator; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.util.collection.Sorter; final class ShuffleInMemorySorter { - private final Sorter<PackedRecordPointer, long[]> sorter; + private final Sorter<PackedRecordPointer, LongArray> sorter; private static final class SortComparator implements Comparator<PackedRecordPointer> { @Override public int compare(PackedRecordPointer left, PackedRecordPointer right) { @@ -32,24 +35,34 @@ final class ShuffleInMemorySorter { } private static final SortComparator SORT_COMPARATOR = new SortComparator(); + private final MemoryConsumer consumer; + /** * An array of record pointers and partition ids that have been encoded by * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating * records. */ - private long[] array; + private LongArray array; /** * The position in the pointer array where new records can be inserted. */ private int pos = 0; - public ShuffleInMemorySorter(int initialSize) { + public ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) { + this.consumer = consumer; assert (initialSize > 0); - this.array = new long[initialSize]; + this.array = consumer.allocateArray(initialSize); this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE); } + public void free() { + if (array != null) { + consumer.freeArray(array); + array = null; + } + } + public int numRecords() { return pos; } @@ -58,30 +71,25 @@ final class ShuffleInMemorySorter { pos = 0; } - private int newLength() { - // Guard against overflow: - return array.length <= Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE; - } - - /** - * Returns the memory needed to expand - */ - public long getMemoryToExpand() { - return ((long) (newLength() - array.length)) * 8; - } - - public void expandPointerArray() { - final long[] oldArray = array; - array = new long[newLength()]; - System.arraycopy(oldArray, 0, array, 0, oldArray.length); + public void expandPointerArray(LongArray newArray) { + assert(newArray.size() > array.size()); + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + newArray.getBaseObject(), + newArray.getBaseOffset(), + array.size() * 8L + ); + consumer.freeArray(array); + array = newArray; } public boolean hasSpaceForAnotherRecord() { - return pos < array.length; + return pos < array.size(); } public long getMemoryUsage() { - return array.length * 8L; + return array.size() * 8L; } /** @@ -96,14 +104,9 @@ final class ShuffleInMemorySorter { */ public void insertRecord(long recordPointer, int partitionId) { if (!hasSpaceForAnotherRecord()) { - if (array.length == Integer.MAX_VALUE) { - throw new IllegalStateException("Sort pointer array has reached maximum size"); - } else { - expandPointerArray(); - } + expandPointerArray(consumer.allocateArray(array.size() * 2)); } - array[pos] = - PackedRecordPointer.packPointer(recordPointer, partitionId); + array.set(pos, PackedRecordPointer.packPointer(recordPointer, partitionId)); pos++; } @@ -112,12 +115,12 @@ final class ShuffleInMemorySorter { */ public static final class ShuffleSorterIterator { - private final long[] pointerArray; + private final LongArray pointerArray; private final int numRecords; final PackedRecordPointer packedRecordPointer = new PackedRecordPointer(); private int position = 0; - public ShuffleSorterIterator(int numRecords, long[] pointerArray) { + public ShuffleSorterIterator(int numRecords, LongArray pointerArray) { this.numRecords = numRecords; this.pointerArray = pointerArray; } @@ -127,7 +130,7 @@ final class ShuffleInMemorySorter { } public void loadNext() { - packedRecordPointer.set(pointerArray[position]); + packedRecordPointer.set(pointerArray.get(position)); position++; } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java index 8a1e5aec6ff0e4ce9dc2d656e71b9fc2dd51e3f2..8f4e3229976dc8ea4743564f9a1e8f1114f419ad 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java @@ -17,16 +17,19 @@ package org.apache.spark.shuffle.sort; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.SortDataFormat; -final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> { +final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, LongArray> { public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat(); private ShuffleSortDataFormat() { } @Override - public PackedRecordPointer getKey(long[] data, int pos) { + public PackedRecordPointer getKey(LongArray data, int pos) { // Since we re-use keys, this method shouldn't be called. throw new UnsupportedOperationException(); } @@ -37,31 +40,38 @@ final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, lo } @Override - public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) { - reuse.set(data[pos]); + public PackedRecordPointer getKey(LongArray data, int pos, PackedRecordPointer reuse) { + reuse.set(data.get(pos)); return reuse; } @Override - public void swap(long[] data, int pos0, int pos1) { - final long temp = data[pos0]; - data[pos0] = data[pos1]; - data[pos1] = temp; + public void swap(LongArray data, int pos0, int pos1) { + final long temp = data.get(pos0); + data.set(pos0, data.get(pos1)); + data.set(pos1, temp); } @Override - public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { - dst[dstPos] = src[srcPos]; + public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { + dst.set(dstPos, src.get(srcPos)); } @Override - public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { - System.arraycopy(src, srcPos, dst, dstPos, length); + public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { + Platform.copyMemory( + src.getBaseObject(), + src.getBaseOffset() + srcPos * 8, + dst.getBaseObject(), + dst.getBaseOffset() + dstPos * 8, + length * 8 + ); } @Override - public long[] allocate(int length) { - return new long[length]; + public LongArray allocate(int length) { + // This buffer is used temporary (usually small), so it's fine to allocated from JVM heap. + return new LongArray(MemoryBlock.fromLongArray(new long[length])); } } diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 6656fd1d0bc5973e985820bb9928ade33a697df8..04694dc54418c8501d0424bdb155f859943d089c 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -20,7 +20,6 @@ package org.apache.spark.unsafe.map; import javax.annotation.Nullable; import java.io.File; import java.io.IOException; -import java.util.Arrays; import java.util.Iterator; import java.util.LinkedList; @@ -724,11 +723,10 @@ public final class BytesToBytesMap extends MemoryConsumer { */ private void allocate(int capacity) { assert (capacity >= 0); - // The capacity needs to be divisible by 64 so that our bit set can be sized properly capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64); assert (capacity <= MAX_CAPACITY); - acquireMemory(capacity * 16); - longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2])); + longArray = allocateArray(capacity * 2); + longArray.zeroOut(); this.growthThreshold = (int) (capacity * loadFactor); this.mask = capacity - 1; @@ -743,9 +741,8 @@ public final class BytesToBytesMap extends MemoryConsumer { public void free() { updatePeakMemoryUsed(); if (longArray != null) { - long used = longArray.memoryBlock().size(); + freeArray(longArray); longArray = null; - releaseMemory(used); } Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator(); while (dataPagesIterator.hasNext()) { @@ -834,9 +831,9 @@ public final class BytesToBytesMap extends MemoryConsumer { /** * Returns the underline long[] of longArray. */ - public long[] getArray() { + public LongArray getArray() { assert(longArray != null); - return (long[]) longArray.memoryBlock().getBaseObject(); + return longArray; } /** @@ -844,7 +841,8 @@ public final class BytesToBytesMap extends MemoryConsumer { */ public void reset() { numElements = 0; - Arrays.fill(getArray(), 0); + longArray.zeroOut(); + while (dataPages.size() > 0) { MemoryBlock dataPage = dataPages.removeLast(); freePage(dataPage); @@ -887,7 +885,7 @@ public final class BytesToBytesMap extends MemoryConsumer { longArray.set(newPos * 2, keyPointer); longArray.set(newPos * 2 + 1, hashcode); } - releaseMemory(oldLongArray.memoryBlock().size()); + freeArray(oldLongArray); if (enablePerfMetrics) { timeSpentResizingNs += System.nanoTime() - resizeStartTime; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index cba043bc48cc877a382c0b89a1a7a9fd109f5aa9..9a7b2ad06cab68428b6dc8a376f776ae1e8c36e2 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -32,6 +32,7 @@ import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.TaskCompletionListener; import org.apache.spark.util.Utils; @@ -123,9 +124,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer { this.writeMetrics = new ShuffleWriteMetrics(); if (existingInMemorySorter == null) { - this.inMemSorter = - new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize); - acquireMemory(inMemSorter.getMemoryUsage()); + this.inMemSorter = new UnsafeInMemorySorter( + this, taskMemoryManager, recordComparator, prefixComparator, initialSize); } else { this.inMemSorter = existingInMemorySorter; } @@ -277,9 +277,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer { deleteSpillFiles(); freeMemory(); if (inMemSorter != null) { - long used = inMemSorter.getMemoryUsage(); + inMemSorter.free(); inMemSorter = null; - releaseMemory(used); } } } @@ -293,9 +292,10 @@ public final class UnsafeExternalSorter extends MemoryConsumer { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { long used = inMemSorter.getMemoryUsage(); - long needed = used + inMemSorter.getMemoryToExpand(); + LongArray array; try { - acquireMemory(needed); // could trigger spilling + // could trigger spilling + array = allocateArray(used / 8 * 2); } catch (OutOfMemoryError e) { // should have trigger spilling assert(inMemSorter.hasSpaceForAnotherRecord()); @@ -303,16 +303,9 @@ public final class UnsafeExternalSorter extends MemoryConsumer { } // check if spilling is triggered or not if (inMemSorter.hasSpaceForAnotherRecord()) { - releaseMemory(needed); + freeArray(array); } else { - try { - inMemSorter.expandPointerArray(); - releaseMemory(used); - } catch (OutOfMemoryError oom) { - // Just in case that JVM had run out of memory - releaseMemory(needed); - spill(); - } + inMemSorter.expandPointerArray(array); } } } @@ -498,9 +491,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer { nextUpstream = null; assert(inMemSorter != null); - long used = inMemSorter.getMemoryUsage(); + inMemSorter.free(); inMemSorter = null; - releaseMemory(used); } numRecords--; upstream.loadNext(); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index d57213b9b8bfc839208da0fc1ee51119ea714d61..a218ad4623f463e6e850fa9d305aca088ca14ecb 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -19,8 +19,10 @@ package org.apache.spark.util.collection.unsafe.sort; import java.util.Comparator; +import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.util.collection.Sorter; /** @@ -62,15 +64,16 @@ public final class UnsafeInMemorySorter { } } + private final MemoryConsumer consumer; private final TaskMemoryManager memoryManager; - private final Sorter<RecordPointerAndKeyPrefix, long[]> sorter; + private final Sorter<RecordPointerAndKeyPrefix, LongArray> sorter; private final Comparator<RecordPointerAndKeyPrefix> sortComparator; /** * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ - private long[] array; + private LongArray array; /** * The position in the sort buffer where new records can be inserted. @@ -78,22 +81,33 @@ public final class UnsafeInMemorySorter { private int pos = 0; public UnsafeInMemorySorter( + final MemoryConsumer consumer, final TaskMemoryManager memoryManager, final RecordComparator recordComparator, final PrefixComparator prefixComparator, int initialSize) { - this(memoryManager, recordComparator, prefixComparator, new long[initialSize * 2]); + this(consumer, memoryManager, recordComparator, prefixComparator, + consumer.allocateArray(initialSize * 2)); } public UnsafeInMemorySorter( + final MemoryConsumer consumer, final TaskMemoryManager memoryManager, final RecordComparator recordComparator, final PrefixComparator prefixComparator, - long[] array) { - this.array = array; + LongArray array) { + this.consumer = consumer; this.memoryManager = memoryManager; this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); + this.array = array; + } + + /** + * Free the memory used by pointer array. + */ + public void free() { + consumer.freeArray(array); } public void reset() { @@ -107,26 +121,26 @@ public final class UnsafeInMemorySorter { return pos / 2; } - private int newLength() { - return array.length < Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE; - } - - public long getMemoryToExpand() { - return (long) (newLength() - array.length) * 8L; - } - public long getMemoryUsage() { - return array.length * 8L; + return array.size() * 8L; } public boolean hasSpaceForAnotherRecord() { - return pos + 2 <= array.length; + return pos + 2 <= array.size(); } - public void expandPointerArray() { - final long[] oldArray = array; - array = new long[newLength()]; - System.arraycopy(oldArray, 0, array, 0, oldArray.length); + public void expandPointerArray(LongArray newArray) { + if (newArray.size() < array.size()) { + throw new OutOfMemoryError("Not enough memory to grow pointer array"); + } + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + newArray.getBaseObject(), + newArray.getBaseOffset(), + array.size() * 8L); + consumer.freeArray(array); + array = newArray; } /** @@ -138,11 +152,11 @@ public final class UnsafeInMemorySorter { */ public void insertRecord(long recordPointer, long keyPrefix) { if (!hasSpaceForAnotherRecord()) { - expandPointerArray(); + expandPointerArray(consumer.allocateArray(array.size() * 2)); } - array[pos] = recordPointer; + array.set(pos, recordPointer); pos++; - array[pos] = keyPrefix; + array.set(pos, keyPrefix); pos++; } @@ -150,7 +164,7 @@ public final class UnsafeInMemorySorter { private final TaskMemoryManager memoryManager; private final int sortBufferInsertPosition; - private final long[] sortBuffer; + private final LongArray sortBuffer; private int position = 0; private Object baseObject; private long baseOffset; @@ -160,7 +174,7 @@ public final class UnsafeInMemorySorter { private SortedIterator( TaskMemoryManager memoryManager, int sortBufferInsertPosition, - long[] sortBuffer) { + LongArray sortBuffer) { this.memoryManager = memoryManager; this.sortBufferInsertPosition = sortBufferInsertPosition; this.sortBuffer = sortBuffer; @@ -188,11 +202,11 @@ public final class UnsafeInMemorySorter { @Override public void loadNext() { // This pointer points to a 4-byte record length, followed by the record's bytes - final long recordPointer = sortBuffer[position]; + final long recordPointer = sortBuffer.get(position); baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length recordLength = Platform.getInt(baseObject, baseOffset - 4); - keyPrefix = sortBuffer[position + 1]; + keyPrefix = sortBuffer.get(position + 1); position += 2; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java index d09c728a7a638a8f1ff2663762cee8b8a823fabb..d3137f5f31c25cfaaa2591201a26eca0091729d4 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -17,6 +17,9 @@ package org.apache.spark.util.collection.unsafe.sort; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.SortDataFormat; /** @@ -26,14 +29,14 @@ import org.apache.spark.util.collection.SortDataFormat; * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ -final class UnsafeSortDataFormat extends SortDataFormat<RecordPointerAndKeyPrefix, long[]> { +final class UnsafeSortDataFormat extends SortDataFormat<RecordPointerAndKeyPrefix, LongArray> { public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); private UnsafeSortDataFormat() { } @Override - public RecordPointerAndKeyPrefix getKey(long[] data, int pos) { + public RecordPointerAndKeyPrefix getKey(LongArray data, int pos) { // Since we re-use keys, this method shouldn't be called. throw new UnsupportedOperationException(); } @@ -44,37 +47,43 @@ final class UnsafeSortDataFormat extends SortDataFormat<RecordPointerAndKeyPrefi } @Override - public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) { - reuse.recordPointer = data[pos * 2]; - reuse.keyPrefix = data[pos * 2 + 1]; + public RecordPointerAndKeyPrefix getKey(LongArray data, int pos, RecordPointerAndKeyPrefix reuse) { + reuse.recordPointer = data.get(pos * 2); + reuse.keyPrefix = data.get(pos * 2 + 1); return reuse; } @Override - public void swap(long[] data, int pos0, int pos1) { - long tempPointer = data[pos0 * 2]; - long tempKeyPrefix = data[pos0 * 2 + 1]; - data[pos0 * 2] = data[pos1 * 2]; - data[pos0 * 2 + 1] = data[pos1 * 2 + 1]; - data[pos1 * 2] = tempPointer; - data[pos1 * 2 + 1] = tempKeyPrefix; + public void swap(LongArray data, int pos0, int pos1) { + long tempPointer = data.get(pos0 * 2); + long tempKeyPrefix = data.get(pos0 * 2 + 1); + data.set(pos0 * 2, data.get(pos1 * 2)); + data.set(pos0 * 2 + 1, data.get(pos1 * 2 + 1)); + data.set(pos1 * 2, tempPointer); + data.set(pos1 * 2 + 1, tempKeyPrefix); } @Override - public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { - dst[dstPos * 2] = src[srcPos * 2]; - dst[dstPos * 2 + 1] = src[srcPos * 2 + 1]; + public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { + dst.set(dstPos * 2, src.get(srcPos * 2)); + dst.set(dstPos * 2 + 1, src.get(srcPos * 2 + 1)); } @Override - public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { - System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2); + public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { + Platform.copyMemory( + src.getBaseObject(), + src.getBaseOffset() + srcPos * 16, + dst.getBaseObject(), + dst.getBaseOffset() + dstPos * 16, + length * 16); } @Override - public long[] allocate(int length) { + public LongArray allocate(int length) { assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large"; - return new long[length * 2]; + // This is used as temporary buffer, it's fine to allocate from JVM heap. + return new LongArray(MemoryBlock.fromLongArray(new long[length * 2])); } } diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index dab7b0592cb4e4124f6b6e62b90f503c39e63f81..c731317395612d773f60fbea8935960bc1e473fc 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -17,8 +17,6 @@ package org.apache.spark.memory; -import java.io.IOException; - import org.junit.Assert; import org.junit.Test; @@ -27,27 +25,6 @@ import org.apache.spark.unsafe.memory.MemoryBlock; public class TaskMemoryManagerSuite { - class TestMemoryConsumer extends MemoryConsumer { - TestMemoryConsumer(TaskMemoryManager memoryManager) { - super(memoryManager); - } - - @Override - public long spill(long size, MemoryConsumer trigger) throws IOException { - long used = getUsed(); - releaseMemory(used); - return used; - } - - void use(long size) { - acquireMemory(size); - } - - void free(long size) { - releaseMemory(size); - } - } - @Test public void leakedPageMemoryIsDetected() { final TaskMemoryManager manager = new TaskMemoryManager( diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java new file mode 100644 index 0000000000000000000000000000000000000000..8ae36427385093a363dfaeb0a68ce755f00955ac --- /dev/null +++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + +import java.io.IOException; + +public class TestMemoryConsumer extends MemoryConsumer { + public TestMemoryConsumer(TaskMemoryManager memoryManager) { + super(memoryManager); + } + + @Override + public long spill(long size, MemoryConsumer trigger) throws IOException { + long used = getUsed(); + free(used); + return used; + } + + void use(long size) { + long got = taskMemoryManager.acquireExecutionMemory(size, this); + used += got; + } + + void free(long size) { + used -= size; + taskMemoryManager.releaseExecutionMemory(size, this); + } +} + + diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index 2293b1bbc113e30cf6372134753aef56bf98749a..faa5a863ee63037d2f063cbd821d5340779436f8 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -25,13 +25,19 @@ import org.junit.Test; import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; -import org.apache.spark.unsafe.Platform; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TestMemoryConsumer; import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.memory.TaskMemoryManager; public class ShuffleInMemorySorterSuite { + final TestMemoryManager memoryManager = + new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")); + final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(taskMemoryManager); + private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { final byte[] strBytes = new byte[strLength]; Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength); @@ -40,7 +46,7 @@ public class ShuffleInMemorySorterSuite { @Test public void testSortingEmptyInput() { - final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(100); + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 100); final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); assert(!iter.hasNext()); } @@ -63,7 +69,7 @@ public class ShuffleInMemorySorterSuite { new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); final Object baseObject = dataPage.getBaseObject(); - final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4); + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4); final HashPartitioner hashPartitioner = new HashPartitioner(4); // Write the records into the data page and store pointers into the sorter @@ -104,7 +110,7 @@ public class ShuffleInMemorySorterSuite { @Test public void testSortingManyNumbers() throws Exception { - ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4); + ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4); int[] numbersToSort = new int[128000]; Random random = new Random(16); for (int i = 0; i < numbersToSort.length; i++) { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index cfead0e5924b88b9062c440f96ab4cb2cc947984..11c3a7be388759048b4b5f795077dfe762f24d5c 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -390,7 +390,6 @@ public class UnsafeExternalSorterSuite { for (int i = 0; i < numRecordsPerPage * 10; i++) { insertNumber(sorter, i); newPeakMemory = sorter.getPeakMemoryUsedBytes(); - // The first page is pre-allocated on instantiation if (i % numRecordsPerPage == 0) { // We allocated a new page for this record, so peak memory should change assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 642f6585f8a15982fc1df51109225b7b93723ec6..a203a09648ac0eded8d5243c4fd61c588ebcabf7 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -23,6 +23,7 @@ import org.junit.Test; import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; +import org.apache.spark.memory.TestMemoryConsumer; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; @@ -44,9 +45,11 @@ public class UnsafeInMemorySorterSuite { @Test public void testSortingEmptyInput() { - final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter( - new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0), + final TaskMemoryManager memoryManager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); + final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, + memoryManager, mock(RecordComparator.class), mock(PrefixComparator.class), 100); @@ -69,6 +72,7 @@ public class UnsafeInMemorySorterSuite { }; final TaskMemoryManager memoryManager = new TaskMemoryManager( new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); final Object baseObject = dataPage.getBaseObject(); // Write the records into the data page: @@ -102,7 +106,7 @@ public class UnsafeInMemorySorterSuite { return (int) prefix1 - (int) prefix2; } }; - UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator, + UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager, recordComparator, prefixComparator, dataToSort.length); // Given a page of records, insert those records into the sorter one-by-one: position = dataPage.getBaseOffset(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index e2898ef2e215839823cfb99318ff95543c209b72..8c9b9c85e37fc265406d7aed03080c15e8bff074 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -85,8 +85,9 @@ public final class UnsafeKVExternalSorter { } else { // During spilling, the array in map will not be used, so we can borrow that and use it // as the underline array for in-memory sorter (it's always large enough). + // Since we will not grow the array, it's fine to pass `null` as consumer. final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( - taskMemoryManager, recordComparator, prefixComparator, map.getArray()); + null, taskMemoryManager, recordComparator, prefixComparator, map.getArray()); // We cannot use the destructive iterator here because we are reusing the existing memory // pages in BytesToBytesMap to hold records during sorting. diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index 74105050e419101269123445c28fc5086ece7b7c..1a3cdff638264d39ee18d75804833dad957e8d51 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -39,7 +39,6 @@ public final class LongArray { private final long length; public LongArray(MemoryBlock memory) { - assert memory.size() % WIDTH == 0 : "Memory not aligned (" + memory.size() + ")"; assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size > 4 billion elements"; this.memory = memory; this.baseObj = memory.getBaseObject(); @@ -51,6 +50,14 @@ public final class LongArray { return memory; } + public Object getBaseObject() { + return baseObj; + } + + public long getBaseOffset() { + return baseOffset; + } + /** * Returns the number of elements this array can hold. */ @@ -58,6 +65,15 @@ public final class LongArray { return length; } + /** + * Fill this all with 0L. + */ + public void zeroOut() { + for (long off = baseOffset; off < baseOffset + length * WIDTH; off += WIDTH) { + Platform.putLong(baseObj, off, 0); + } + } + /** * Sets the value at position {@code index}. */ diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java index 5974cf91ff993cace8949c6611a39e4f3649b27e..fb8e53b3348f364853c838c61ca1851f1d5204df 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java @@ -34,5 +34,9 @@ public class LongArraySuite { Assert.assertEquals(2, arr.size()); Assert.assertEquals(1L, arr.get(0)); Assert.assertEquals(3L, arr.get(1)); + + arr.zeroOut(); + Assert.assertEquals(0L, arr.get(0)); + Assert.assertEquals(0L, arr.get(1)); } }