diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 968185bde78abfa33cc20c79a8599acfb8b47134..117745f9a9c00f63d31764f11b77eecdfad227d9 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -101,7 +101,7 @@ public class UnsafeExternalSorterSuite { public void setUp() { MockitoAnnotations.initMocks(this); sparkConf = new SparkConf(); - tempDir = new File(Utils.createTempDir$default$1()); + tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE); spillFilesCreated.clear(); taskContext = mock(TaskContext.class); @@ -143,13 +143,18 @@ public class UnsafeExternalSorterSuite { @After public void tearDown() { - long leakedUnsafeMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); - if (shuffleMemoryManager != null) { - long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask(); - shuffleMemoryManager = null; - assertEquals(0L, leakedShuffleMemory); + try { + long leakedUnsafeMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); + if (shuffleMemoryManager != null) { + long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask(); + shuffleMemoryManager = null; + assertEquals(0L, leakedShuffleMemory); + } + assertEquals(0, leakedUnsafeMemory); + } finally { + Utils.deleteRecursively(tempDir); + tempDir = null; } - assertEquals(0, leakedUnsafeMemory); } private void assertSpillFilesWereCleanedUp() { @@ -234,7 +239,7 @@ public class UnsafeExternalSorterSuite { public void spillingOccursInResponseToMemoryPressure() throws Exception { shuffleMemoryManager = new ShuffleMemoryManager(pageSizeBytes * 2); final UnsafeExternalSorter sorter = newSorter(); - final int numRecords = 100000; + final int numRecords = (int) pageSizeBytes / 4; for (int i = 0; i <= numRecords; i++) { insertNumber(sorter, numRecords - i); }