diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 07241c827c2ae5e5017a2b6c3bbb999a668be07f..6656fd1d0bc5973e985820bb9928ade33a697df8 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -20,6 +20,7 @@ package org.apache.spark.unsafe.map; import javax.annotation.Nullable; import java.io.File; import java.io.IOException; +import java.util.Arrays; import java.util.Iterator; import java.util.LinkedList; @@ -638,7 +639,11 @@ public final class BytesToBytesMap extends MemoryConsumer { assert (valueLength % 8 == 0); assert(longArray != null); - if (numElements == MAX_CAPACITY || !canGrowArray) { + + if (numElements == MAX_CAPACITY + // The map could be reused from last spill (because of no enough memory to grow), + // then we don't try to grow again if hit the `growthThreshold`. + || !canGrowArray && numElements > growthThreshold) { return false; } @@ -730,25 +735,18 @@ public final class BytesToBytesMap extends MemoryConsumer { } /** - * Free the memory used by longArray. + * Free all allocated memory associated with this map, including the storage for keys and values + * as well as the hash map array itself. + * + * This method is idempotent and can be called multiple times. */ - public void freeArray() { + public void free() { updatePeakMemoryUsed(); if (longArray != null) { long used = longArray.memoryBlock().size(); longArray = null; releaseMemory(used); } - } - - /** - * Free all allocated memory associated with this map, including the storage for keys and values - * as well as the hash map array itself. - * - * This method is idempotent and can be called multiple times. - */ - public void free() { - freeArray(); Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator(); while (dataPagesIterator.hasNext()) { MemoryBlock dataPage = dataPagesIterator.next(); @@ -833,6 +831,28 @@ public final class BytesToBytesMap extends MemoryConsumer { return dataPages.size(); } + /** + * Returns the underline long[] of longArray. + */ + public long[] getArray() { + assert(longArray != null); + return (long[]) longArray.memoryBlock().getBaseObject(); + } + + /** + * Reset this map to initialized state. + */ + public void reset() { + numElements = 0; + Arrays.fill(getArray(), 0); + while (dataPages.size() > 0) { + MemoryBlock dataPage = dataPages.removeLast(); + freePage(dataPage); + } + currentPage = null; + pageCursor = 0; + } + /** * Grows the size of the hash table and re-hash everything. */ diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 509fb0a044c0c43e220f5cb03e176093bda1c576..cba043bc48cc877a382c0b89a1a7a9fd109f5aa9 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -79,9 +79,13 @@ public final class UnsafeExternalSorter extends MemoryConsumer { PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, - UnsafeInMemorySorter inMemorySorter) { - return new UnsafeExternalSorter(taskMemoryManager, blockManager, + UnsafeInMemorySorter inMemorySorter) throws IOException { + UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager, taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter); + sorter.spill(Long.MAX_VALUE, sorter); + // The external sorter will be used to insert records, in-memory sorter is not needed. + sorter.inMemSorter = null; + return sorter; } public static UnsafeExternalSorter create( @@ -124,7 +128,6 @@ public final class UnsafeExternalSorter extends MemoryConsumer { acquireMemory(inMemSorter.getMemoryUsage()); } else { this.inMemSorter = existingInMemorySorter; - // will acquire after free the map } this.peakMemoryUsedBytes = getMemoryUsage(); @@ -157,12 +160,9 @@ public final class UnsafeExternalSorter extends MemoryConsumer { */ @Override public long spill(long size, MemoryConsumer trigger) throws IOException { - assert(inMemSorter != null); if (trigger != this) { if (readingIterator != null) { return readingIterator.spill(); - } else { - } return 0L; // this should throw exception } @@ -388,25 +388,38 @@ public final class UnsafeExternalSorter extends MemoryConsumer { inMemSorter.insertRecord(recordAddress, prefix); } + /** + * Merges another UnsafeExternalSorters into this one, the other one will be emptied. + * + * @throws IOException + */ + public void merge(UnsafeExternalSorter other) throws IOException { + other.spill(); + spillWriters.addAll(other.spillWriters); + // remove them from `spillWriters`, or the files will be deleted in `cleanupResources`. + other.spillWriters.clear(); + other.cleanupResources(); + } + /** * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()` * after consuming this iterator. */ public UnsafeSorterIterator getSortedIterator() throws IOException { - assert(inMemSorter != null); - readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); - int numIteratorsToMerge = spillWriters.size() + (readingIterator.hasNext() ? 1 : 0); if (spillWriters.isEmpty()) { + assert(inMemSorter != null); + readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); return readingIterator; } else { final UnsafeSorterSpillMerger spillMerger = - new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge); + new UnsafeSorterSpillMerger(recordComparator, prefixComparator, spillWriters.size()); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager)); } - spillWriters.clear(); - spillMerger.addSpillIfNotEmpty(readingIterator); - + if (inMemSorter != null) { + readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); + spillMerger.addSpillIfNotEmpty(readingIterator); + } return spillMerger.getSortedIterator(); } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 1480f0681ed9ca9e60772a2b7c984243cec37464..d57213b9b8bfc839208da0fc1ee51119ea714d61 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -19,9 +19,9 @@ package org.apache.spark.util.collection.unsafe.sort; import java.util.Comparator; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.Sorter; -import org.apache.spark.memory.TaskMemoryManager; /** * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records @@ -77,13 +77,20 @@ public final class UnsafeInMemorySorter { */ private int pos = 0; + public UnsafeInMemorySorter( + final TaskMemoryManager memoryManager, + final RecordComparator recordComparator, + final PrefixComparator prefixComparator, + int initialSize) { + this(memoryManager, recordComparator, prefixComparator, new long[initialSize * 2]); + } + public UnsafeInMemorySorter( final TaskMemoryManager memoryManager, final RecordComparator recordComparator, final PrefixComparator prefixComparator, - int initialSize) { - assert (initialSize > 0); - this.array = new long[initialSize * 2]; + long[] array) { + this.array = array; this.memoryManager = memoryManager; this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index d4b6d75b4d981b0d24a8755e42fd13b60c229eb3..a2f99d566d4711caef021ddbe9230d6dd6b68283 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -236,16 +236,13 @@ public final class UnsafeFixedWidthAggregationMap { /** * Sorts the map's records in place, spill them to disk, and returns an [[UnsafeKVExternalSorter]] - * that can be used to insert more records to do external sorting. * - * The only memory that is allocated is the address/prefix array, 16 bytes per record. - * - * Note that this destroys the map, and as a result, the map cannot be used anymore after this. + * Note that the map will be reset for inserting new records, and the returned sorter can NOT be used + * to insert records. */ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException { - UnsafeKVExternalSorter sorter = new UnsafeKVExternalSorter( + return new UnsafeKVExternalSorter( groupingKeySchema, aggregationBufferSchema, SparkEnv.get().blockManager(), map.getPageSizeBytes(), map); - return sorter; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 845f2ae6859b70fca0e98001588c771152d61ec0..e2898ef2e215839823cfb99318ff95543c209b72 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -83,11 +83,10 @@ public final class UnsafeKVExternalSorter { /* initialSize */ 4096, pageSizeBytes); } else { - // The memory needed for UnsafeInMemorySorter should be less than longArray in map. - map.freeArray(); - // The memory used by UnsafeInMemorySorter will be counted later (end of this block) + // During spilling, the array in map will not be used, so we can borrow that and use it + // as the underline array for in-memory sorter (it's always large enough). final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( - taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements())); + taskMemoryManager, recordComparator, prefixComparator, map.getArray()); // We cannot use the destructive iterator here because we are reusing the existing memory // pages in BytesToBytesMap to hold records during sorting. @@ -123,10 +122,9 @@ public final class UnsafeKVExternalSorter { pageSizeBytes, inMemSorter); - sorter.spill(); - map.free(); - // counting the memory used UnsafeInMemorySorter - taskMemoryManager.acquireExecutionMemory(inMemSorter.getMemoryUsage(), sorter); + // reset the map, so we can re-use it to insert new records. the inMemSorter will not used + // anymore, so the underline array could be used by map again. + map.reset(); } } @@ -142,6 +140,15 @@ public final class UnsafeKVExternalSorter { value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix); } + /** + * Merges another UnsafeKVExternalSorter into `this`, the other one will be emptied. + * + * @throws IOException + */ + public void merge(UnsafeKVExternalSorter other) throws IOException { + sorter.merge(other.sorter); + } + /** * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()` * after consuming this iterator. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 713a4db0cd59bc255f135d82c19b8eac8abfda69..ce8d592c368eeb61f827421026e98c8df487d79b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -34,14 +34,18 @@ import org.apache.spark.sql.types.StructType * * This iterator first uses hash-based aggregation to process input rows. It uses * a hash map to store groups and their corresponding aggregation buffers. If we - * this map cannot allocate memory from memory manager, - * it switches to sort-based aggregation. The process of the switch has the following step: + * this map cannot allocate memory from memory manager, it spill the map into disk + * and create a new one. After processed all the input, then merge all the spills + * together using external sorter, and do sort-based aggregation. + * + * The process has the following step: + * - Step 0: Do hash-based aggregation. * - Step 1: Sort all entries of the hash map based on values of grouping expressions and * spill them to disk. - * - Step 2: Create a external sorter based on the spilled sorted map entries. - * - Step 3: Redirect all input rows to the external sorter. - * - Step 4: Get a sorted [[KVIterator]] from the external sorter. - * - Step 5: Initialize sort-based aggregation. + * - Step 2: Create a external sorter based on the spilled sorted map entries and reset the map. + * - Step 3: Get a sorted [[KVIterator]] from the external sorter. + * - Step 4: Repeat step 0 until no more input. + * - Step 5: Initialize sort-based aggregation on the sorted iterator. * Then, this iterator works in the way of sort-based aggregation. * * The code of this class is organized as follows: @@ -488,9 +492,10 @@ class TungstenAggregationIterator( // The function used to read and process input rows. When processing input rows, // it first uses hash-based aggregation by putting groups and their buffers in - // hashMap. If we could not allocate more memory for the map, we switch to - // sort-based aggregation (by calling switchToSortBasedAggregation). - private def processInputs(): Unit = { + // hashMap. If there is not enough memory, it will multiple hash-maps, spilling + // after each becomes full then using sort to merge these spills, finally do sort + // based aggregation. + private def processInputs(fallbackStartsAt: Int): Unit = { if (groupingExpressions.isEmpty) { // If there is no grouping expressions, we can just reuse the same buffer over and over again. // Note that it would be better to eliminate the hash map entirely in the future. @@ -502,44 +507,40 @@ class TungstenAggregationIterator( processRow(buffer, newInput) } } else { - while (!sortBased && inputIter.hasNext) { + var i = 0 + while (inputIter.hasNext) { val newInput = inputIter.next() numInputRows += 1 val groupingKey = groupProjection.apply(newInput) - val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + var buffer: UnsafeRow = null + if (i < fallbackStartsAt) { + buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + } if (buffer == null) { - // buffer == null means that we could not allocate more memory. - // Now, we need to spill the map and switch to sort-based aggregation. - switchToSortBasedAggregation(groupingKey, newInput) - } else { - processRow(buffer, newInput) + val sorter = hashMap.destructAndCreateExternalSorter() + if (externalSorter == null) { + externalSorter = sorter + } else { + externalSorter.merge(sorter) + } + i = 0 + buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + if (buffer == null) { + // failed to allocate the first page + throw new OutOfMemoryError("No enough memory for aggregation") + } } + processRow(buffer, newInput) + i += 1 } - } - } - // This function is only used for testing. It basically the same as processInputs except - // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have - // been processed. - private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = { - var i = 0 - while (!sortBased && inputIter.hasNext) { - val newInput = inputIter.next() - numInputRows += 1 - val groupingKey = groupProjection.apply(newInput) - val buffer: UnsafeRow = if (i < fallbackStartsAt) { - hashMap.getAggregationBufferFromUnsafeRow(groupingKey) - } else { - null - } - if (buffer == null) { - // buffer == null means that we could not allocate more memory. - // Now, we need to spill the map and switch to sort-based aggregation. - switchToSortBasedAggregation(groupingKey, newInput) - } else { - processRow(buffer, newInput) + if (externalSorter != null) { + val sorter = hashMap.destructAndCreateExternalSorter() + externalSorter.merge(sorter) + hashMap.free() + + switchToSortBasedAggregation() } - i += 1 } } @@ -561,88 +562,8 @@ class TungstenAggregationIterator( /** * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory. */ - private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = { + private def switchToSortBasedAggregation(): Unit = { logInfo("falling back to sort based aggregation.") - // Step 1: Get the ExternalSorter containing sorted entries of the map. - externalSorter = hashMap.destructAndCreateExternalSorter() - - // Step 2: If we have aggregate function with mode Partial or Complete, - // we need to process input rows to get aggregation buffer. - // So, later in the sort-based aggregation iterator, we can do merge. - // If aggregate functions are with mode Final and PartialMerge, - // we just need to project the aggregation buffer from an input row. - val needsProcess = aggregationMode match { - case (Some(Partial), None) => true - case (None, Some(Complete)) => true - case (Some(Final), Some(Complete)) => true - case _ => false - } - - // Note: Since we spill the sorter's contents immediately after creating it, we must insert - // something into the sorter here to ensure that we acquire at least a page of memory. - // This is done through `externalSorter.insertKV`, which will trigger the page allocation. - // Otherwise, children operators may steal the window of opportunity and starve our sorter. - - if (needsProcess) { - // First, we create a buffer. - val buffer = createNewAggregationBuffer() - - // Process firstKey and firstInput. - // Initialize buffer. - buffer.copyFrom(initialAggregationBuffer) - processRow(buffer, firstInput) - externalSorter.insertKV(firstKey, buffer) - - // Process the rest of input rows. - while (inputIter.hasNext) { - val newInput = inputIter.next() - numInputRows += 1 - val groupingKey = groupProjection.apply(newInput) - buffer.copyFrom(initialAggregationBuffer) - processRow(buffer, newInput) - externalSorter.insertKV(groupingKey, buffer) - } - } else { - // When needsProcess is false, the format of input rows is groupingKey + aggregation buffer. - // We need to project the aggregation buffer part from an input row. - val buffer = createNewAggregationBuffer() - // In principle, we could use `allAggregateFunctions.flatMap(_.inputAggBufferAttributes)` to - // extract the aggregation buffer. In practice, however, we extract it positionally by relying - // on it being present at the end of the row. The reason for this relates to how the different - // aggregates handle input binding. - // - // ImperativeAggregate uses field numbers and field number offsets to manipulate its buffers, - // so its correctness does not rely on attribute bindings. When we fall back to sort-based - // aggregation, these field number offsets (mutableAggBufferOffset and inputAggBufferOffset) - // need to be updated and any internal state in the aggregate functions themselves must be - // reset, so we call withNewMutableAggBufferOffset and withNewInputAggBufferOffset to reset - // this state and update the offsets. - // - // The updated ImperativeAggregate will have different attribute ids for its - // aggBufferAttributes and inputAggBufferAttributes. This isn't a problem for the actual - // ImperativeAggregate evaluation, but it means that - // `allAggregateFunctions.flatMap(_.inputAggBufferAttributes)` will no longer match the - // attributes in `originalInputAttributes`, which is why we can't use those attributes here. - // - // For more details, see the discussion on PR #9038. - val bufferExtractor = newMutableProjection( - originalInputAttributes.drop(initialInputBufferOffset), - originalInputAttributes)() - bufferExtractor.target(buffer) - - // Insert firstKey and its buffer. - bufferExtractor(firstInput) - externalSorter.insertKV(firstKey, buffer) - - // Insert the rest of input rows. - while (inputIter.hasNext) { - val newInput = inputIter.next() - numInputRows += 1 - val groupingKey = groupProjection.apply(newInput) - bufferExtractor(newInput) - externalSorter.insertKV(groupingKey, buffer) - } - } // Set aggregationMode, processRow, and generateOutput for sort-based aggregation. val newAggregationMode = aggregationMode match { @@ -762,15 +683,7 @@ class TungstenAggregationIterator( /** * Start processing input rows. */ - testFallbackStartsAt match { - case None => - processInputs() - case Some(fallbackStartsAt) => - // This is the testing path. processInputsWithControlledFallback is same as processInputs - // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows - // have been processed. - processInputsWithControlledFallback(fallbackStartsAt) - } + processInputs(testFallbackStartsAt.getOrElse(Int.MaxValue)) // If we did not switch to sort-based aggregation in processInputs, // we pre-load the first key-value pair from the map (to make hasNext idempotent). diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index a38623623a441c1357ede81307466ec441dbe35f..7ceaee38d131bb53d6d2a544bf956aedd14a1e37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -170,9 +170,6 @@ class UnsafeFixedWidthAggregationMapSuite } testWithMemoryLeakDetection("test external sorting") { - // Memory consumption in the beginning of the task. - val initialMemoryConsumption = taskMemoryManager.getMemoryConsumptionForThisTask() - val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, aggBufferSchema, @@ -189,35 +186,33 @@ class UnsafeFixedWidthAggregationMapSuite buf.setInt(0, keyString.length) assert(buf != null) } - - // Convert the map into a sorter val sorter = map.destructAndCreateExternalSorter() // Add more keys to the sorter and make sure the results come out sorted. val additionalKeys = randomStrings(1024) - val keyConverter = UnsafeProjection.create(groupKeySchema) - val valueConverter = UnsafeProjection.create(aggBufferSchema) - additionalKeys.zipWithIndex.foreach { case (str, i) => - val k = InternalRow(UTF8String.fromString(str)) - val v = InternalRow(str.length) - sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) + buf.setInt(0, str.length) if ((i % 100) == 0) { - memoryManager.markExecutionAsOutOfMemoryOnce() - sorter.closeCurrentPage() + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) } } + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) val out = new scala.collection.mutable.ArrayBuffer[String] val iter = sorter.sortedIterator() while (iter.next()) { - assert(iter.getKey.getString(0).length === iter.getValue.getInt(0)) - out += iter.getKey.getString(0) + // At here, we also test if copy is correct. + val key = iter.getKey.copy() + val value = iter.getValue.copy() + assert(key.getString(0).length === value.getInt(0)) + out += key.getString(0) } assert(out === (keys ++ additionalKeys).sorted) - map.free() } @@ -232,25 +227,21 @@ class UnsafeFixedWidthAggregationMapSuite PAGE_SIZE_BYTES, false // disable perf metrics ) - - // Convert the map into a sorter val sorter = map.destructAndCreateExternalSorter() // Add more keys to the sorter and make sure the results come out sorted. val additionalKeys = randomStrings(1024) - val keyConverter = UnsafeProjection.create(groupKeySchema) - val valueConverter = UnsafeProjection.create(aggBufferSchema) - additionalKeys.zipWithIndex.foreach { case (str, i) => - val k = InternalRow(UTF8String.fromString(str)) - val v = InternalRow(str.length) - sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) + buf.setInt(0, str.length) if ((i % 100) == 0) { - memoryManager.markExecutionAsOutOfMemoryOnce() - sorter.closeCurrentPage() + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) } } + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) val out = new scala.collection.mutable.ArrayBuffer[String] val iter = sorter.sortedIterator() @@ -262,16 +253,12 @@ class UnsafeFixedWidthAggregationMapSuite out += key.getString(0) } - assert(out === (additionalKeys).sorted) - + assert(out === additionalKeys.sorted) map.free() } testWithMemoryLeakDetection("test external sorting with empty records") { - // Memory consumption in the beginning of the task. - val initialMemoryConsumption = taskMemoryManager.getMemoryConsumptionForThisTask() - val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, StructType(Nil), @@ -281,7 +268,6 @@ class UnsafeFixedWidthAggregationMapSuite PAGE_SIZE_BYTES, false // disable perf metrics ) - (1 to 10).foreach { i => val buf = map.getAggregationBuffer(UnsafeRow.createFromByteArray(0, 0)) assert(buf != null) @@ -292,13 +278,15 @@ class UnsafeFixedWidthAggregationMapSuite // Add more keys to the sorter and make sure the results come out sorted. (1 to 4096).foreach { i => - sorter.insertKV(UnsafeRow.createFromByteArray(0, 0), UnsafeRow.createFromByteArray(0, 0)) + map.getAggregationBufferFromUnsafeRow(UnsafeRow.createFromByteArray(0, 0)) if ((i % 100) == 0) { - memoryManager.markExecutionAsOutOfMemoryOnce() - sorter.closeCurrentPage() + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) } } + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) var count = 0 val iter = sorter.sortedIterator() @@ -309,9 +297,8 @@ class UnsafeFixedWidthAggregationMapSuite count += 1 } - // 1 record was from the map and 4096 records were explicitly inserted. - assert(count === 4097) - + // 1 record per map, spilled 42 times. + assert(count === 42) map.free() } @@ -345,6 +332,7 @@ class UnsafeFixedWidthAggregationMapSuite var sorter: UnsafeKVExternalSorter = null try { sorter = map.destructAndCreateExternalSorter() + map.free() } finally { if (sorter != null) { sorter.cleanupResources() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 74061db0f28af51247343b7c373bb62c5c173a76..ea80060e370e02f1d2e4a1338f7b80c56f7d240f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -22,13 +22,12 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types._ import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction { @@ -702,6 +701,13 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te } } + test("no aggregation function (SPARK-11486)") { + val df = sqlContext.range(20).selectExpr("id", "repeat(id, 1) as s") + .groupBy("s").count() + .groupBy().count() + checkAnswer(df, Row(20) :: Nil) + } + test("udaf with all data types") { val struct = StructType(