Skip to content
Snippets Groups Projects
Commit 487d409e authored by Davies Liu's avatar Davies Liu Committed by Reynold Xin
Browse files

[SPARK-11243][SQL] zero out padding bytes in UnsafeRow

For nested StructType, the underline buffer could be used for others before, we should zero out the padding bytes for those primitive types that have less than 8 bytes.

cc cloud-fan

Author: Davies Liu <davies@databricks.com>

Closes #9217 from davies/zero_out.
parent 16dc9f34
No related branches found
No related tags found
No related merge requests found
......@@ -100,19 +100,27 @@ public class UnsafeRowWriter {
}
public void write(int ordinal, boolean value) {
Platform.putBoolean(holder.buffer, getFieldOffset(ordinal), value);
final long offset = getFieldOffset(ordinal);
Platform.putLong(holder.buffer, offset, 0L);
Platform.putBoolean(holder.buffer, offset, value);
}
public void write(int ordinal, byte value) {
Platform.putByte(holder.buffer, getFieldOffset(ordinal), value);
final long offset = getFieldOffset(ordinal);
Platform.putLong(holder.buffer, offset, 0L);
Platform.putByte(holder.buffer, offset, value);
}
public void write(int ordinal, short value) {
Platform.putShort(holder.buffer, getFieldOffset(ordinal), value);
final long offset = getFieldOffset(ordinal);
Platform.putLong(holder.buffer, offset, 0L);
Platform.putShort(holder.buffer, offset, value);
}
public void write(int ordinal, int value) {
Platform.putInt(holder.buffer, getFieldOffset(ordinal), value);
final long offset = getFieldOffset(ordinal);
Platform.putLong(holder.buffer, offset, 0L);
Platform.putInt(holder.buffer, offset, value);
}
public void write(int ordinal, long value) {
......@@ -123,7 +131,9 @@ public class UnsafeRowWriter {
if (Float.isNaN(value)) {
value = Float.NaN;
}
Platform.putFloat(holder.buffer, getFieldOffset(ordinal), value);
final long offset = getFieldOffset(ordinal);
Platform.putLong(holder.buffer, offset, 0L);
Platform.putFloat(holder.buffer, offset, value);
}
public void write(int ordinal, double value) {
......
......@@ -99,4 +99,24 @@ class GeneratedProjectionSuite extends SparkFunSuite {
val row2 = safeProj(unsafeRow)
assert(row2 === row)
}
test("padding bytes should be zeroed out") {
val types = Seq(BooleanType, ByteType, ShortType, IntegerType, FloatType, BinaryType,
StringType)
val struct = StructType(types.map(StructField("", _, true)))
val fields = Array[DataType](StringType, struct)
val unsafeProj = UnsafeProjection.create(fields)
val innerRow = InternalRow(false, 1.toByte, 2.toShort, 3, 4.0f, "".getBytes,
UTF8String.fromString(""))
val row1 = InternalRow(UTF8String.fromString(""), innerRow)
val unsafe1 = unsafeProj(row1).copy()
// create a Row with long String before the inner struct
val row2 = InternalRow(UTF8String.fromString("a_long_string").repeat(10), innerRow)
val unsafe2 = unsafeProj(row2).copy()
assert(unsafe1.getStruct(1, 7) === unsafe2.getStruct(1, 7))
val unsafe3 = unsafeProj(row1).copy()
assert(unsafe1 === unsafe3)
assert(unsafe1.getStruct(1, 7) === unsafe3.getStruct(1, 7))
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment