diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 3513960b418135697fc3c79ff703c611c74b5e1e..3d80df227151d1118c508a998815019502aa75a0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -270,8 +270,8 @@ public class UnsafeArrayData extends ArrayData { final int offset = getElementOffset(ordinal); if (offset < 0) return null; final int size = getElementSize(offset, ordinal); - final UnsafeRow row = new UnsafeRow(); - row.pointTo(baseObject, baseOffset + offset, numFields, size); + final UnsafeRow row = new UnsafeRow(numFields); + row.pointTo(baseObject, baseOffset + offset, size); return row; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index b6979d0c82977070d33a7fa7360d3e71a5c3a96e..7492b88c471a4447e7bc3393bc3f74ce2c2f5fe9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,11 +17,7 @@ package org.apache.spark.sql.catalyst.expressions; -import java.io.Externalizable; -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; -import java.io.OutputStream; +import java.io.*; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; @@ -30,26 +26,12 @@ import java.util.Collections; import java.util.HashSet; import java.util.Set; -import org.apache.spark.sql.types.ArrayType; -import org.apache.spark.sql.types.BinaryType; -import org.apache.spark.sql.types.BooleanType; -import org.apache.spark.sql.types.ByteType; -import org.apache.spark.sql.types.CalendarIntervalType; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DateType; -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.sql.types.DecimalType; -import org.apache.spark.sql.types.DoubleType; -import org.apache.spark.sql.types.FloatType; -import org.apache.spark.sql.types.IntegerType; -import org.apache.spark.sql.types.LongType; -import org.apache.spark.sql.types.MapType; -import org.apache.spark.sql.types.NullType; -import org.apache.spark.sql.types.ShortType; -import org.apache.spark.sql.types.StringType; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.types.TimestampType; -import org.apache.spark.sql.types.UserDefinedType; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + +import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; @@ -57,23 +39,9 @@ import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; -import static org.apache.spark.sql.types.DataTypes.BooleanType; -import static org.apache.spark.sql.types.DataTypes.ByteType; -import static org.apache.spark.sql.types.DataTypes.DateType; -import static org.apache.spark.sql.types.DataTypes.DoubleType; -import static org.apache.spark.sql.types.DataTypes.FloatType; -import static org.apache.spark.sql.types.DataTypes.IntegerType; -import static org.apache.spark.sql.types.DataTypes.LongType; -import static org.apache.spark.sql.types.DataTypes.NullType; -import static org.apache.spark.sql.types.DataTypes.ShortType; -import static org.apache.spark.sql.types.DataTypes.TimestampType; +import static org.apache.spark.sql.types.DataTypes.*; import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; - /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. * @@ -167,8 +135,16 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS /** * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called, * since the value returned by this constructor is equivalent to a null pointer. + * + * @param numFields the number of fields in this row */ - public UnsafeRow() { } + public UnsafeRow(int numFields) { + this.numFields = numFields; + this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); + } + + // for serializer + public UnsafeRow() {} public Object getBaseObject() { return baseObject; } public long getBaseOffset() { return baseOffset; } @@ -182,15 +158,12 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS * * @param baseObject the base object * @param baseOffset the offset within the base object - * @param numFields the number of fields in this row * @param sizeInBytes the size of this row's backing data, in bytes */ - public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeInBytes) { + public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { assert numFields >= 0 : "numFields (" + numFields + ") should >= 0"; - this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); this.baseObject = baseObject; this.baseOffset = baseOffset; - this.numFields = numFields; this.sizeInBytes = sizeInBytes; } @@ -198,23 +171,12 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS * Update this UnsafeRow to point to the underlying byte array. * * @param buf byte array to point to - * @param numFields the number of fields in this row - * @param sizeInBytes the number of bytes valid in the byte array - */ - public void pointTo(byte[] buf, int numFields, int sizeInBytes) { - pointTo(buf, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); - } - - /** - * Updates this UnsafeRow preserving the number of fields. - * @param buf byte array to point to * @param sizeInBytes the number of bytes valid in the byte array */ public void pointTo(byte[] buf, int sizeInBytes) { - pointTo(buf, numFields, sizeInBytes); + pointTo(buf, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); } - public void setNotNullAt(int i) { assertIndexIsValid(i); BitSetMethods.unset(baseObject, baseOffset, i); @@ -489,8 +451,8 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - final UnsafeRow row = new UnsafeRow(); - row.pointTo(baseObject, baseOffset + offset, numFields, size); + final UnsafeRow row = new UnsafeRow(numFields); + row.pointTo(baseObject, baseOffset + offset, size); return row; } } @@ -529,7 +491,7 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS */ @Override public UnsafeRow copy() { - UnsafeRow rowCopy = new UnsafeRow(); + UnsafeRow rowCopy = new UnsafeRow(numFields); final byte[] rowDataCopy = new byte[sizeInBytes]; Platform.copyMemory( baseObject, @@ -538,7 +500,7 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS Platform.BYTE_ARRAY_OFFSET, sizeInBytes ); - rowCopy.pointTo(rowDataCopy, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); + rowCopy.pointTo(rowDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); return rowCopy; } @@ -547,8 +509,8 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS * The returned row is invalid until we call copyFrom on it. */ public static UnsafeRow createFromByteArray(int numBytes, int numFields) { - final UnsafeRow row = new UnsafeRow(); - row.pointTo(new byte[numBytes], numFields, numBytes); + final UnsafeRow row = new UnsafeRow(numFields); + row.pointTo(new byte[numBytes], numBytes); return row; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 352002b3499a23d4430e9ff951b414de8cd13f45..27ae62f1212f66f37b0c1f460397ea17c32065ff 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -26,10 +26,9 @@ import com.google.common.annotations.VisibleForTesting; import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; -import org.apache.spark.sql.catalyst.util.AbstractScalaRowIterator; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.util.AbstractScalaRowIterator; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; @@ -123,7 +122,7 @@ final class UnsafeExternalRowSorter { return new AbstractScalaRowIterator<UnsafeRow>() { private final int numFields = schema.length(); - private UnsafeRow row = new UnsafeRow(); + private UnsafeRow row = new UnsafeRow(numFields); @Override public boolean hasNext() { @@ -137,7 +136,6 @@ final class UnsafeExternalRowSorter { row.pointTo( sortedIterator.getBaseObject(), sortedIterator.getBaseOffset(), - numFields, sortedIterator.getRecordLength()); if (!hasNext()) { UnsafeRow copy = row.copy(); // so that we don't have dangling pointers to freed page @@ -173,19 +171,21 @@ final class UnsafeExternalRowSorter { private static final class RowComparator extends RecordComparator { private final Ordering<InternalRow> ordering; private final int numFields; - private final UnsafeRow row1 = new UnsafeRow(); - private final UnsafeRow row2 = new UnsafeRow(); + private final UnsafeRow row1; + private final UnsafeRow row2; public RowComparator(Ordering<InternalRow> ordering, int numFields) { this.numFields = numFields; + this.row1 = new UnsafeRow(numFields); + this.row2 = new UnsafeRow(numFields); this.ordering = ordering; } @Override public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { // TODO: Why are the sizes -1? - row1.pointTo(baseObj1, baseOff1, numFields, -1); - row2.pointTo(baseObj2, baseOff2, numFields, -1); + row1.pointTo(baseObj1, baseOff1, -1); + row2.pointTo(baseObj2, baseOff2, -1); return ordering.compare(row1, row2); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index c1defe12b0b91d8813341afa47d1c91480471358..d0e031f27990c70c6b0cc48c64b1864e2dee607e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -289,7 +289,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val exprTypes = expressions.map(_.dataType) val result = ctx.freshName("result") - ctx.addMutableState("UnsafeRow", result, s"this.$result = new UnsafeRow();") + ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});") val bufferHolder = ctx.freshName("bufferHolder") val holderClass = classOf[BufferHolder].getName ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") @@ -303,7 +303,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $subexprReset ${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)} - $result.pointTo($bufferHolder.buffer, ${expressions.length}, $bufferHolder.totalSize()); + $result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize()); """ GeneratedExpressionCode(code, "false", result) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index da602d9b4bce1a6c15a553f280dd85c904524692..c9ff357bf34763c2fda9e0e348909fb47f71213a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -165,7 +165,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | |class SpecificUnsafeRowJoiner extends ${classOf[UnsafeRowJoiner].getName} { | private byte[] buf = new byte[64]; - | private UnsafeRow out = new UnsafeRow(); + | private UnsafeRow out = new UnsafeRow(${schema1.size + schema2.size}); | | public UnsafeRow join(UnsafeRow row1, UnsafeRow row2) { | // row1: ${schema1.size} fields, $bitset1Words words in bitset @@ -188,7 +188,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | $copyVariableLengthRow2 | $updateOffset | - | out.pointTo(buf, ${schema1.size + schema2.size}, sizeInBytes - $sizeReduction); + | out.pointTo(buf, sizeInBytes - $sizeReduction); | | return out; | } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala index 796d60032e1a6fc99e7f1d3de4f98177c150dff3..f8342214d9ae08a0cff80da8aa6d9fc1e83306c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala @@ -90,13 +90,13 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite { } private def createUnsafeRow(numFields: Int): UnsafeRow = { - val row = new UnsafeRow + val row = new UnsafeRow(numFields) val sizeInBytes = numFields * 8 + ((numFields + 63) / 64) * 8 // Allocate a larger buffer than needed and point the UnsafeRow to somewhere in the middle. // This way we can test the joiner when the input UnsafeRows are not the entire arrays. val offset = numFields * 8 val buf = new Array[Byte](sizeInBytes + offset) - row.pointTo(buf, Platform.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes) + row.pointTo(buf, Platform.BYTE_ARRAY_OFFSET + offset, sizeInBytes) row } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index a2f99d566d4711caef021ddbe9230d6dd6b68283..6bf9d7bd0367c4657946948804436f0c1153bf27 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -61,7 +61,7 @@ public final class UnsafeFixedWidthAggregationMap { /** * Re-used pointer to the current aggregation buffer */ - private final UnsafeRow currentAggregationBuffer = new UnsafeRow(); + private final UnsafeRow currentAggregationBuffer; private final boolean enablePerfMetrics; @@ -98,6 +98,7 @@ public final class UnsafeFixedWidthAggregationMap { long pageSizeBytes, boolean enablePerfMetrics) { this.aggregationBufferSchema = aggregationBufferSchema; + this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length()); this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; this.map = @@ -147,7 +148,6 @@ public final class UnsafeFixedWidthAggregationMap { currentAggregationBuffer.pointTo( address.getBaseObject(), address.getBaseOffset(), - aggregationBufferSchema.length(), loc.getValueLength() ); return currentAggregationBuffer; @@ -165,8 +165,8 @@ public final class UnsafeFixedWidthAggregationMap { private final BytesToBytesMap.MapIterator mapLocationIterator = map.destructiveIterator(); - private final UnsafeRow key = new UnsafeRow(); - private final UnsafeRow value = new UnsafeRow(); + private final UnsafeRow key = new UnsafeRow(groupingKeySchema.length()); + private final UnsafeRow value = new UnsafeRow(aggregationBufferSchema.length()); @Override public boolean next() { @@ -177,13 +177,11 @@ public final class UnsafeFixedWidthAggregationMap { key.pointTo( keyAddress.getBaseObject(), keyAddress.getBaseOffset(), - groupingKeySchema.length(), loc.getKeyLength() ); value.pointTo( valueAddress.getBaseObject(), valueAddress.getBaseOffset(), - aggregationBufferSchema.length(), loc.getValueLength() ); return true; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 8c9b9c85e37fc265406d7aed03080c15e8bff074..0da26bf376a6a30992438040e9f5ffc496a520b5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -94,7 +94,7 @@ public final class UnsafeKVExternalSorter { // The only new memory we are allocating is the pointer/prefix array. BytesToBytesMap.MapIterator iter = map.iterator(); final int numKeyFields = keySchema.size(); - UnsafeRow row = new UnsafeRow(); + UnsafeRow row = new UnsafeRow(numKeyFields); while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); final Object baseObject = loc.getKeyAddress().getBaseObject(); @@ -107,7 +107,7 @@ public final class UnsafeKVExternalSorter { long address = taskMemoryManager.encodePageNumberAndOffset(page, baseOffset - 8); // Compute prefix - row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength()); + row.pointTo(baseObject, baseOffset, loc.getKeyLength()); final long prefix = prefixComputer.computePrefix(row); inMemSorter.insertRecord(address, prefix); @@ -194,12 +194,14 @@ public final class UnsafeKVExternalSorter { private static final class KVComparator extends RecordComparator { private final BaseOrdering ordering; - private final UnsafeRow row1 = new UnsafeRow(); - private final UnsafeRow row2 = new UnsafeRow(); + private final UnsafeRow row1; + private final UnsafeRow row2; private final int numKeyFields; public KVComparator(BaseOrdering ordering, int numKeyFields) { this.numKeyFields = numKeyFields; + this.row1 = new UnsafeRow(numKeyFields); + this.row2 = new UnsafeRow(numKeyFields); this.ordering = ordering; } @@ -207,17 +209,15 @@ public final class UnsafeKVExternalSorter { public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { // Note that since ordering doesn't need the total length of the record, we just pass -1 // into the row. - row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1); - row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1); + row1.pointTo(baseObj1, baseOff1 + 4, -1); + row2.pointTo(baseObj2, baseOff2 + 4, -1); return ordering.compare(row1, row2); } } public class KVSorterIterator extends KVIterator<UnsafeRow, UnsafeRow> { - private UnsafeRow key = new UnsafeRow(); - private UnsafeRow value = new UnsafeRow(); - private final int numKeyFields = keySchema.size(); - private final int numValueFields = valueSchema.size(); + private UnsafeRow key = new UnsafeRow(keySchema.size()); + private UnsafeRow value = new UnsafeRow(valueSchema.size()); private final UnsafeSorterIterator underlying; private KVSorterIterator(UnsafeSorterIterator underlying) { @@ -237,8 +237,8 @@ public final class UnsafeKVExternalSorter { // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself) int keyLen = Platform.getInt(baseObj, recordOffset); int valueLen = recordLen - keyLen - 4; - key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen); - value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen); + key.pointTo(baseObj, recordOffset + 4, keyLen); + value.pointTo(baseObj, recordOffset + 4 + keyLen, valueLen); return true; } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 0cc4566c9cddeba25538e3843dca2df2263c10d1..a6758bddfa7d03670be349d643ed0874e25b87fa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -21,35 +21,28 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; -import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.types.UTF8String; - -import static org.apache.parquet.column.ValuesType.DEFINITION_LEVEL; -import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL; -import static org.apache.parquet.column.ValuesType.VALUES; - import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.parquet.Preconditions; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Dictionary; import org.apache.parquet.column.Encoding; -import org.apache.parquet.column.page.DataPage; -import org.apache.parquet.column.page.DataPageV1; -import org.apache.parquet.column.page.DataPageV2; -import org.apache.parquet.column.page.DictionaryPage; -import org.apache.parquet.column.page.PageReadStore; -import org.apache.parquet.column.page.PageReader; +import org.apache.parquet.column.page.*; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.io.api.Binary; import org.apache.parquet.schema.OriginalType; import org.apache.parquet.schema.PrimitiveType; import org.apache.parquet.schema.Type; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.UTF8String; + +import static org.apache.parquet.column.ValuesType.*; + /** * A specialized RecordReader that reads into UnsafeRows directly using the Parquet column APIs. * @@ -181,12 +174,11 @@ public class UnsafeRowParquetRecordReader extends SpecificParquetRecordReaderBas rowWriters = new UnsafeRowWriter[rows.length]; for (int i = 0; i < rows.length; ++i) { - rows[i] = new UnsafeRow(); + rows[i] = new UnsafeRow(requestedSchema.getFieldCount()); rowWriters[i] = new UnsafeRowWriter(); BufferHolder holder = new BufferHolder(rowByteSize); rowWriters[i].initialize(rows[i], holder, requestedSchema.getFieldCount()); - rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, requestedSchema.getFieldCount(), - holder.buffer.length); + rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, holder.buffer.length); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 7e981268de3925d71d5add9e8a6bb7d696c38073..4730647c4be9c9b3f5009af6f947c8a44a0e4f8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -94,7 +94,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst private[this] val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in)) // 1024 is a default buffer size; this buffer will grow to accommodate larger rows private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) - private[this] var row: UnsafeRow = new UnsafeRow() + private[this] var row: UnsafeRow = new UnsafeRow(numFields) private[this] var rowTuple: (Int, UnsafeRow) = (0, row) private[this] val EOF: Int = -1 @@ -117,7 +117,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) + row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, rowSize) rowSize = readSize() if (rowSize == EOF) { // We are returning the last row in this stream dIn.close() @@ -152,7 +152,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) + row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, rowSize) row.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index c9f2329db4b6da6d56eb8196c05bb0e6ad875330..9c908b2877e795495ac67726cf5593ba8531706a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -574,11 +574,10 @@ private[columnar] case class STRUCT(dataType: StructType) assert(buffer.hasArray) val cursor = buffer.position() buffer.position(cursor + sizeInBytes) - val unsafeRow = new UnsafeRow + val unsafeRow = new UnsafeRow(numOfFields) unsafeRow.pointTo( buffer.array(), Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, - numOfFields, sizeInBytes) unsafeRow } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index eaafc96e4d2e72203f4a37f798a04c231163b930..b208425ffc3c3e02cd70b6f22accae97864f48c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -131,7 +131,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private ByteOrder nativeOrder = null; private byte[][] buffers = null; - private UnsafeRow unsafeRow = new UnsafeRow(); + private UnsafeRow unsafeRow = new UnsafeRow($numFields); private BufferHolder bufferHolder = new BufferHolder(); private UnsafeRowWriter rowWriter = new UnsafeRowWriter(); private MutableUnsafeRow mutableRow = null; @@ -183,7 +183,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera bufferHolder.reset(); rowWriter.initialize(bufferHolder, $numFields); ${extractors.mkString("\n")} - unsafeRow.pointTo(bufferHolder.buffer, $numFields, bufferHolder.totalSize()); + unsafeRow.pointTo(bufferHolder.buffer, bufferHolder.totalSize()); return unsafeRow; } }""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 4a1cbe4c38fa20fd325f1a7173e0330f6116506d..41fcb11d84bff3eea2ae5da9a7227a5745a631da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -101,14 +101,14 @@ private[sql] class TextRelation( .mapPartitions { iter => val bufferHolder = new BufferHolder val unsafeRowWriter = new UnsafeRowWriter - val unsafeRow = new UnsafeRow + val unsafeRow = new UnsafeRow(1) iter.map { case (_, line) => // Writes to an UnsafeRow directly bufferHolder.reset() unsafeRowWriter.initialize(bufferHolder, 1) unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRow.pointTo(bufferHolder.buffer, 1, bufferHolder.totalSize()) + unsafeRow.pointTo(bufferHolder.buffer, bufferHolder.totalSize()) unsafeRow } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index fa2bc7672131c2d80aeaffbad1d8f745053fb616..81bfe4e67ca739c8022ed6b4e0994c9b41eaf5e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -56,15 +56,14 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow] def createIter(): Iterator[UnsafeRow] = { val iter = sorter.getIterator - val unsafeRow = new UnsafeRow + val unsafeRow = new UnsafeRow(numFieldsOfRight) new Iterator[UnsafeRow] { override def hasNext: Boolean = { iter.hasNext } override def next(): UnsafeRow = { iter.loadNext() - unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, numFieldsOfRight, - iter.getRecordLength) + unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) unsafeRow } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 8c7099ab5a34de7832f1ce731a76685be1c3a076..c6f56cfaed22cca6c9789e23fb51c8d86d24347e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -245,8 +245,8 @@ private[joins] final class UnsafeHashedRelation( val sizeInBytes = Platform.getInt(base, offset + 4) offset += 8 - val row = new UnsafeRow - row.pointTo(base, offset, numFields, sizeInBytes) + val row = new UnsafeRow(numFields) + row.pointTo(base, offset, sizeInBytes) buffer += row offset += sizeInBytes } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 00f1526576cc5ccd309d8188d26ebad073a9cb1b..a32763db054f3da1c1c643e34bf6d36ac5ea8643 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -34,8 +34,8 @@ class UnsafeRowSuite extends SparkFunSuite { test("UnsafeRow Java serialization") { // serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data val data = new Array[Byte](1024) - val row = new UnsafeRow - row.pointTo(data, 1, 16) + val row = new UnsafeRow(1) + row.pointTo(data, 16) row.setLong(0, 19285) val ser = new JavaSerializer(new SparkConf).newInstance() @@ -47,8 +47,8 @@ class UnsafeRowSuite extends SparkFunSuite { test("UnsafeRow Kryo serialization") { // serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data val data = new Array[Byte](1024) - val row = new UnsafeRow - row.pointTo(data, 1, 16) + val row = new UnsafeRow(1) + row.pointTo(data, 16) row.setLong(0, 19285) val ser = new KryoSerializer(new SparkConf).newInstance() @@ -86,11 +86,10 @@ class UnsafeRowSuite extends SparkFunSuite { offheapRowPage.getBaseOffset, arrayBackedUnsafeRow.getSizeInBytes ) - val offheapUnsafeRow: UnsafeRow = new UnsafeRow() + val offheapUnsafeRow: UnsafeRow = new UnsafeRow(3) offheapUnsafeRow.pointTo( offheapRowPage.getBaseObject, offheapRowPage.getBaseOffset, - 3, // num fields arrayBackedUnsafeRow.getSizeInBytes ) assert(offheapUnsafeRow.getBaseObject === null)