From 81498dd5c86ca51d2fb351c8ef52cbb28e6844f4 Mon Sep 17 00:00:00 2001 From: Davies Liu <davies@databricks.com> Date: Wed, 4 Nov 2015 21:30:21 -0800 Subject: [PATCH] [SPARK-11425] [SPARK-11486] Improve hybrid aggregation After aggregation, the dataset could be smaller than inputs, so it's better to do hash based aggregation for all inputs, then using sort based aggregation to merge them. Author: Davies Liu <davies@databricks.com> Closes #9383 from davies/fix_switch. --- .../spark/unsafe/map/BytesToBytesMap.java | 46 +++-- .../unsafe/sort/UnsafeExternalSorter.java | 39 ++-- .../unsafe/sort/UnsafeInMemorySorter.java | 15 +- .../UnsafeFixedWidthAggregationMap.java | 9 +- .../sql/execution/UnsafeKVExternalSorter.java | 23 ++- .../TungstenAggregationIterator.scala | 171 +++++------------- .../UnsafeFixedWidthAggregationMapSuite.scala | 64 +++---- .../execution/AggregationQuerySuite.scala | 12 +- 8 files changed, 165 insertions(+), 214 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 07241c827c..6656fd1d0b 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -20,6 +20,7 @@ package org.apache.spark.unsafe.map; import javax.annotation.Nullable; import java.io.File; import java.io.IOException; +import java.util.Arrays; import java.util.Iterator; import java.util.LinkedList; @@ -638,7 +639,11 @@ public final class BytesToBytesMap extends MemoryConsumer { assert (valueLength % 8 == 0); assert(longArray != null); - if (numElements == MAX_CAPACITY || !canGrowArray) { + + if (numElements == MAX_CAPACITY + // The map could be reused from last spill (because of no enough memory to grow), + // then we don't try to grow again if hit the `growthThreshold`. + || !canGrowArray && numElements > growthThreshold) { return false; } @@ -730,25 +735,18 @@ public final class BytesToBytesMap extends MemoryConsumer { } /** - * Free the memory used by longArray. + * Free all allocated memory associated with this map, including the storage for keys and values + * as well as the hash map array itself. + * + * This method is idempotent and can be called multiple times. */ - public void freeArray() { + public void free() { updatePeakMemoryUsed(); if (longArray != null) { long used = longArray.memoryBlock().size(); longArray = null; releaseMemory(used); } - } - - /** - * Free all allocated memory associated with this map, including the storage for keys and values - * as well as the hash map array itself. - * - * This method is idempotent and can be called multiple times. - */ - public void free() { - freeArray(); Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator(); while (dataPagesIterator.hasNext()) { MemoryBlock dataPage = dataPagesIterator.next(); @@ -833,6 +831,28 @@ public final class BytesToBytesMap extends MemoryConsumer { return dataPages.size(); } + /** + * Returns the underline long[] of longArray. + */ + public long[] getArray() { + assert(longArray != null); + return (long[]) longArray.memoryBlock().getBaseObject(); + } + + /** + * Reset this map to initialized state. + */ + public void reset() { + numElements = 0; + Arrays.fill(getArray(), 0); + while (dataPages.size() > 0) { + MemoryBlock dataPage = dataPages.removeLast(); + freePage(dataPage); + } + currentPage = null; + pageCursor = 0; + } + /** * Grows the size of the hash table and re-hash everything. */ diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 509fb0a044..cba043bc48 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -79,9 +79,13 @@ public final class UnsafeExternalSorter extends MemoryConsumer { PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, - UnsafeInMemorySorter inMemorySorter) { - return new UnsafeExternalSorter(taskMemoryManager, blockManager, + UnsafeInMemorySorter inMemorySorter) throws IOException { + UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager, taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter); + sorter.spill(Long.MAX_VALUE, sorter); + // The external sorter will be used to insert records, in-memory sorter is not needed. + sorter.inMemSorter = null; + return sorter; } public static UnsafeExternalSorter create( @@ -124,7 +128,6 @@ public final class UnsafeExternalSorter extends MemoryConsumer { acquireMemory(inMemSorter.getMemoryUsage()); } else { this.inMemSorter = existingInMemorySorter; - // will acquire after free the map } this.peakMemoryUsedBytes = getMemoryUsage(); @@ -157,12 +160,9 @@ public final class UnsafeExternalSorter extends MemoryConsumer { */ @Override public long spill(long size, MemoryConsumer trigger) throws IOException { - assert(inMemSorter != null); if (trigger != this) { if (readingIterator != null) { return readingIterator.spill(); - } else { - } return 0L; // this should throw exception } @@ -388,25 +388,38 @@ public final class UnsafeExternalSorter extends MemoryConsumer { inMemSorter.insertRecord(recordAddress, prefix); } + /** + * Merges another UnsafeExternalSorters into this one, the other one will be emptied. + * + * @throws IOException + */ + public void merge(UnsafeExternalSorter other) throws IOException { + other.spill(); + spillWriters.addAll(other.spillWriters); + // remove them from `spillWriters`, or the files will be deleted in `cleanupResources`. + other.spillWriters.clear(); + other.cleanupResources(); + } + /** * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()` * after consuming this iterator. */ public UnsafeSorterIterator getSortedIterator() throws IOException { - assert(inMemSorter != null); - readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); - int numIteratorsToMerge = spillWriters.size() + (readingIterator.hasNext() ? 1 : 0); if (spillWriters.isEmpty()) { + assert(inMemSorter != null); + readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); return readingIterator; } else { final UnsafeSorterSpillMerger spillMerger = - new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge); + new UnsafeSorterSpillMerger(recordComparator, prefixComparator, spillWriters.size()); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager)); } - spillWriters.clear(); - spillMerger.addSpillIfNotEmpty(readingIterator); - + if (inMemSorter != null) { + readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); + spillMerger.addSpillIfNotEmpty(readingIterator); + } return spillMerger.getSortedIterator(); } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 1480f0681e..d57213b9b8 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -19,9 +19,9 @@ package org.apache.spark.util.collection.unsafe.sort; import java.util.Comparator; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.Sorter; -import org.apache.spark.memory.TaskMemoryManager; /** * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records @@ -77,13 +77,20 @@ public final class UnsafeInMemorySorter { */ private int pos = 0; + public UnsafeInMemorySorter( + final TaskMemoryManager memoryManager, + final RecordComparator recordComparator, + final PrefixComparator prefixComparator, + int initialSize) { + this(memoryManager, recordComparator, prefixComparator, new long[initialSize * 2]); + } + public UnsafeInMemorySorter( final TaskMemoryManager memoryManager, final RecordComparator recordComparator, final PrefixComparator prefixComparator, - int initialSize) { - assert (initialSize > 0); - this.array = new long[initialSize * 2]; + long[] array) { + this.array = array; this.memoryManager = memoryManager; this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index d4b6d75b4d..a2f99d566d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -236,16 +236,13 @@ public final class UnsafeFixedWidthAggregationMap { /** * Sorts the map's records in place, spill them to disk, and returns an [[UnsafeKVExternalSorter]] - * that can be used to insert more records to do external sorting. * - * The only memory that is allocated is the address/prefix array, 16 bytes per record. - * - * Note that this destroys the map, and as a result, the map cannot be used anymore after this. + * Note that the map will be reset for inserting new records, and the returned sorter can NOT be used + * to insert records. */ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException { - UnsafeKVExternalSorter sorter = new UnsafeKVExternalSorter( + return new UnsafeKVExternalSorter( groupingKeySchema, aggregationBufferSchema, SparkEnv.get().blockManager(), map.getPageSizeBytes(), map); - return sorter; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 845f2ae685..e2898ef2e2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -83,11 +83,10 @@ public final class UnsafeKVExternalSorter { /* initialSize */ 4096, pageSizeBytes); } else { - // The memory needed for UnsafeInMemorySorter should be less than longArray in map. - map.freeArray(); - // The memory used by UnsafeInMemorySorter will be counted later (end of this block) + // During spilling, the array in map will not be used, so we can borrow that and use it + // as the underline array for in-memory sorter (it's always large enough). final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( - taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements())); + taskMemoryManager, recordComparator, prefixComparator, map.getArray()); // We cannot use the destructive iterator here because we are reusing the existing memory // pages in BytesToBytesMap to hold records during sorting. @@ -123,10 +122,9 @@ public final class UnsafeKVExternalSorter { pageSizeBytes, inMemSorter); - sorter.spill(); - map.free(); - // counting the memory used UnsafeInMemorySorter - taskMemoryManager.acquireExecutionMemory(inMemSorter.getMemoryUsage(), sorter); + // reset the map, so we can re-use it to insert new records. the inMemSorter will not used + // anymore, so the underline array could be used by map again. + map.reset(); } } @@ -142,6 +140,15 @@ public final class UnsafeKVExternalSorter { value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix); } + /** + * Merges another UnsafeKVExternalSorter into `this`, the other one will be emptied. + * + * @throws IOException + */ + public void merge(UnsafeKVExternalSorter other) throws IOException { + sorter.merge(other.sorter); + } + /** * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()` * after consuming this iterator. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 713a4db0cd..ce8d592c36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -34,14 +34,18 @@ import org.apache.spark.sql.types.StructType * * This iterator first uses hash-based aggregation to process input rows. It uses * a hash map to store groups and their corresponding aggregation buffers. If we - * this map cannot allocate memory from memory manager, - * it switches to sort-based aggregation. The process of the switch has the following step: + * this map cannot allocate memory from memory manager, it spill the map into disk + * and create a new one. After processed all the input, then merge all the spills + * together using external sorter, and do sort-based aggregation. + * + * The process has the following step: + * - Step 0: Do hash-based aggregation. * - Step 1: Sort all entries of the hash map based on values of grouping expressions and * spill them to disk. - * - Step 2: Create a external sorter based on the spilled sorted map entries. - * - Step 3: Redirect all input rows to the external sorter. - * - Step 4: Get a sorted [[KVIterator]] from the external sorter. - * - Step 5: Initialize sort-based aggregation. + * - Step 2: Create a external sorter based on the spilled sorted map entries and reset the map. + * - Step 3: Get a sorted [[KVIterator]] from the external sorter. + * - Step 4: Repeat step 0 until no more input. + * - Step 5: Initialize sort-based aggregation on the sorted iterator. * Then, this iterator works in the way of sort-based aggregation. * * The code of this class is organized as follows: @@ -488,9 +492,10 @@ class TungstenAggregationIterator( // The function used to read and process input rows. When processing input rows, // it first uses hash-based aggregation by putting groups and their buffers in - // hashMap. If we could not allocate more memory for the map, we switch to - // sort-based aggregation (by calling switchToSortBasedAggregation). - private def processInputs(): Unit = { + // hashMap. If there is not enough memory, it will multiple hash-maps, spilling + // after each becomes full then using sort to merge these spills, finally do sort + // based aggregation. + private def processInputs(fallbackStartsAt: Int): Unit = { if (groupingExpressions.isEmpty) { // If there is no grouping expressions, we can just reuse the same buffer over and over again. // Note that it would be better to eliminate the hash map entirely in the future. @@ -502,44 +507,40 @@ class TungstenAggregationIterator( processRow(buffer, newInput) } } else { - while (!sortBased && inputIter.hasNext) { + var i = 0 + while (inputIter.hasNext) { val newInput = inputIter.next() numInputRows += 1 val groupingKey = groupProjection.apply(newInput) - val buffer: UnsafeRow = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + var buffer: UnsafeRow = null + if (i < fallbackStartsAt) { + buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + } if (buffer == null) { - // buffer == null means that we could not allocate more memory. - // Now, we need to spill the map and switch to sort-based aggregation. - switchToSortBasedAggregation(groupingKey, newInput) - } else { - processRow(buffer, newInput) + val sorter = hashMap.destructAndCreateExternalSorter() + if (externalSorter == null) { + externalSorter = sorter + } else { + externalSorter.merge(sorter) + } + i = 0 + buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) + if (buffer == null) { + // failed to allocate the first page + throw new OutOfMemoryError("No enough memory for aggregation") + } } + processRow(buffer, newInput) + i += 1 } - } - } - // This function is only used for testing. It basically the same as processInputs except - // that it switch to sort-based aggregation after `fallbackStartsAt` input rows have - // been processed. - private def processInputsWithControlledFallback(fallbackStartsAt: Int): Unit = { - var i = 0 - while (!sortBased && inputIter.hasNext) { - val newInput = inputIter.next() - numInputRows += 1 - val groupingKey = groupProjection.apply(newInput) - val buffer: UnsafeRow = if (i < fallbackStartsAt) { - hashMap.getAggregationBufferFromUnsafeRow(groupingKey) - } else { - null - } - if (buffer == null) { - // buffer == null means that we could not allocate more memory. - // Now, we need to spill the map and switch to sort-based aggregation. - switchToSortBasedAggregation(groupingKey, newInput) - } else { - processRow(buffer, newInput) + if (externalSorter != null) { + val sorter = hashMap.destructAndCreateExternalSorter() + externalSorter.merge(sorter) + hashMap.free() + + switchToSortBasedAggregation() } - i += 1 } } @@ -561,88 +562,8 @@ class TungstenAggregationIterator( /** * Switch to sort-based aggregation when the hash-based approach is unable to acquire memory. */ - private def switchToSortBasedAggregation(firstKey: UnsafeRow, firstInput: InternalRow): Unit = { + private def switchToSortBasedAggregation(): Unit = { logInfo("falling back to sort based aggregation.") - // Step 1: Get the ExternalSorter containing sorted entries of the map. - externalSorter = hashMap.destructAndCreateExternalSorter() - - // Step 2: If we have aggregate function with mode Partial or Complete, - // we need to process input rows to get aggregation buffer. - // So, later in the sort-based aggregation iterator, we can do merge. - // If aggregate functions are with mode Final and PartialMerge, - // we just need to project the aggregation buffer from an input row. - val needsProcess = aggregationMode match { - case (Some(Partial), None) => true - case (None, Some(Complete)) => true - case (Some(Final), Some(Complete)) => true - case _ => false - } - - // Note: Since we spill the sorter's contents immediately after creating it, we must insert - // something into the sorter here to ensure that we acquire at least a page of memory. - // This is done through `externalSorter.insertKV`, which will trigger the page allocation. - // Otherwise, children operators may steal the window of opportunity and starve our sorter. - - if (needsProcess) { - // First, we create a buffer. - val buffer = createNewAggregationBuffer() - - // Process firstKey and firstInput. - // Initialize buffer. - buffer.copyFrom(initialAggregationBuffer) - processRow(buffer, firstInput) - externalSorter.insertKV(firstKey, buffer) - - // Process the rest of input rows. - while (inputIter.hasNext) { - val newInput = inputIter.next() - numInputRows += 1 - val groupingKey = groupProjection.apply(newInput) - buffer.copyFrom(initialAggregationBuffer) - processRow(buffer, newInput) - externalSorter.insertKV(groupingKey, buffer) - } - } else { - // When needsProcess is false, the format of input rows is groupingKey + aggregation buffer. - // We need to project the aggregation buffer part from an input row. - val buffer = createNewAggregationBuffer() - // In principle, we could use `allAggregateFunctions.flatMap(_.inputAggBufferAttributes)` to - // extract the aggregation buffer. In practice, however, we extract it positionally by relying - // on it being present at the end of the row. The reason for this relates to how the different - // aggregates handle input binding. - // - // ImperativeAggregate uses field numbers and field number offsets to manipulate its buffers, - // so its correctness does not rely on attribute bindings. When we fall back to sort-based - // aggregation, these field number offsets (mutableAggBufferOffset and inputAggBufferOffset) - // need to be updated and any internal state in the aggregate functions themselves must be - // reset, so we call withNewMutableAggBufferOffset and withNewInputAggBufferOffset to reset - // this state and update the offsets. - // - // The updated ImperativeAggregate will have different attribute ids for its - // aggBufferAttributes and inputAggBufferAttributes. This isn't a problem for the actual - // ImperativeAggregate evaluation, but it means that - // `allAggregateFunctions.flatMap(_.inputAggBufferAttributes)` will no longer match the - // attributes in `originalInputAttributes`, which is why we can't use those attributes here. - // - // For more details, see the discussion on PR #9038. - val bufferExtractor = newMutableProjection( - originalInputAttributes.drop(initialInputBufferOffset), - originalInputAttributes)() - bufferExtractor.target(buffer) - - // Insert firstKey and its buffer. - bufferExtractor(firstInput) - externalSorter.insertKV(firstKey, buffer) - - // Insert the rest of input rows. - while (inputIter.hasNext) { - val newInput = inputIter.next() - numInputRows += 1 - val groupingKey = groupProjection.apply(newInput) - bufferExtractor(newInput) - externalSorter.insertKV(groupingKey, buffer) - } - } // Set aggregationMode, processRow, and generateOutput for sort-based aggregation. val newAggregationMode = aggregationMode match { @@ -762,15 +683,7 @@ class TungstenAggregationIterator( /** * Start processing input rows. */ - testFallbackStartsAt match { - case None => - processInputs() - case Some(fallbackStartsAt) => - // This is the testing path. processInputsWithControlledFallback is same as processInputs - // except that it switches to sort-based aggregation after `fallbackStartsAt` input rows - // have been processed. - processInputsWithControlledFallback(fallbackStartsAt) - } + processInputs(testFallbackStartsAt.getOrElse(Int.MaxValue)) // If we did not switch to sort-based aggregation in processInputs, // we pre-load the first key-value pair from the map (to make hasNext idempotent). diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index a38623623a..7ceaee38d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -170,9 +170,6 @@ class UnsafeFixedWidthAggregationMapSuite } testWithMemoryLeakDetection("test external sorting") { - // Memory consumption in the beginning of the task. - val initialMemoryConsumption = taskMemoryManager.getMemoryConsumptionForThisTask() - val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, aggBufferSchema, @@ -189,35 +186,33 @@ class UnsafeFixedWidthAggregationMapSuite buf.setInt(0, keyString.length) assert(buf != null) } - - // Convert the map into a sorter val sorter = map.destructAndCreateExternalSorter() // Add more keys to the sorter and make sure the results come out sorted. val additionalKeys = randomStrings(1024) - val keyConverter = UnsafeProjection.create(groupKeySchema) - val valueConverter = UnsafeProjection.create(aggBufferSchema) - additionalKeys.zipWithIndex.foreach { case (str, i) => - val k = InternalRow(UTF8String.fromString(str)) - val v = InternalRow(str.length) - sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) + buf.setInt(0, str.length) if ((i % 100) == 0) { - memoryManager.markExecutionAsOutOfMemoryOnce() - sorter.closeCurrentPage() + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) } } + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) val out = new scala.collection.mutable.ArrayBuffer[String] val iter = sorter.sortedIterator() while (iter.next()) { - assert(iter.getKey.getString(0).length === iter.getValue.getInt(0)) - out += iter.getKey.getString(0) + // At here, we also test if copy is correct. + val key = iter.getKey.copy() + val value = iter.getValue.copy() + assert(key.getString(0).length === value.getInt(0)) + out += key.getString(0) } assert(out === (keys ++ additionalKeys).sorted) - map.free() } @@ -232,25 +227,21 @@ class UnsafeFixedWidthAggregationMapSuite PAGE_SIZE_BYTES, false // disable perf metrics ) - - // Convert the map into a sorter val sorter = map.destructAndCreateExternalSorter() // Add more keys to the sorter and make sure the results come out sorted. val additionalKeys = randomStrings(1024) - val keyConverter = UnsafeProjection.create(groupKeySchema) - val valueConverter = UnsafeProjection.create(aggBufferSchema) - additionalKeys.zipWithIndex.foreach { case (str, i) => - val k = InternalRow(UTF8String.fromString(str)) - val v = InternalRow(str.length) - sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) + buf.setInt(0, str.length) if ((i % 100) == 0) { - memoryManager.markExecutionAsOutOfMemoryOnce() - sorter.closeCurrentPage() + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) } } + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) val out = new scala.collection.mutable.ArrayBuffer[String] val iter = sorter.sortedIterator() @@ -262,16 +253,12 @@ class UnsafeFixedWidthAggregationMapSuite out += key.getString(0) } - assert(out === (additionalKeys).sorted) - + assert(out === additionalKeys.sorted) map.free() } testWithMemoryLeakDetection("test external sorting with empty records") { - // Memory consumption in the beginning of the task. - val initialMemoryConsumption = taskMemoryManager.getMemoryConsumptionForThisTask() - val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, StructType(Nil), @@ -281,7 +268,6 @@ class UnsafeFixedWidthAggregationMapSuite PAGE_SIZE_BYTES, false // disable perf metrics ) - (1 to 10).foreach { i => val buf = map.getAggregationBuffer(UnsafeRow.createFromByteArray(0, 0)) assert(buf != null) @@ -292,13 +278,15 @@ class UnsafeFixedWidthAggregationMapSuite // Add more keys to the sorter and make sure the results come out sorted. (1 to 4096).foreach { i => - sorter.insertKV(UnsafeRow.createFromByteArray(0, 0), UnsafeRow.createFromByteArray(0, 0)) + map.getAggregationBufferFromUnsafeRow(UnsafeRow.createFromByteArray(0, 0)) if ((i % 100) == 0) { - memoryManager.markExecutionAsOutOfMemoryOnce() - sorter.closeCurrentPage() + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) } } + val sorter2 = map.destructAndCreateExternalSorter() + sorter.merge(sorter2) var count = 0 val iter = sorter.sortedIterator() @@ -309,9 +297,8 @@ class UnsafeFixedWidthAggregationMapSuite count += 1 } - // 1 record was from the map and 4096 records were explicitly inserted. - assert(count === 4097) - + // 1 record per map, spilled 42 times. + assert(count === 42) map.free() } @@ -345,6 +332,7 @@ class UnsafeFixedWidthAggregationMapSuite var sorter: UnsafeKVExternalSorter = null try { sorter = map.destructAndCreateExternalSorter() + map.free() } finally { if (sorter != null) { sorter.cleanupResources() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 74061db0f2..ea80060e37 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -22,13 +22,12 @@ import scala.collection.JavaConverters._ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types._ import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction { @@ -702,6 +701,13 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te } } + test("no aggregation function (SPARK-11486)") { + val df = sqlContext.range(20).selectExpr("id", "repeat(id, 1) as s") + .groupBy("s").count() + .groupBy().count() + checkAnswer(df, Row(20) :: Nil) + } + test("udaf with all data types") { val struct = StructType( -- GitLab