diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 4c63abb071e3b392730cfcb7e61cd5d8848d6c86..761f0447943e8e4ecd5033bdb287bf81679574e1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -30,19 +30,18 @@ import org.apache.spark.unsafe.types.UTF8String; /** * An Unsafe implementation of Array which is backed by raw memory instead of Java objects. * - * Each tuple has two parts: [offsets] [values] + * Each tuple has three parts: [numElements] [offsets] [values] * - * In the `offsets` region, we store 4 bytes per element, represents the start address of this - * element in `values` region. We can get the length of this element by subtracting next offset. + * The `numElements` is 4 bytes storing the number of elements of this array. + * + * In the `offsets` region, we store 4 bytes per element, represents the relative offset (w.r.t. the + * base address of the array) of this element in `values` region. We can get the length of this + * element by subtracting next offset. * Note that offset can by negative which means this element is null. * * In the `values` region, we store the content of elements. As we can get length info, so elements * can be variable-length. * - * Note that when we write out this array, we should write out the `numElements` at first 4 bytes, - * then follows content. When we read in an array, we should read first 4 bytes as `numElements` - * and take the rest as content. - * * Instances of `UnsafeArrayData` act as pointers to row data stored in this format. */ // todo: there is a lof of duplicated code between UnsafeRow and UnsafeArrayData. @@ -54,11 +53,16 @@ public class UnsafeArrayData extends ArrayData { // The number of elements in this array private int numElements; - // The size of this array's backing data, in bytes + // The size of this array's backing data, in bytes. + // The 4-bytes header of `numElements` is also included. private int sizeInBytes; + public Object getBaseObject() { return baseObject; } + public long getBaseOffset() { return baseOffset; } + public int getSizeInBytes() { return sizeInBytes; } + private int getElementOffset(int ordinal) { - return Platform.getInt(baseObject, baseOffset + ordinal * 4L); + return Platform.getInt(baseObject, baseOffset + 4 + ordinal * 4L); } private int getElementSize(int offset, int ordinal) { @@ -85,10 +89,6 @@ public class UnsafeArrayData extends ArrayData { */ public UnsafeArrayData() { } - public Object getBaseObject() { return baseObject; } - public long getBaseOffset() { return baseOffset; } - public int getSizeInBytes() { return sizeInBytes; } - @Override public int numElements() { return numElements; } @@ -97,10 +97,13 @@ public class UnsafeArrayData extends ArrayData { * * @param baseObject the base object * @param baseOffset the offset within the base object - * @param sizeInBytes the size of this row's backing data, in bytes + * @param sizeInBytes the size of this array's backing data, in bytes */ - public void pointTo(Object baseObject, long baseOffset, int numElements, int sizeInBytes) { + public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { + // Read the number of elements from the first 4 bytes. + final int numElements = Platform.getInt(baseObject, baseOffset); assert numElements >= 0 : "numElements (" + numElements + ") should >= 0"; + this.numElements = numElements; this.baseObject = baseObject; this.baseOffset = baseOffset; @@ -277,7 +280,9 @@ public class UnsafeArrayData extends ArrayData { final int offset = getElementOffset(ordinal); if (offset < 0) return null; final int size = getElementSize(offset, ordinal); - return UnsafeReaders.readArray(baseObject, baseOffset + offset, size); + final UnsafeArrayData array = new UnsafeArrayData(); + array.pointTo(baseObject, baseOffset + offset, size); + return array; } @Override @@ -286,7 +291,9 @@ public class UnsafeArrayData extends ArrayData { final int offset = getElementOffset(ordinal); if (offset < 0) return null; final int size = getElementSize(offset, ordinal); - return UnsafeReaders.readMap(baseObject, baseOffset + offset, size); + final UnsafeMapData map = new UnsafeMapData(); + map.pointTo(baseObject, baseOffset + offset, size); + return map; } @Override @@ -328,7 +335,7 @@ public class UnsafeArrayData extends ArrayData { final byte[] arrayDataCopy = new byte[sizeInBytes]; Platform.copyMemory( baseObject, baseOffset, arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); - arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, numElements, sizeInBytes); + arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); return arrayCopy; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java index e9dab9edb6bd17cd7a4354878488e7ec605d1632..5bebe2a96e391e66ddabd38a1c118aba167c8cbe 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -17,41 +17,73 @@ package org.apache.spark.sql.catalyst.expressions; +import java.nio.ByteBuffer; + import org.apache.spark.sql.types.MapData; +import org.apache.spark.unsafe.Platform; /** * An Unsafe implementation of Map which is backed by raw memory instead of Java objects. * - * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData. - * - * Note that when we write out this map, we should write out the `numElements` at first 4 bytes, - * and numBytes of key array at second 4 bytes, then follows key array content and value array - * content without `numElements` header. - * When we read in a map, we should read first 4 bytes as `numElements` and second 4 bytes as - * numBytes of key array, and construct unsafe key array and value array with these 2 information. + * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 4 bytes at head + * to indicate the number of bytes of the unsafe key array. + * [unsafe key array numBytes] [unsafe key array] [unsafe value array] */ +// TODO: Use a more efficient format which doesn't depend on unsafe array. public class UnsafeMapData extends MapData { - private final UnsafeArrayData keys; - private final UnsafeArrayData values; - // The number of elements in this array - private int numElements; - // The size of this array's backing data, in bytes + private Object baseObject; + private long baseOffset; + + // The size of this map's backing data, in bytes. + // The 4-bytes header of key array `numBytes` is also included, so it's actually equal to + // 4 + key array numBytes + value array numBytes. private int sizeInBytes; + public Object getBaseObject() { return baseObject; } + public long getBaseOffset() { return baseOffset; } public int getSizeInBytes() { return sizeInBytes; } - public UnsafeMapData(UnsafeArrayData keys, UnsafeArrayData values) { + private final UnsafeArrayData keys; + private final UnsafeArrayData values; + + /** + * Construct a new UnsafeMapData. The resulting UnsafeMapData won't be usable until + * `pointTo()` has been called, since the value returned by this constructor is equivalent + * to a null pointer. + */ + public UnsafeMapData() { + keys = new UnsafeArrayData(); + values = new UnsafeArrayData(); + } + + /** + * Update this UnsafeMapData to point to different backing data. + * + * @param baseObject the base object + * @param baseOffset the offset within the base object + * @param sizeInBytes the size of this map's backing data, in bytes + */ + public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { + // Read the numBytes of key array from the first 4 bytes. + final int keyArraySize = Platform.getInt(baseObject, baseOffset); + final int valueArraySize = sizeInBytes - keyArraySize - 4; + assert keyArraySize >= 0 : "keyArraySize (" + keyArraySize + ") should >= 0"; + assert valueArraySize >= 0 : "valueArraySize (" + valueArraySize + ") should >= 0"; + + keys.pointTo(baseObject, baseOffset + 4, keyArraySize); + values.pointTo(baseObject, baseOffset + 4 + keyArraySize, valueArraySize); + assert keys.numElements() == values.numElements(); - this.sizeInBytes = keys.getSizeInBytes() + values.getSizeInBytes(); - this.numElements = keys.numElements(); - this.keys = keys; - this.values = values; + + this.baseObject = baseObject; + this.baseOffset = baseOffset; + this.sizeInBytes = sizeInBytes; } @Override public int numElements() { - return numElements; + return keys.numElements(); } @Override @@ -64,8 +96,26 @@ public class UnsafeMapData extends MapData { return values; } + public void writeToMemory(Object target, long targetOffset) { + Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); + } + + public void writeTo(ByteBuffer buffer) { + assert(buffer.hasArray()); + byte[] target = buffer.array(); + int offset = buffer.arrayOffset(); + int pos = buffer.position(); + writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos); + buffer.position(pos + sizeInBytes); + } + @Override public UnsafeMapData copy() { - return new UnsafeMapData(keys.copy(), values.copy()); + UnsafeMapData mapCopy = new UnsafeMapData(); + final byte[] mapDataCopy = new byte[sizeInBytes]; + Platform.copyMemory( + baseObject, baseOffset, mapDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); + mapCopy.pointTo(mapDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); + return mapCopy; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java deleted file mode 100644 index 6c5fcbca63fd79f08d38c86aef0cc4ba41f27f0b..0000000000000000000000000000000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions; - -import org.apache.spark.unsafe.Platform; - -public class UnsafeReaders { - - /** - * Reads in unsafe array according to the format described in `UnsafeArrayData`. - */ - public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int numBytes) { - // Read the number of elements from first 4 bytes. - final int numElements = Platform.getInt(baseObject, baseOffset); - final UnsafeArrayData array = new UnsafeArrayData(); - // Skip the first 4 bytes. - array.pointTo(baseObject, baseOffset + 4, numElements, numBytes - 4); - return array; - } - - /** - * Reads in unsafe map according to the format described in `UnsafeMapData`. - */ - public static UnsafeMapData readMap(Object baseObject, long baseOffset, int numBytes) { - // Read the number of elements from first 4 bytes. - final int numElements = Platform.getInt(baseObject, baseOffset); - // Read the numBytes of key array in second 4 bytes. - final int keyArraySize = Platform.getInt(baseObject, baseOffset + 4); - final int valueArraySize = numBytes - 8 - keyArraySize; - - final UnsafeArrayData keyArray = new UnsafeArrayData(); - keyArray.pointTo(baseObject, baseOffset + 8, numElements, keyArraySize); - - final UnsafeArrayData valueArray = new UnsafeArrayData(); - valueArray.pointTo(baseObject, baseOffset + 8 + keyArraySize, numElements, valueArraySize); - - return new UnsafeMapData(keyArray, valueArray); - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 36859fbab97449efa36f65dc776de5d3609d72ef..366615f6fe69fc96425414637005dcdbb202cd23 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -461,7 +461,9 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) (offsetAndSize & ((1L << 32) - 1)); - return UnsafeReaders.readArray(baseObject, baseOffset + offset, size); + final UnsafeArrayData array = new UnsafeArrayData(); + array.pointTo(baseObject, baseOffset + offset, size); + return array; } } @@ -473,7 +475,9 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) (offsetAndSize & ((1L << 32) - 1)); - return UnsafeReaders.readMap(baseObject, baseOffset + offset, size); + final UnsafeMapData map = new UnsafeMapData(); + map.pointTo(baseObject, baseOffset + offset, size); + return map; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 138178ce99d853d1a510286d6e7048011b980c61..7f2a1cb07af0179b50d20a1f84438b0d9e0ddd5d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -30,17 +30,19 @@ import org.apache.spark.unsafe.types.UTF8String; public class UnsafeArrayWriter { private BufferHolder holder; + // The offset of the global buffer where we start to write this array. private int startingOffset; public void initialize(BufferHolder holder, int numElements, int fixedElementSize) { - // We need 4 bytes each element to store offset. - final int fixedSize = 4 * numElements; + // We need 4 bytes to store numElements and 4 bytes each element to store offset. + final int fixedSize = 4 + 4 * numElements; this.holder = holder; this.startingOffset = holder.cursor; holder.grow(fixedSize); + Platform.putInt(holder.buffer, holder.cursor, numElements); holder.cursor += fixedSize; // Grows the global buffer ahead for fixed size data. @@ -48,7 +50,7 @@ public class UnsafeArrayWriter { } private long getElementOffset(int ordinal) { - return startingOffset + 4 * ordinal; + return startingOffset + 4 + 4 * ordinal; } public void setNullAt(int ordinal) { @@ -132,20 +134,4 @@ public class UnsafeArrayWriter { // move the cursor forward. holder.cursor += 16; } - - - - // If this array is already an UnsafeArray, we don't need to go through all elements, we can - // directly write it. - public static void directWrite(BufferHolder holder, UnsafeArrayData input) { - final int numBytes = input.getSizeInBytes(); - - // grow the global buffer before writing data. - holder.grow(numBytes); - - // Writes the array content to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); - - holder.cursor += numBytes; - } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 8b7debd440031be801cef3bfa6c493add33e70a1..e1f5a05d1d446f4b8cc332b0fd690df42f5bd629 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -181,19 +181,4 @@ public class UnsafeRowWriter { // move the cursor forward. holder.cursor += 16; } - - - - // If this struct is already an UnsafeRow, we don't need to go through all fields, we can - // directly write it. - public static void directWrite(BufferHolder holder, UnsafeRow input) { - // No need to zero-out the bytes as UnsafeRow is word aligned for sure. - final int numBytes = input.getSizeInBytes(); - // grow the global buffer before writing data. - holder.grow(numBytes); - // Write the bytes to the variable length portion. - input.writeToMemory(holder.buffer, holder.cursor); - // move the cursor forward. - holder.cursor += numBytes; - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 1b957a508d10e2db49e6c2e22248fadea6285235..dbe92d6a8350200a3ad500e06a783a1f14d95cd6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -62,7 +62,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" if ($input instanceof UnsafeRow) { - $rowWriterClass.directWrite($bufferHolder, (UnsafeRow) $input); + ${writeUnsafeData(ctx, s"((UnsafeRow) $input)", bufferHolder)} } else { ${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)} } @@ -164,8 +164,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodeGenContext, input: String, elementType: DataType, - bufferHolder: String, - needHeader: Boolean = true): String = { + bufferHolder: String): String = { val arrayWriter = ctx.freshName("arrayWriter") ctx.addMutableState(arrayWriterClass, arrayWriter, s"this.$arrayWriter = new $arrayWriterClass();") @@ -227,21 +226,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => s"$arrayWriter.write($index, $element);" } - val writeHeader = if (needHeader) { - // If header is required, we need to write the number of elements into first 4 bytes. - s""" - $bufferHolder.grow(4); - Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, $numElements); - $bufferHolder.cursor += 4; - """ - } else "" - s""" - final int $numElements = $input.numElements(); - $writeHeader if ($input instanceof UnsafeArrayData) { - $arrayWriterClass.directWrite($bufferHolder, (UnsafeArrayData) $input); + ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)} } else { + final int $numElements = $input.numElements(); $arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize); for (int $index = 0; $index < $numElements; $index++) { @@ -270,23 +259,40 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Writes out unsafe map according to the format described in `UnsafeMapData`. s""" - final ArrayData $keys = $input.keyArray(); - final ArrayData $values = $input.valueArray(); + if ($input instanceof UnsafeMapData) { + ${writeUnsafeData(ctx, s"((UnsafeMapData) $input)", bufferHolder)} + } else { + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); - $bufferHolder.grow(8); + // preserve 4 bytes to write the key array numBytes later. + $bufferHolder.grow(4); + $bufferHolder.cursor += 4; - // Write the numElements into first 4 bytes. - Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, $keys.numElements()); + // Remember the current cursor so that we can write numBytes of key array later. + final int $tmpCursor = $bufferHolder.cursor; - $bufferHolder.cursor += 8; - // Remember the current cursor so that we can write numBytes of key array later. - final int $tmpCursor = $bufferHolder.cursor; + ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)} + // Write the numBytes of key array into the first 4 bytes. + Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor); - ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder, needHeader = false)} - // Write the numBytes of key array into second 4 bytes. - Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor); + ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)} + } + """ + } - ${writeArrayToBuffer(ctx, values, valueType, bufferHolder, needHeader = false)} + /** + * If the input is already in unsafe format, we don't need to go through all elements/fields, + * we can directly write it. + */ + private def writeUnsafeData(ctx: CodeGenContext, input: String, bufferHolder: String) = { + val sizeInBytes = ctx.freshName("sizeInBytes") + s""" + final int $sizeInBytes = $input.getSizeInBytes(); + // grow the global buffer before writing data. + $bufferHolder.grow($sizeInBytes); + $input.writeToMemory($bufferHolder.buffer, $bufferHolder.cursor); + $bufferHolder.cursor += $sizeInBytes; """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index c991cd86d28c8ac0070c3017e5893b04ff8de2b3..c6aad34e972b598732db8104497c3f18e2d985a4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -296,13 +296,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { new ArrayBasedMapData(createArray(keys: _*), createArray(values: _*)) } - private def arraySizeInRow(numBytes: Int): Int = roundedSize(4 + numBytes) - - private def mapSizeInRow(numBytes: Int): Int = roundedSize(8 + numBytes) - private def testArrayInt(array: UnsafeArrayData, values: Seq[Int]): Unit = { assert(array.numElements == values.length) - assert(array.getSizeInBytes == (4 + 4) * values.length) + assert(array.getSizeInBytes == 4 + (4 + 4) * values.length) values.zipWithIndex.foreach { case (value, index) => assert(array.getInt(index) == value) } @@ -315,7 +311,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { testArrayInt(map.keyArray, keys) testArrayInt(map.valueArray, values) - assert(map.getSizeInBytes == map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) + assert(map.getSizeInBytes == 4 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) } test("basic conversion with array type") { @@ -341,10 +337,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val nestedArray = unsafeArray2.getArray(0) testArrayInt(nestedArray, Seq(3, 4)) - assert(unsafeArray2.getSizeInBytes == 4 + (4 + nestedArray.getSizeInBytes)) + assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes) - val array1Size = arraySizeInRow(unsafeArray1.getSizeInBytes) - val array2Size = arraySizeInRow(unsafeArray2.getSizeInBytes) + val array1Size = roundedSize(unsafeArray1.getSizeInBytes) + val array2Size = roundedSize(unsafeArray2.getSizeInBytes) assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size) } @@ -384,13 +380,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val nestedMap = valueArray.getMap(0) testMapInt(nestedMap, Seq(5, 6), Seq(7, 8)) - assert(valueArray.getSizeInBytes == 4 + (8 + nestedMap.getSizeInBytes)) + assert(valueArray.getSizeInBytes == 4 + 4 + nestedMap.getSizeInBytes) } - assert(unsafeMap2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes) + assert(unsafeMap2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) - val map1Size = mapSizeInRow(unsafeMap1.getSizeInBytes) - val map2Size = mapSizeInRow(unsafeMap2.getSizeInBytes) + val map1Size = roundedSize(unsafeMap1.getSizeInBytes) + val map2Size = roundedSize(unsafeMap2.getSizeInBytes) assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size) } @@ -414,7 +410,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val innerArray = field1.getArray(0) testArrayInt(innerArray, Seq(1)) - assert(field1.getSizeInBytes == 8 + 8 + arraySizeInRow(innerArray.getSizeInBytes)) + assert(field1.getSizeInBytes == 8 + 8 + roundedSize(innerArray.getSizeInBytes)) val field2 = unsafeRow.getArray(1) assert(field2.numElements == 1) @@ -427,10 +423,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(innerStruct.getLong(0) == 2L) } - assert(field2.getSizeInBytes == 4 + innerStruct.getSizeInBytes) + assert(field2.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes) assert(unsafeRow.getSizeInBytes == - 8 + 8 * 2 + field1.getSizeInBytes + arraySizeInRow(field2.getSizeInBytes)) + 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } test("basic conversion with struct and map") { @@ -453,7 +449,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val innerMap = field1.getMap(0) testMapInt(innerMap, Seq(1), Seq(2)) - assert(field1.getSizeInBytes == 8 + 8 + mapSizeInRow(innerMap.getSizeInBytes)) + assert(field1.getSizeInBytes == 8 + 8 + roundedSize(innerMap.getSizeInBytes)) val field2 = unsafeRow.getMap(1) @@ -470,13 +466,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(innerStruct.getSizeInBytes == 8 + 8) assert(innerStruct.getLong(0) == 4L) - assert(valueArray.getSizeInBytes == 4 + innerStruct.getSizeInBytes) + assert(valueArray.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes) } - assert(field2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes) + assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) assert(unsafeRow.getSizeInBytes == - 8 + 8 * 2 + field1.getSizeInBytes + mapSizeInRow(field2.getSizeInBytes)) + 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes)) } test("basic conversion with array and map") { @@ -499,7 +495,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val innerMap = field1.getMap(0) testMapInt(innerMap, Seq(1), Seq(2)) - assert(field1.getSizeInBytes == 4 + (8 + innerMap.getSizeInBytes)) + assert(field1.getSizeInBytes == 4 + 4 + innerMap.getSizeInBytes) val field2 = unsafeRow.getMap(1) assert(field2.numElements == 1) @@ -518,9 +514,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(valueArray.getSizeInBytes == 4 + (4 + innerArray.getSizeInBytes)) } - assert(field2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes) + assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes) assert(unsafeRow.getSizeInBytes == - 8 + 8 * 2 + arraySizeInRow(field1.getSizeInBytes) + mapSizeInRow(field2.getSizeInBytes)) + 8 + 8 * 2 + roundedSize(field1.getSizeInBytes) + roundedSize(field2.getSizeInBytes)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 2bc2c96b61634f630f34b9931ec7e6725dd62432..a41f04dd3b59a9d8849810e37dcee919c4cabbd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -482,12 +482,14 @@ private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRo override def extract(buffer: ByteBuffer): UnsafeRow = { val sizeInBytes = buffer.getInt() assert(buffer.hasArray) - val base = buffer.array() - val offset = buffer.arrayOffset() val cursor = buffer.position() buffer.position(cursor + sizeInBytes) val unsafeRow = new UnsafeRow - unsafeRow.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor, numOfFields, sizeInBytes) + unsafeRow.pointTo( + buffer.array(), + Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, + numOfFields, + sizeInBytes) unsafeRow } @@ -508,12 +510,11 @@ private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArra override def actualSize(row: InternalRow, ordinal: Int): Int = { val unsafeArray = getField(row, ordinal) - 4 + 4 + unsafeArray.getSizeInBytes + 4 + unsafeArray.getSizeInBytes } override def append(value: UnsafeArrayData, buffer: ByteBuffer): Unit = { - buffer.putInt(4 + value.getSizeInBytes) - buffer.putInt(value.numElements()) + buffer.putInt(value.getSizeInBytes) value.writeTo(buffer) } @@ -522,10 +523,12 @@ private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArra assert(buffer.hasArray) val cursor = buffer.position() buffer.position(cursor + numBytes) - UnsafeReaders.readArray( + val array = new UnsafeArrayData + array.pointTo( buffer.array(), Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, numBytes) + array } override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy() @@ -545,15 +548,12 @@ private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] override def actualSize(row: InternalRow, ordinal: Int): Int = { val unsafeMap = getField(row, ordinal) - 12 + unsafeMap.keyArray().getSizeInBytes + unsafeMap.valueArray().getSizeInBytes + 4 + unsafeMap.getSizeInBytes } override def append(value: UnsafeMapData, buffer: ByteBuffer): Unit = { - buffer.putInt(8 + value.keyArray().getSizeInBytes + value.valueArray().getSizeInBytes) - buffer.putInt(value.numElements()) - buffer.putInt(value.keyArray().getSizeInBytes) - value.keyArray().writeTo(buffer) - value.valueArray().writeTo(buffer) + buffer.putInt(value.getSizeInBytes) + value.writeTo(buffer) } override def extract(buffer: ByteBuffer): UnsafeMapData = { @@ -561,10 +561,12 @@ private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] assert(buffer.hasArray) val cursor = buffer.position() buffer.position(cursor + numBytes) - UnsafeReaders.readMap( + val map = new UnsafeMapData + map.pointTo( buffer.array(), Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, numBytes) + map } override def clone(v: UnsafeMapData): UnsafeMapData = v.copy() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 0e6e1bcf72896e71e6da6ce84891bc993375b635..63bc39bfa0307fad3c0ade6878fe0199a208401d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -73,7 +73,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8) checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5) checkActualSize(ARRAY_TYPE, Array[Any](1), 16) - checkActualSize(MAP_TYPE, Map(1 -> "a"), 25) + checkActualSize(MAP_TYPE, Map(1 -> "a"), 29) checkActualSize(STRUCT_TYPE, Row("hello"), 28) }