Skip to content
Snippets Groups Projects
Commit 6a1c864a authored by Davies Liu's avatar Davies Liu Committed by Davies Liu
Browse files

[SPARK-12295] [SQL] external spilling for window functions

This PR manage the memory used by window functions (buffered rows), also enable external spilling.

After this PR, we can run window functions on a partition with hundreds of millions of rows with only 1G.

Author: Davies Liu <davies@databricks.com>

Closes #10605 from davies/unsafe_window.
parent 84e77a15
No related branches found
No related tags found
No related merge requests found
......@@ -45,7 +45,9 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class);
@Nullable
private final PrefixComparator prefixComparator;
@Nullable
private final RecordComparator recordComparator;
private final TaskMemoryManager taskMemoryManager;
private final BlockManager blockManager;
......@@ -431,7 +433,11 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
public SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) {
this.upstream = inMemIterator;
this.numRecords = inMemIterator.numRecordsLeft();
this.numRecords = inMemIterator.getNumRecords();
}
public int getNumRecords() {
return numRecords;
}
public long spill() throws IOException {
......@@ -558,13 +564,23 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
private final Queue<UnsafeSorterIterator> iterators;
private UnsafeSorterIterator current;
private int numRecords;
public ChainedIterator(Queue<UnsafeSorterIterator> iterators) {
assert iterators.size() > 0;
this.numRecords = 0;
for (UnsafeSorterIterator iter: iterators) {
this.numRecords += iter.getNumRecords();
}
this.iterators = iterators;
this.current = iterators.remove();
}
@Override
public int getNumRecords() {
return numRecords;
}
@Override
public boolean hasNext() {
while (!current.hasNext() && !iterators.isEmpty()) {
......@@ -575,6 +591,9 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
@Override
public void loadNext() throws IOException {
while (!current.hasNext() && !iterators.isEmpty()) {
current = iterators.remove();
}
current.loadNext();
}
......
......@@ -19,6 +19,8 @@ package org.apache.spark.util.collection.unsafe.sort;
import java.util.Comparator;
import org.apache.avro.reflect.Nullable;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.unsafe.Platform;
......@@ -66,7 +68,9 @@ public final class UnsafeInMemorySorter {
private final MemoryConsumer consumer;
private final TaskMemoryManager memoryManager;
@Nullable
private final Sorter<RecordPointerAndKeyPrefix, LongArray> sorter;
@Nullable
private final Comparator<RecordPointerAndKeyPrefix> sortComparator;
/**
......@@ -98,10 +102,11 @@ public final class UnsafeInMemorySorter {
LongArray array) {
this.consumer = consumer;
this.memoryManager = memoryManager;
this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
if (recordComparator != null) {
this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
} else {
this.sorter = null;
this.sortComparator = null;
}
this.array = array;
......@@ -190,12 +195,13 @@ public final class UnsafeInMemorySorter {
}
@Override
public boolean hasNext() {
return position / 2 < numRecords;
public int getNumRecords() {
return numRecords;
}
public int numRecordsLeft() {
return numRecords - position / 2;
@Override
public boolean hasNext() {
return position / 2 < numRecords;
}
@Override
......@@ -227,7 +233,7 @@ public final class UnsafeInMemorySorter {
* {@code next()} will return the same mutable object.
*/
public SortedIterator getSortedIterator() {
if (sortComparator != null) {
if (sorter != null) {
sorter.sort(array, 0, pos / 2, sortComparator);
}
return new SortedIterator(pos / 2);
......
......@@ -32,4 +32,6 @@ public abstract class UnsafeSorterIterator {
public abstract int getRecordLength();
public abstract long getKeyPrefix();
public abstract int getNumRecords();
}
......@@ -23,6 +23,7 @@ import java.util.PriorityQueue;
final class UnsafeSorterSpillMerger {
private int numRecords = 0;
private final PriorityQueue<UnsafeSorterIterator> priorityQueue;
public UnsafeSorterSpillMerger(
......@@ -59,6 +60,7 @@ final class UnsafeSorterSpillMerger {
// priorityQueue, we will have n extra empty records in the result of the UnsafeSorterIterator.
spillReader.loadNext();
priorityQueue.add(spillReader);
numRecords += spillReader.getNumRecords();
}
}
......@@ -67,6 +69,11 @@ final class UnsafeSorterSpillMerger {
private UnsafeSorterIterator spillReader;
@Override
public int getNumRecords() {
return numRecords;
}
@Override
public boolean hasNext() {
return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext());
......
......@@ -38,6 +38,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
// Variables that change with every record read:
private int recordLength;
private long keyPrefix;
private int numRecords;
private int numRecordsRemaining;
private byte[] arr = new byte[1024 * 1024];
......@@ -53,13 +54,18 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
try {
this.in = blockManager.wrapForCompression(blockId, bs);
this.din = new DataInputStream(this.in);
numRecordsRemaining = din.readInt();
numRecords = numRecordsRemaining = din.readInt();
} catch (IOException e) {
Closeables.close(bs, /* swallowIOException = */ true);
throw e;
}
}
@Override
public int getNumRecords() {
return numRecords;
}
@Override
public boolean hasNext() {
return (numRecordsRemaining > 0);
......
......@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution
import java.util
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
......@@ -26,6 +28,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator}
import org.apache.spark.{SparkEnv, TaskContext}
/**
* This class calculates and outputs (windowed) aggregates over the rows in a single (sorted)
......@@ -283,23 +287,26 @@ case class Window(
val grouping = UnsafeProjection.create(partitionSpec, child.output)
// Manage the stream and the grouping.
var nextRow: InternalRow = EmptyRow
var nextGroup: InternalRow = EmptyRow
var nextRow: UnsafeRow = null
var nextGroup: UnsafeRow = null
var nextRowAvailable: Boolean = false
private[this] def fetchNextRow() {
nextRowAvailable = stream.hasNext
if (nextRowAvailable) {
nextRow = stream.next()
nextRow = stream.next().asInstanceOf[UnsafeRow]
nextGroup = grouping(nextRow)
} else {
nextRow = EmptyRow
nextGroup = EmptyRow
nextRow = null
nextGroup = null
}
}
fetchNextRow()
// Manage the current partition.
val rows = ArrayBuffer.empty[InternalRow]
val rows = ArrayBuffer.empty[UnsafeRow]
val inputFields = child.output.length
var sorter: UnsafeExternalSorter = null
var rowBuffer: RowBuffer = null
val windowFunctionResult = new SpecificMutableRow(expressions.map(_.dataType))
val frames = factories.map(_(windowFunctionResult))
val numFrames = frames.length
......@@ -307,27 +314,63 @@ case class Window(
// Collect all the rows in the current partition.
// Before we start to fetch new input rows, make a copy of nextGroup.
val currentGroup = nextGroup.copy()
rows.clear()
// clear last partition
if (sorter != null) {
// the last sorter of this task will be cleaned up via task completion listener
sorter.cleanupResources()
sorter = null
} else {
rows.clear()
}
while (nextRowAvailable && nextGroup == currentGroup) {
rows += nextRow.copy()
if (sorter == null) {
rows += nextRow.copy()
if (rows.length >= 4096) {
// We will not sort the rows, so prefixComparator and recordComparator are null.
sorter = UnsafeExternalSorter.create(
TaskContext.get().taskMemoryManager(),
SparkEnv.get.blockManager,
TaskContext.get(),
null,
null,
1024,
SparkEnv.get.memoryManager.pageSizeBytes)
rows.foreach { r =>
sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0)
}
rows.clear()
}
} else {
sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset,
nextRow.getSizeInBytes, 0)
}
fetchNextRow()
}
if (sorter != null) {
rowBuffer = new ExternalRowBuffer(sorter, inputFields)
} else {
rowBuffer = new ArrayRowBuffer(rows)
}
// Setup the frames.
var i = 0
while (i < numFrames) {
frames(i).prepare(rows)
frames(i).prepare(rowBuffer.copy())
i += 1
}
// Setup iteration
rowIndex = 0
rowsSize = rows.size
rowsSize = rowBuffer.size()
}
// Iteration
var rowIndex = 0
var rowsSize = 0
var rowsSize = 0L
override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable
val join = new JoinedRow
......@@ -340,13 +383,14 @@ case class Window(
if (rowIndex < rowsSize) {
// Get the results for the window frames.
var i = 0
val current = rowBuffer.next()
while (i < numFrames) {
frames(i).write()
frames(i).write(rowIndex, current)
i += 1
}
// 'Merge' the input row with the window function result
join(rows(rowIndex), windowFunctionResult)
join(current, windowFunctionResult)
rowIndex += 1
// Return the projection.
......@@ -362,14 +406,18 @@ case class Window(
* Function for comparing boundary values.
*/
private[execution] abstract class BoundOrdering {
def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int
def compare(inputRow: InternalRow, inputIndex: Int, outputRow: InternalRow, outputIndex: Int): Int
}
/**
* Compare the input index to the bound of the output index.
*/
private[execution] final case class RowBoundOrdering(offset: Int) extends BoundOrdering {
override def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int =
override def compare(
inputRow: InternalRow,
inputIndex: Int,
outputRow: InternalRow,
outputIndex: Int): Int =
inputIndex - (outputIndex + offset)
}
......@@ -380,8 +428,100 @@ private[execution] final case class RangeBoundOrdering(
ordering: Ordering[InternalRow],
current: Projection,
bound: Projection) extends BoundOrdering {
override def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int =
ordering.compare(current(input(inputIndex)), bound(input(outputIndex)))
override def compare(
inputRow: InternalRow,
inputIndex: Int,
outputRow: InternalRow,
outputIndex: Int): Int =
ordering.compare(current(inputRow), bound(outputRow))
}
/**
* The interface of row buffer for a partition
*/
private[execution] abstract class RowBuffer {
/** Number of rows. */
def size(): Int
/** Return next row in the buffer, null if no more left. */
def next(): InternalRow
/** Skip the next `n` rows. */
def skip(n: Int): Unit
/** Return a new RowBuffer that has the same rows. */
def copy(): RowBuffer
}
/**
* A row buffer based on ArrayBuffer (the number of rows is limited)
*/
private[execution] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer {
private[this] var cursor: Int = -1
/** Number of rows. */
def size(): Int = buffer.length
/** Return next row in the buffer, null if no more left. */
def next(): InternalRow = {
cursor += 1
if (cursor < buffer.length) {
buffer(cursor)
} else {
null
}
}
/** Skip the next `n` rows. */
def skip(n: Int): Unit = {
cursor += n
}
/** Return a new RowBuffer that has the same rows. */
def copy(): RowBuffer = {
new ArrayRowBuffer(buffer)
}
}
/**
* An external buffer of rows based on UnsafeExternalSorter
*/
private[execution] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int)
extends RowBuffer {
private[this] val iter: UnsafeSorterIterator = sorter.getIterator
private[this] val currentRow = new UnsafeRow(numFields)
/** Number of rows. */
def size(): Int = iter.getNumRecords()
/** Return next row in the buffer, null if no more left. */
def next(): InternalRow = {
if (iter.hasNext) {
iter.loadNext()
currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
currentRow
} else {
null
}
}
/** Skip the next `n` rows. */
def skip(n: Int): Unit = {
var i = 0
while (i < n && iter.hasNext) {
iter.loadNext()
i += 1
}
}
/** Return a new RowBuffer that has the same rows. */
def copy(): RowBuffer = {
new ExternalRowBuffer(sorter, numFields)
}
}
/**
......@@ -395,12 +535,12 @@ private[execution] abstract class WindowFunctionFrame {
*
* @param rows to calculate the frame results for.
*/
def prepare(rows: ArrayBuffer[InternalRow]): Unit
def prepare(rows: RowBuffer): Unit
/**
* Write the current results to the target row.
*/
def write(): Unit
def write(index: Int, current: InternalRow): Unit
}
/**
......@@ -421,14 +561,11 @@ private[execution] final class OffsetWindowFunctionFrame(
offset: Int) extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
private[this] var input: ArrayBuffer[InternalRow] = null
private[this] var input: RowBuffer = null
/** Index of the input row currently used for output. */
private[this] var inputIndex = 0
/** Index of the current output row. */
private[this] var outputIndex = 0
/** Row used when there is no valid input. */
private[this] val emptyRow = new GenericInternalRow(inputSchema.size)
......@@ -463,22 +600,26 @@ private[execution] final class OffsetWindowFunctionFrame(
newMutableProjection(boundExpressions, Nil)().target(target)
}
override def prepare(rows: ArrayBuffer[InternalRow]): Unit = {
override def prepare(rows: RowBuffer): Unit = {
input = rows
// drain the first few rows if offset is larger than zero
inputIndex = 0
while (inputIndex < offset) {
input.next()
inputIndex += 1
}
inputIndex = offset
outputIndex = 0
}
override def write(): Unit = {
val size = input.size
if (inputIndex >= 0 && inputIndex < size) {
join(input(inputIndex), input(outputIndex))
override def write(index: Int, current: InternalRow): Unit = {
if (inputIndex >= 0 && inputIndex < input.size) {
val r = input.next()
join(r, current)
} else {
join(emptyRow, input(outputIndex))
join(emptyRow, current)
}
projection(join)
inputIndex += 1
outputIndex += 1
}
}
......@@ -498,7 +639,13 @@ private[execution] final class SlidingWindowFunctionFrame(
ubound: BoundOrdering) extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
private[this] var input: ArrayBuffer[InternalRow] = null
private[this] var input: RowBuffer = null
/** The next row from `input`. */
private[this] var nextRow: InternalRow = null
/** The rows within current sliding window. */
private[this] val buffer = new util.ArrayDeque[InternalRow]()
/** Index of the first input row with a value greater than the upper bound of the current
* output row. */
......@@ -508,33 +655,32 @@ private[execution] final class SlidingWindowFunctionFrame(
* current output row. */
private[this] var inputLowIndex = 0
/** Index of the row we are currently writing. */
private[this] var outputIndex = 0
/** Prepare the frame for calculating a new partition. Reset all variables. */
override def prepare(rows: ArrayBuffer[InternalRow]): Unit = {
override def prepare(rows: RowBuffer): Unit = {
input = rows
nextRow = rows.next()
inputHighIndex = 0
inputLowIndex = 0
outputIndex = 0
buffer.clear()
}
/** Write the frame columns for the current row to the given target row. */
override def write(): Unit = {
var bufferUpdated = outputIndex == 0
override def write(index: Int, current: InternalRow): Unit = {
var bufferUpdated = index == 0
// Add all rows to the buffer for which the input row value is equal to or less than
// the output row upper bound.
while (inputHighIndex < input.size &&
ubound.compare(input, inputHighIndex, outputIndex) <= 0) {
while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) {
buffer.add(nextRow.copy())
nextRow = input.next()
inputHighIndex += 1
bufferUpdated = true
}
// Drop all rows from the buffer for which the input row value is smaller than
// the output row lower bound.
while (inputLowIndex < inputHighIndex &&
lbound.compare(input, inputLowIndex, outputIndex) < 0) {
while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) {
buffer.remove()
inputLowIndex += 1
bufferUpdated = true
}
......@@ -542,12 +688,12 @@ private[execution] final class SlidingWindowFunctionFrame(
// Only recalculate and update when the buffer changes.
if (bufferUpdated) {
processor.initialize(input.size)
processor.update(input, inputLowIndex, inputHighIndex)
val iter = buffer.iterator()
while (iter.hasNext) {
processor.update(iter.next())
}
processor.evaluate(target)
}
// Move to the next row.
outputIndex += 1
}
}
......@@ -567,13 +713,18 @@ private[execution] final class UnboundedWindowFunctionFrame(
processor: AggregateProcessor) extends WindowFunctionFrame {
/** Prepare the frame for calculating a new partition. Process all rows eagerly. */
override def prepare(rows: ArrayBuffer[InternalRow]): Unit = {
processor.initialize(rows.size)
processor.update(rows, 0, rows.size)
override def prepare(rows: RowBuffer): Unit = {
val size = rows.size()
processor.initialize(size)
var i = 0
while (i < size) {
processor.update(rows.next())
i += 1
}
}
/** Write the frame columns for the current row to the given target row. */
override def write(): Unit = {
override def write(index: Int, current: InternalRow): Unit = {
// Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate
// for each row.
processor.evaluate(target)
......@@ -600,31 +751,32 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame(
ubound: BoundOrdering) extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
private[this] var input: ArrayBuffer[InternalRow] = null
private[this] var input: RowBuffer = null
/** The next row from `input`. */
private[this] var nextRow: InternalRow = null
/** Index of the first input row with a value greater than the upper bound of the current
* output row. */
private[this] var inputIndex = 0
/** Index of the row we are currently writing. */
private[this] var outputIndex = 0
/** Prepare the frame for calculating a new partition. */
override def prepare(rows: ArrayBuffer[InternalRow]): Unit = {
override def prepare(rows: RowBuffer): Unit = {
input = rows
nextRow = rows.next()
inputIndex = 0
outputIndex = 0
processor.initialize(input.size)
}
/** Write the frame columns for the current row to the given target row. */
override def write(): Unit = {
var bufferUpdated = outputIndex == 0
override def write(index: Int, current: InternalRow): Unit = {
var bufferUpdated = index == 0
// Add all rows to the aggregates for which the input row value is equal to or less than
// the output row upper bound.
while (inputIndex < input.size && ubound.compare(input, inputIndex, outputIndex) <= 0) {
processor.update(input(inputIndex))
while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) {
processor.update(nextRow)
nextRow = input.next()
inputIndex += 1
bufferUpdated = true
}
......@@ -633,9 +785,6 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame(
if (bufferUpdated) {
processor.evaluate(target)
}
// Move to the next row.
outputIndex += 1
}
}
......@@ -661,29 +810,31 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame(
lbound: BoundOrdering) extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
private[this] var input: ArrayBuffer[InternalRow] = null
private[this] var input: RowBuffer = null
/** Index of the first input row with a value equal to or greater than the lower bound of the
* current output row. */
private[this] var inputIndex = 0
/** Index of the row we are currently writing. */
private[this] var outputIndex = 0
/** Prepare the frame for calculating a new partition. */
override def prepare(rows: ArrayBuffer[InternalRow]): Unit = {
override def prepare(rows: RowBuffer): Unit = {
input = rows
inputIndex = 0
outputIndex = 0
}
/** Write the frame columns for the current row to the given target row. */
override def write(): Unit = {
var bufferUpdated = outputIndex == 0
override def write(index: Int, current: InternalRow): Unit = {
var bufferUpdated = index == 0
// Duplicate the input to have a new iterator
val tmp = input.copy()
// Drop all rows from the buffer for which the input row value is smaller than
// the output row lower bound.
while (inputIndex < input.size && lbound.compare(input, inputIndex, outputIndex) < 0) {
tmp.skip(inputIndex)
var nextRow = tmp.next()
while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index) < 0) {
nextRow = tmp.next()
inputIndex += 1
bufferUpdated = true
}
......@@ -691,12 +842,12 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame(
// Only recalculate and update when the buffer changes.
if (bufferUpdated) {
processor.initialize(input.size)
processor.update(input, inputIndex, input.size)
while (nextRow != null) {
processor.update(nextRow)
nextRow = tmp.next()
}
processor.evaluate(target)
}
// Move to the next row.
outputIndex += 1
}
}
......@@ -825,15 +976,6 @@ private[execution] final class AggregateProcessor(
}
}
/** Bulk update the given buffer. */
def update(input: ArrayBuffer[InternalRow], begin: Int, end: Int): Unit = {
var i = begin
while (i < end) {
update(input(i))
i += 1
}
}
/** Evaluate buffer. */
def evaluate(target: MutableRow): Unit =
evaluateProjection.target(target)(buffer)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment