Skip to content
Snippets Groups Projects
Commit be375fcb authored by Wenchen Fan's avatar Wenchen Fan Committed by Davies Liu
Browse files

[SPARK-12879] [SQL] improve the unsafe row writing framework

As we begin to use unsafe row writing framework(`BufferHolder` and `UnsafeRowWriter`) in more and more places(`UnsafeProjection`, `UnsafeRowParquetRecordReader`, `GenerateColumnAccessor`, etc.), we should add more doc to it and make it easier to use.

This PR abstract the technique used in `UnsafeRowParquetRecordReader`: avoid unnecessary operatition as more as possible. For example, do not always point the row to the buffer at the end, we only need to update the size of row. If all fields are of primitive type, we can even save the row size updating. Then we can apply this technique to more places easily.

a local benchmark shows `UnsafeProjection` is up to 1.7x faster after this PR:
**old version**
```
Intel(R) Core(TM) i7-4960HQ CPU  2.60GHz
unsafe projection:                 Avg Time(ms)    Avg Rate(M/s)  Relative Rate
-------------------------------------------------------------------------------
single long                             2616.04           102.61         1.00 X
single nullable long                    3032.54            88.52         0.86 X
primitive types                         9121.05            29.43         0.29 X
nullable primitive types               12410.60            21.63         0.21 X
```

**new version**
```
Intel(R) Core(TM) i7-4960HQ CPU  2.60GHz
unsafe projection:                 Avg Time(ms)    Avg Rate(M/s)  Relative Rate
-------------------------------------------------------------------------------
single long                             1533.34           175.07         1.00 X
single nullable long                    2306.73           116.37         0.66 X
primitive types                         8403.93            31.94         0.18 X
nullable primitive types               12448.39            21.56         0.12 X
```

For single non-nullable long(the best case), we can have about 1.7x speed up. Even it's nullable, we can still have 1.3x speed up. For other cases, it's not such a boost as the saved operations only take a little proportion of the whole process.  The benchmark code is included in this PR.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #10809 from cloud-fan/unsafe-projection.
parent 6f0f1d9e
No related branches found
No related tags found
No related merge requests found
...@@ -21,24 +21,40 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; ...@@ -21,24 +21,40 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.Platform;
/** /**
* A helper class to manage the row buffer when construct unsafe rows. * A helper class to manage the data buffer for an unsafe row. The data buffer can grow and
* automatically re-point the unsafe row to it.
*
* This class can be used to build a one-pass unsafe row writing program, i.e. data will be written
* to the data buffer directly and no extra copy is needed. There should be only one instance of
* this class per writing program, so that the memory segment/data buffer can be reused. Note that
* for each incoming record, we should call `reset` of BufferHolder instance before write the record
* and reuse the data buffer.
*
* Generally we should call `UnsafeRow.setTotalSize` and pass in `BufferHolder.totalSize` to update
* the size of the result row, after writing a record to the buffer. However, we can skip this step
* if the fields of row are all fixed-length, as the size of result row is also fixed.
*/ */
public class BufferHolder { public class BufferHolder {
public byte[] buffer; public byte[] buffer;
public int cursor = Platform.BYTE_ARRAY_OFFSET; public int cursor = Platform.BYTE_ARRAY_OFFSET;
private final UnsafeRow row;
private final int fixedSize;
public BufferHolder() { public BufferHolder(UnsafeRow row) {
this(64); this(row, 64);
} }
public BufferHolder(int size) { public BufferHolder(UnsafeRow row, int initialSize) {
buffer = new byte[size]; this.fixedSize = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()) + 8 * row.numFields();
this.buffer = new byte[fixedSize + initialSize];
this.row = row;
this.row.pointTo(buffer, buffer.length);
} }
/** /**
* Grows the buffer to at least neededSize. If row is non-null, points the row to the buffer. * Grows the buffer by at least neededSize and points the row to the buffer.
*/ */
public void grow(int neededSize, UnsafeRow row) { public void grow(int neededSize) {
final int length = totalSize() + neededSize; final int length = totalSize() + neededSize;
if (buffer.length < length) { if (buffer.length < length) {
// This will not happen frequently, because the buffer is re-used. // This will not happen frequently, because the buffer is re-used.
...@@ -50,22 +66,12 @@ public class BufferHolder { ...@@ -50,22 +66,12 @@ public class BufferHolder {
Platform.BYTE_ARRAY_OFFSET, Platform.BYTE_ARRAY_OFFSET,
totalSize()); totalSize());
buffer = tmp; buffer = tmp;
if (row != null) { row.pointTo(buffer, buffer.length);
row.pointTo(buffer, length * 2);
}
} }
} }
public void grow(int neededSize) {
grow(neededSize, null);
}
public void reset() { public void reset() {
cursor = Platform.BYTE_ARRAY_OFFSET; cursor = Platform.BYTE_ARRAY_OFFSET + fixedSize;
}
public void resetTo(int offset) {
assert(offset <= buffer.length);
cursor = Platform.BYTE_ARRAY_OFFSET + offset;
} }
public int totalSize() { public int totalSize() {
......
...@@ -26,38 +26,56 @@ import org.apache.spark.unsafe.types.CalendarInterval; ...@@ -26,38 +26,56 @@ import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String; import org.apache.spark.unsafe.types.UTF8String;
/** /**
* A helper class to write data into global row buffer using `UnsafeRow` format, * A helper class to write data into global row buffer using `UnsafeRow` format.
* used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. *
* It will remember the offset of row buffer which it starts to write, and move the cursor of row
* buffer while writing. If new data(can be the input record if this is the outermost writer, or
* nested struct if this is an inner writer) comes, the starting cursor of row buffer may be
* changed, so we need to call `UnsafeRowWriter.reset` before writing, to update the
* `startingOffset` and clear out null bits.
*
* Note that if this is the outermost writer, which means we will always write from the very
* beginning of the global row buffer, we don't need to update `startingOffset` and can just call
* `zeroOutNullBytes` before writing new data.
*/ */
public class UnsafeRowWriter { public class UnsafeRowWriter {
private BufferHolder holder; private final BufferHolder holder;
// The offset of the global buffer where we start to write this row. // The offset of the global buffer where we start to write this row.
private int startingOffset; private int startingOffset;
private int nullBitsSize; private final int nullBitsSize;
private UnsafeRow row; private final int fixedSize;
public void initialize(BufferHolder holder, int numFields) { public UnsafeRowWriter(BufferHolder holder, int numFields) {
this.holder = holder; this.holder = holder;
this.startingOffset = holder.cursor;
this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields); this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields);
this.fixedSize = nullBitsSize + 8 * numFields;
this.startingOffset = holder.cursor;
}
/**
* Resets the `startingOffset` according to the current cursor of row buffer, and clear out null
* bits. This should be called before we write a new nested struct to the row buffer.
*/
public void reset() {
this.startingOffset = holder.cursor;
// grow the global buffer to make sure it has enough space to write fixed-length data. // grow the global buffer to make sure it has enough space to write fixed-length data.
final int fixedSize = nullBitsSize + 8 * numFields; holder.grow(fixedSize);
holder.grow(fixedSize, row);
holder.cursor += fixedSize; holder.cursor += fixedSize;
// zero-out the null bits region zeroOutNullBytes();
}
/**
* Clears out null bits. This should be called before we write a new row to row buffer.
*/
public void zeroOutNullBytes() {
for (int i = 0; i < nullBitsSize; i += 8) { for (int i = 0; i < nullBitsSize; i += 8) {
Platform.putLong(holder.buffer, startingOffset + i, 0L); Platform.putLong(holder.buffer, startingOffset + i, 0L);
} }
} }
public void initialize(UnsafeRow row, BufferHolder holder, int numFields) {
initialize(holder, numFields);
this.row = row;
}
private void zeroOutPaddingBytes(int numBytes) { private void zeroOutPaddingBytes(int numBytes) {
if ((numBytes & 0x07) > 0) { if ((numBytes & 0x07) > 0) {
Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L);
...@@ -98,7 +116,7 @@ public class UnsafeRowWriter { ...@@ -98,7 +116,7 @@ public class UnsafeRowWriter {
if (remainder > 0) { if (remainder > 0) {
final int paddingBytes = 8 - remainder; final int paddingBytes = 8 - remainder;
holder.grow(paddingBytes, row); holder.grow(paddingBytes);
for (int i = 0; i < paddingBytes; i++) { for (int i = 0; i < paddingBytes; i++) {
Platform.putByte(holder.buffer, holder.cursor, (byte) 0); Platform.putByte(holder.buffer, holder.cursor, (byte) 0);
...@@ -161,7 +179,7 @@ public class UnsafeRowWriter { ...@@ -161,7 +179,7 @@ public class UnsafeRowWriter {
} }
} else { } else {
// grow the global buffer before writing data. // grow the global buffer before writing data.
holder.grow(16, row); holder.grow(16);
// zero-out the bytes // zero-out the bytes
Platform.putLong(holder.buffer, holder.cursor, 0L); Platform.putLong(holder.buffer, holder.cursor, 0L);
...@@ -193,7 +211,7 @@ public class UnsafeRowWriter { ...@@ -193,7 +211,7 @@ public class UnsafeRowWriter {
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
// grow the global buffer before writing data. // grow the global buffer before writing data.
holder.grow(roundedSize, row); holder.grow(roundedSize);
zeroOutPaddingBytes(numBytes); zeroOutPaddingBytes(numBytes);
...@@ -214,7 +232,7 @@ public class UnsafeRowWriter { ...@@ -214,7 +232,7 @@ public class UnsafeRowWriter {
final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
// grow the global buffer before writing data. // grow the global buffer before writing data.
holder.grow(roundedSize, row); holder.grow(roundedSize);
zeroOutPaddingBytes(numBytes); zeroOutPaddingBytes(numBytes);
...@@ -230,7 +248,7 @@ public class UnsafeRowWriter { ...@@ -230,7 +248,7 @@ public class UnsafeRowWriter {
public void write(int ordinal, CalendarInterval input) { public void write(int ordinal, CalendarInterval input) {
// grow the global buffer before writing data. // grow the global buffer before writing data.
holder.grow(16, row); holder.grow(16);
// Write the months and microseconds fields of Interval to the variable length portion. // Write the months and microseconds fields of Interval to the variable length portion.
Platform.putLong(holder.buffer, holder.cursor, input.months); Platform.putLong(holder.buffer, holder.cursor, input.months);
......
...@@ -43,9 +43,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ...@@ -43,9 +43,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => false case _ => false
} }
private val rowWriterClass = classOf[UnsafeRowWriter].getName
private val arrayWriterClass = classOf[UnsafeArrayWriter].getName
// TODO: if the nullability of field is correct, we can use it to save null check. // TODO: if the nullability of field is correct, we can use it to save null check.
private def writeStructToBuffer( private def writeStructToBuffer(
ctx: CodegenContext, ctx: CodegenContext,
...@@ -73,9 +70,27 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ...@@ -73,9 +70,27 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
row: String, row: String,
inputs: Seq[ExprCode], inputs: Seq[ExprCode],
inputTypes: Seq[DataType], inputTypes: Seq[DataType],
bufferHolder: String): String = { bufferHolder: String,
isTopLevel: Boolean = false): String = {
val rowWriterClass = classOf[UnsafeRowWriter].getName
val rowWriter = ctx.freshName("rowWriter") val rowWriter = ctx.freshName("rowWriter")
ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();") ctx.addMutableState(rowWriterClass, rowWriter,
s"this.$rowWriter = new $rowWriterClass($bufferHolder, ${inputs.length});")
val resetWriter = if (isTopLevel) {
// For top level row writer, it always writes to the beginning of the global buffer holder,
// which means its fixed-size region always in the same position, so we don't need to call
// `reset` to set up its fixed-size region every time.
if (inputs.map(_.isNull).forall(_ == "false")) {
// If all fields are not nullable, which means the null bits never changes, then we don't
// need to clear it out every time.
""
} else {
s"$rowWriter.zeroOutNullBytes();"
}
} else {
s"$rowWriter.reset();"
}
val writeFields = inputs.zip(inputTypes).zipWithIndex.map { val writeFields = inputs.zip(inputTypes).zipWithIndex.map {
case ((input, dataType), index) => case ((input, dataType), index) =>
...@@ -122,11 +137,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ...@@ -122,11 +137,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor); $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor);
""" """
case _ if ctx.isPrimitiveType(dt) =>
s"""
$rowWriter.write($index, ${input.value});
"""
case t: DecimalType => case t: DecimalType =>
s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});" s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});"
...@@ -153,7 +163,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ...@@ -153,7 +163,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
} }
s""" s"""
$rowWriter.initialize($bufferHolder, ${inputs.length}); $resetWriter
${ctx.splitExpressions(row, writeFields)} ${ctx.splitExpressions(row, writeFields)}
""".trim """.trim
} }
...@@ -164,6 +174,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ...@@ -164,6 +174,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
input: String, input: String,
elementType: DataType, elementType: DataType,
bufferHolder: String): String = { bufferHolder: String): String = {
val arrayWriterClass = classOf[UnsafeArrayWriter].getName
val arrayWriter = ctx.freshName("arrayWriter") val arrayWriter = ctx.freshName("arrayWriter")
ctx.addMutableState(arrayWriterClass, arrayWriter, ctx.addMutableState(arrayWriterClass, arrayWriter,
s"this.$arrayWriter = new $arrayWriterClass();") s"this.$arrayWriter = new $arrayWriterClass();")
...@@ -288,22 +299,43 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ...@@ -288,22 +299,43 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
val exprTypes = expressions.map(_.dataType) val exprTypes = expressions.map(_.dataType)
val numVarLenFields = exprTypes.count {
case dt if UnsafeRow.isFixedLength(dt) => false
// TODO: consider large decimal and interval type
case _ => true
}
val result = ctx.freshName("result") val result = ctx.freshName("result")
ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});") ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});")
val bufferHolder = ctx.freshName("bufferHolder")
val holder = ctx.freshName("holder")
val holderClass = classOf[BufferHolder].getName val holderClass = classOf[BufferHolder].getName
ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") ctx.addMutableState(holderClass, holder,
s"this.$holder = new $holderClass($result, ${numVarLenFields * 32});")
val resetBufferHolder = if (numVarLenFields == 0) {
""
} else {
s"$holder.reset();"
}
val updateRowSize = if (numVarLenFields == 0) {
""
} else {
s"$result.setTotalSize($holder.totalSize());"
}
// Evaluate all the subexpression. // Evaluate all the subexpression.
val evalSubexpr = ctx.subexprFunctions.mkString("\n") val evalSubexpr = ctx.subexprFunctions.mkString("\n")
val writeExpressions =
writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, holder, isTopLevel = true)
val code = val code =
s""" s"""
$bufferHolder.reset(); $resetBufferHolder
$evalSubexpr $evalSubexpr
${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)} $writeExpressions
$updateRowSize
$result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize());
""" """
ExprCode(code, "false", result) ExprCode(code, "false", result)
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.types._
import org.apache.spark.util.Benchmark
/**
* Benchmark [[UnsafeProjection]] for fixed-length/primitive-type fields.
*/
object UnsafeProjectionBenchmark {
def generateRows(schema: StructType, numRows: Int): Array[InternalRow] = {
val generator = RandomDataGenerator.forType(schema, nullable = false).get
val encoder = RowEncoder(schema)
(1 to numRows).map(_ => encoder.toRow(generator().asInstanceOf[Row]).copy()).toArray
}
def main(args: Array[String]) {
val iters = 1024 * 16
val numRows = 1024 * 16
val benchmark = new Benchmark("unsafe projection", iters * numRows)
val schema1 = new StructType().add("l", LongType, false)
val attrs1 = schema1.toAttributes
val rows1 = generateRows(schema1, numRows)
val projection1 = UnsafeProjection.create(attrs1, attrs1)
benchmark.addCase("single long") { _ =>
for (_ <- 1 to iters) {
var sum = 0L
var i = 0
while (i < numRows) {
sum += projection1(rows1(i)).getLong(0)
i += 1
}
}
}
val schema2 = new StructType().add("l", LongType, true)
val attrs2 = schema2.toAttributes
val rows2 = generateRows(schema2, numRows)
val projection2 = UnsafeProjection.create(attrs2, attrs2)
benchmark.addCase("single nullable long") { _ =>
for (_ <- 1 to iters) {
var sum = 0L
var i = 0
while (i < numRows) {
sum += projection2(rows2(i)).getLong(0)
i += 1
}
}
}
val schema3 = new StructType()
.add("boolean", BooleanType, false)
.add("byte", ByteType, false)
.add("short", ShortType, false)
.add("int", IntegerType, false)
.add("long", LongType, false)
.add("float", FloatType, false)
.add("double", DoubleType, false)
val attrs3 = schema3.toAttributes
val rows3 = generateRows(schema3, numRows)
val projection3 = UnsafeProjection.create(attrs3, attrs3)
benchmark.addCase("7 primitive types") { _ =>
for (_ <- 1 to iters) {
var sum = 0L
var i = 0
while (i < numRows) {
sum += projection3(rows3(i)).getLong(0)
i += 1
}
}
}
val schema4 = new StructType()
.add("boolean", BooleanType, true)
.add("byte", ByteType, true)
.add("short", ShortType, true)
.add("int", IntegerType, true)
.add("long", LongType, true)
.add("float", FloatType, true)
.add("double", DoubleType, true)
val attrs4 = schema4.toAttributes
val rows4 = generateRows(schema4, numRows)
val projection4 = UnsafeProjection.create(attrs4, attrs4)
benchmark.addCase("7 nullable primitive types") { _ =>
for (_ <- 1 to iters) {
var sum = 0L
var i = 0
while (i < numRows) {
sum += projection4(rows4(i)).getLong(0)
i += 1
}
}
}
/*
Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
unsafe projection: Avg Time(ms) Avg Rate(M/s) Relative Rate
-------------------------------------------------------------------------------
single long 1533.34 175.07 1.00 X
single nullable long 2306.73 116.37 0.66 X
primitive types 8403.93 31.94 0.18 X
nullable primitive types 12448.39 21.56 0.12 X
*/
benchmark.run()
}
}
...@@ -73,11 +73,6 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas ...@@ -73,11 +73,6 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
*/ */
private boolean containsVarLenFields; private boolean containsVarLenFields;
/**
* The number of bytes in the fixed length portion of the row.
*/
private int fixedSizeBytes;
/** /**
* For each request column, the reader to read this column. * For each request column, the reader to read this column.
* columnsReaders[i] populated the UnsafeRow's attribute at i. * columnsReaders[i] populated the UnsafeRow's attribute at i.
...@@ -266,19 +261,13 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas ...@@ -266,19 +261,13 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
/** /**
* Initialize rows and rowWriters. These objects are reused across all rows in the relation. * Initialize rows and rowWriters. These objects are reused across all rows in the relation.
*/ */
int rowByteSize = UnsafeRow.calculateBitSetWidthInBytes(requestedSchema.getFieldCount());
rowByteSize += 8 * requestedSchema.getFieldCount();
fixedSizeBytes = rowByteSize;
rowByteSize += numVarLenFields * DEFAULT_VAR_LEN_SIZE;
containsVarLenFields = numVarLenFields > 0; containsVarLenFields = numVarLenFields > 0;
rowWriters = new UnsafeRowWriter[rows.length]; rowWriters = new UnsafeRowWriter[rows.length];
for (int i = 0; i < rows.length; ++i) { for (int i = 0; i < rows.length; ++i) {
rows[i] = new UnsafeRow(requestedSchema.getFieldCount()); rows[i] = new UnsafeRow(requestedSchema.getFieldCount());
rowWriters[i] = new UnsafeRowWriter(); BufferHolder holder = new BufferHolder(rows[i], numVarLenFields * DEFAULT_VAR_LEN_SIZE);
BufferHolder holder = new BufferHolder(rowByteSize); rowWriters[i] = new UnsafeRowWriter(holder, requestedSchema.getFieldCount());
rowWriters[i].initialize(rows[i], holder, requestedSchema.getFieldCount());
rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, holder.buffer.length);
} }
} }
...@@ -295,7 +284,7 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas ...@@ -295,7 +284,7 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas
if (containsVarLenFields) { if (containsVarLenFields) {
for (int i = 0; i < rowWriters.length; ++i) { for (int i = 0; i < rowWriters.length; ++i) {
rowWriters[i].holder().resetTo(fixedSizeBytes); rowWriters[i].holder().reset();
} }
} }
......
...@@ -132,8 +132,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera ...@@ -132,8 +132,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
private ByteOrder nativeOrder = null; private ByteOrder nativeOrder = null;
private byte[][] buffers = null; private byte[][] buffers = null;
private UnsafeRow unsafeRow = new UnsafeRow($numFields); private UnsafeRow unsafeRow = new UnsafeRow($numFields);
private BufferHolder bufferHolder = new BufferHolder(); private BufferHolder bufferHolder = new BufferHolder(unsafeRow);
private UnsafeRowWriter rowWriter = new UnsafeRowWriter(); private UnsafeRowWriter rowWriter = new UnsafeRowWriter(bufferHolder, $numFields);
private MutableUnsafeRow mutableRow = null; private MutableUnsafeRow mutableRow = null;
private int currentRow = 0; private int currentRow = 0;
...@@ -181,9 +181,9 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera ...@@ -181,9 +181,9 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera
public InternalRow next() { public InternalRow next() {
currentRow += 1; currentRow += 1;
bufferHolder.reset(); bufferHolder.reset();
rowWriter.initialize(bufferHolder, $numFields); rowWriter.zeroOutNullBytes();
${extractors.mkString("\n")} ${extractors.mkString("\n")}
unsafeRow.pointTo(bufferHolder.buffer, bufferHolder.totalSize()); unsafeRow.setTotalSize(bufferHolder.totalSize());
return unsafeRow; return unsafeRow;
} }
}""" }"""
......
...@@ -98,16 +98,15 @@ private[sql] class TextRelation( ...@@ -98,16 +98,15 @@ private[sql] class TextRelation(
sqlContext.sparkContext.hadoopRDD( sqlContext.sparkContext.hadoopRDD(
conf.asInstanceOf[JobConf], classOf[TextInputFormat], classOf[LongWritable], classOf[Text]) conf.asInstanceOf[JobConf], classOf[TextInputFormat], classOf[LongWritable], classOf[Text])
.mapPartitions { iter => .mapPartitions { iter =>
val bufferHolder = new BufferHolder
val unsafeRowWriter = new UnsafeRowWriter
val unsafeRow = new UnsafeRow(1) val unsafeRow = new UnsafeRow(1)
val bufferHolder = new BufferHolder(unsafeRow)
val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1)
iter.map { case (_, line) => iter.map { case (_, line) =>
// Writes to an UnsafeRow directly // Writes to an UnsafeRow directly
bufferHolder.reset() bufferHolder.reset()
unsafeRowWriter.initialize(bufferHolder, 1)
unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) unsafeRowWriter.write(0, line.getBytes, 0, line.getLength)
unsafeRow.pointTo(bufferHolder.buffer, bufferHolder.totalSize()) unsafeRow.setTotalSize(bufferHolder.totalSize())
unsafeRow unsafeRow
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment