Skip to content
Snippets Groups Projects
Commit b55499a4 authored by Josh Rosen's avatar Josh Rosen Committed by Reynold Xin
Browse files

[SPARK-8932] Support copy() for UnsafeRows that do not use ObjectPools

We call Row.copy() in many places throughout SQL but UnsafeRow currently throws UnsupportedOperationException when copy() is called.

Supporting copying when ObjectPool is used may be difficult, since we may need to handle deep-copying of objects in the pool. In addition, this copy() method needs to produce a self-contained row object which may be passed around / buffered by downstream code which does not understand the UnsafeRow format.

In the long run, we'll need to figure out how to handle the ObjectPool corner cases, but this may be unnecessary if other changes are made. Therefore, in order to unblock my sort patch (#6444) I propose that we support copy() for the cases where UnsafeRow does not use an ObjectPool and continue to throw UnsupportedOperationException when an ObjectPool is used.

This patch accomplishes this by modifying UnsafeRow so that it knows the size of the row's backing data in order to be able to copy it into a byte array.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #7306 from JoshRosen/SPARK-8932 and squashes the following commits:

338e6bf [Josh Rosen] Support copy for UnsafeRows that do not use ObjectPools.
parent a2908148
No related branches found
No related tags found
No related merge requests found
......@@ -120,9 +120,11 @@ public final class UnsafeFixedWidthAggregationMap {
this.bufferPool = new ObjectPool(initialCapacity);
InternalRow initRow = initProjection.apply(emptyRow);
this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)];
int emptyBufferSize = bufferConverter.getSizeRequirement(initRow);
this.emptyBuffer = new byte[emptyBufferSize];
int writtenLength = bufferConverter.writeRow(
initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool);
initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, emptyBufferSize,
bufferPool);
assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!";
// re-use the empty buffer only when there is no object saved in pool.
reuseEmptyBuffer = bufferPool.size() == 0;
......@@ -142,6 +144,7 @@ public final class UnsafeFixedWidthAggregationMap {
groupingKey,
groupingKeyConversionScratchSpace,
PlatformDependent.BYTE_ARRAY_OFFSET,
groupingKeySize,
keyPool);
assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!";
......@@ -157,7 +160,7 @@ public final class UnsafeFixedWidthAggregationMap {
// There is some objects referenced by emptyBuffer, so generate a new one
InternalRow initRow = initProjection.apply(emptyRow);
bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET,
bufferPool);
groupingKeySize, bufferPool);
}
loc.putNewKey(
groupingKeyConversionScratchSpace,
......@@ -175,6 +178,7 @@ public final class UnsafeFixedWidthAggregationMap {
address.getBaseObject(),
address.getBaseOffset(),
bufferConverter.numFields(),
loc.getValueLength(),
bufferPool
);
return currentBuffer;
......@@ -214,12 +218,14 @@ public final class UnsafeFixedWidthAggregationMap {
keyAddress.getBaseObject(),
keyAddress.getBaseOffset(),
keyConverter.numFields(),
loc.getKeyLength(),
keyPool
);
entry.value.pointTo(
valueAddress.getBaseObject(),
valueAddress.getBaseOffset(),
bufferConverter.numFields(),
loc.getValueLength(),
bufferPool
);
return entry;
......
......@@ -68,6 +68,9 @@ public final class UnsafeRow extends MutableRow {
/** The number of fields in this row, used for calculating the bitset width (and in assertions) */
private int numFields;
/** The size of this row's backing data, in bytes) */
private int sizeInBytes;
public int length() { return numFields; }
/** The width of the null tracking bit set, in bytes */
......@@ -95,14 +98,17 @@ public final class UnsafeRow extends MutableRow {
* @param baseObject the base object
* @param baseOffset the offset within the base object
* @param numFields the number of fields in this row
* @param sizeInBytes the size of this row's backing data, in bytes
* @param pool the object pool to hold arbitrary objects
*/
public void pointTo(Object baseObject, long baseOffset, int numFields, ObjectPool pool) {
public void pointTo(
Object baseObject, long baseOffset, int numFields, int sizeInBytes, ObjectPool pool) {
assert numFields >= 0 : "numFields should >= 0";
this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
this.baseObject = baseObject;
this.baseOffset = baseOffset;
this.numFields = numFields;
this.sizeInBytes = sizeInBytes;
this.pool = pool;
}
......@@ -336,9 +342,31 @@ public final class UnsafeRow extends MutableRow {
}
}
/**
* Copies this row, returning a self-contained UnsafeRow that stores its data in an internal
* byte array rather than referencing data stored in a data page.
* <p>
* This method is only supported on UnsafeRows that do not use ObjectPools.
*/
@Override
public InternalRow copy() {
throw new UnsupportedOperationException();
if (pool != null) {
throw new UnsupportedOperationException(
"Copy is not supported for UnsafeRows that use object pools");
} else {
UnsafeRow rowCopy = new UnsafeRow();
final byte[] rowDataCopy = new byte[sizeInBytes];
PlatformDependent.copyMemory(
baseObject,
baseOffset,
rowDataCopy,
PlatformDependent.BYTE_ARRAY_OFFSET,
sizeInBytes
);
rowCopy.pointTo(
rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes, null);
return rowCopy;
}
}
@Override
......
......@@ -70,10 +70,16 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
* @param row the row to convert
* @param baseObject the base object of the destination address
* @param baseOffset the base offset of the destination address
* @param rowLengthInBytes the length calculated by `getSizeRequirement(row)`
* @return the number of bytes written. This should be equal to `getSizeRequirement(row)`.
*/
def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long, pool: ObjectPool): Int = {
unsafeRow.pointTo(baseObject, baseOffset, writers.length, pool)
def writeRow(
row: InternalRow,
baseObject: Object,
baseOffset: Long,
rowLengthInBytes: Int,
pool: ObjectPool): Int = {
unsafeRow.pointTo(baseObject, baseOffset, writers.length, rowLengthInBytes, pool)
if (writers.length > 0) {
// zero-out the bitset
......
......@@ -44,19 +44,32 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val sizeRequired: Int = converter.getSizeRequirement(row)
assert(sizeRequired === 8 + (3 * 8))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
val numBytesWritten =
converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null)
assert(numBytesWritten === sizeRequired)
val unsafeRow = new UnsafeRow()
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
unsafeRow.pointTo(
buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null)
assert(unsafeRow.getLong(0) === 0)
assert(unsafeRow.getLong(1) === 1)
assert(unsafeRow.getInt(2) === 2)
// We can copy UnsafeRows as long as they don't reference ObjectPools
val unsafeRowCopy = unsafeRow.copy()
assert(unsafeRowCopy.getLong(0) === 0)
assert(unsafeRowCopy.getLong(1) === 1)
assert(unsafeRowCopy.getInt(2) === 2)
unsafeRow.setLong(1, 3)
assert(unsafeRow.getLong(1) === 3)
unsafeRow.setInt(2, 4)
assert(unsafeRow.getInt(2) === 4)
// Mutating the original row should not have changed the copy
assert(unsafeRowCopy.getLong(0) === 0)
assert(unsafeRowCopy.getLong(1) === 1)
assert(unsafeRowCopy.getInt(2) === 2)
}
test("basic conversion with primitive, string and binary types") {
......@@ -73,12 +86,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
val numBytesWritten = converter.writeRow(
row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null)
assert(numBytesWritten === sizeRequired)
val unsafeRow = new UnsafeRow()
val pool = new ObjectPool(10)
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
unsafeRow.pointTo(
buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool)
assert(unsafeRow.getLong(0) === 0)
assert(unsafeRow.getString(1) === "Hello")
assert(unsafeRow.get(2) === "World".getBytes)
......@@ -96,6 +111,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
unsafeRow.update(2, "Hello World".getBytes)
assert(unsafeRow.get(2) === "Hello World".getBytes)
assert(pool.size === 2)
// We do not support copy() for UnsafeRows that reference ObjectPools
intercept[UnsupportedOperationException] {
unsafeRow.copy()
}
}
test("basic conversion with primitive, decimal and array") {
......@@ -111,12 +131,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val sizeRequired: Int = converter.getSizeRequirement(row)
assert(sizeRequired === 8 + (8 * 3))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, pool)
val numBytesWritten =
converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, pool)
assert(numBytesWritten === sizeRequired)
assert(pool.size === 2)
val unsafeRow = new UnsafeRow()
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
unsafeRow.pointTo(
buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool)
assert(unsafeRow.getLong(0) === 0)
assert(unsafeRow.get(1) === Decimal(1))
assert(unsafeRow.get(2) === Array(2))
......@@ -142,11 +164,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(sizeRequired === 8 + (8 * 4) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
val numBytesWritten =
converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null)
assert(numBytesWritten === sizeRequired)
val unsafeRow = new UnsafeRow()
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
unsafeRow.pointTo(
buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null)
assert(unsafeRow.getLong(0) === 0)
assert(unsafeRow.getString(1) === "Hello")
// Date is represented as Int in unsafeRow
......@@ -190,12 +214,14 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns)
val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(
rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET,
sizeRequired, null)
assert(numBytesWritten === sizeRequired)
val createdFromNull = new UnsafeRow()
createdFromNull.pointTo(
createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length,
sizeRequired, null)
for (i <- 0 to fieldTypes.length - 1) {
assert(createdFromNull.isNullAt(i))
}
......@@ -233,10 +259,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val pool = new ObjectPool(1)
val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2)
converter.writeRow(
rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, pool)
rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET,
sizeRequired, pool)
val setToNullAfterCreation = new UnsafeRow()
setToNullAfterCreation.pointTo(
setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length,
sizeRequired, pool)
assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment