diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 501dff090313cef44b0b3a2d15dda5d4a607e56e..da9538b3f13cc2f4d448ac16447e3058afe0e634 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions; import java.math.BigDecimal; import java.math.BigInteger; -import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -256,7 +255,7 @@ public class UnsafeArrayData extends ArrayData { } @Override - public InternalRow getStruct(int ordinal, int numFields) { + public UnsafeRow getStruct(int ordinal, int numFields) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return null; @@ -267,7 +266,7 @@ public class UnsafeArrayData extends ArrayData { } @Override - public ArrayData getArray(int ordinal) { + public UnsafeArrayData getArray(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return null; @@ -276,7 +275,7 @@ public class UnsafeArrayData extends ArrayData { } @Override - public MapData getMap(int ordinal) { + public UnsafeMapData getMap(int ordinal) { assertIndexIsValid(ordinal); final int offset = getElementOffset(ordinal); if (offset < 0) return null; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java index 46216054ab38bee3de846632b8bfa27bf14b3240..e9dab9edb6bd17cd7a4354878488e7ec605d1632 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -17,18 +17,23 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.sql.types.ArrayData; import org.apache.spark.sql.types.MapData; /** * An Unsafe implementation of Map which is backed by raw memory instead of Java objects. * * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData. + * + * Note that when we write out this map, we should write out the `numElements` at first 4 bytes, + * and numBytes of key array at second 4 bytes, then follows key array content and value array + * content without `numElements` header. + * When we read in a map, we should read first 4 bytes as `numElements` and second 4 bytes as + * numBytes of key array, and construct unsafe key array and value array with these 2 information. */ public class UnsafeMapData extends MapData { - public final UnsafeArrayData keys; - public final UnsafeArrayData values; + private final UnsafeArrayData keys; + private final UnsafeArrayData values; // The number of elements in this array private int numElements; // The size of this array's backing data, in bytes @@ -50,12 +55,12 @@ public class UnsafeMapData extends MapData { } @Override - public ArrayData keyArray() { + public UnsafeArrayData keyArray() { return keys; } @Override - public ArrayData valueArray() { + public UnsafeArrayData valueArray() { return values; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java index 7b03185a30e3c48bdca66291ae658c8545edae7c..6c5fcbca63fd79f08d38c86aef0cc4ba41f27f0b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java @@ -21,6 +21,9 @@ import org.apache.spark.unsafe.Platform; public class UnsafeReaders { + /** + * Reads in unsafe array according to the format described in `UnsafeArrayData`. + */ public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int numBytes) { // Read the number of elements from first 4 bytes. final int numElements = Platform.getInt(baseObject, baseOffset); @@ -30,6 +33,9 @@ public class UnsafeReaders { return array; } + /** + * Reads in unsafe map according to the format described in `UnsafeMapData`. + */ public static UnsafeMapData readMap(Object baseObject, long baseOffset, int numBytes) { // Read the number of elements from first 4 bytes. final int numElements = Platform.getInt(baseObject, baseOffset); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 6c020045c311ac343bfa1bd1f33c098c25ec6a31..e8ac2999c2d29569507ccdb6e66966c66e70e91c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -446,7 +446,7 @@ public final class UnsafeRow extends MutableRow { } @Override - public ArrayData getArray(int ordinal) { + public UnsafeArrayData getArray(int ordinal) { if (isNullAt(ordinal)) { return null; } else { @@ -458,7 +458,7 @@ public final class UnsafeRow extends MutableRow { } @Override - public MapData getMap(int ordinal) { + public UnsafeMapData getMap(int ordinal) { if (isNullAt(ordinal)) { return null; } else { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index 2f43db68a750e757686eae249d6b4e18b422d791..0f1e0202aace15a303e83040b747ad7b8ef00c91 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -233,8 +233,8 @@ public class UnsafeRowWriters { public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeMapData input) { final long offset = target.getBaseOffset() + cursor; - final UnsafeArrayData keyArray = input.keys; - final UnsafeArrayData valueArray = input.values; + final UnsafeArrayData keyArray = input.keyArray(); + final UnsafeArrayData valueArray = input.valueArray(); final int keysNumBytes = keyArray.getSizeInBytes(); final int valuesNumBytes = valueArray.getSizeInBytes(); final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java index cd83695fca033566f7eb26fcf170fbe0eec9500f..ce2d9c4ffbf9f1860db507fffccc4bc985c29dd1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java @@ -168,8 +168,8 @@ public class UnsafeWriters { } public static int write(Object targetObject, long targetOffset, UnsafeMapData input) { - final UnsafeArrayData keyArray = input.keys; - final UnsafeArrayData valueArray = input.values; + final UnsafeArrayData keyArray = input.keyArray(); + final UnsafeArrayData valueArray = input.valueArray(); final int keysNumBytes = keyArray.getSizeInBytes(); final int valuesNumBytes = valueArray.getSizeInBytes(); final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java new file mode 100644 index 0000000000000000000000000000000000000000..9c9468678065d43ca299ba39ec767c9931d03d0e --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.unsafe.Platform; + +/** + * A helper class to manage the row buffer used in `GenerateUnsafeProjection`. + * + * Note that it is only used in `GenerateUnsafeProjection`, so it's safe to mark member variables + * public for ease of use. + */ +public class BufferHolder { + public byte[] buffer = new byte[64]; + public int cursor = Platform.BYTE_ARRAY_OFFSET; + + public void grow(int neededSize) { + final int length = totalSize() + neededSize; + if (buffer.length < length) { + // This will not happen frequently, because the buffer is re-used. + final byte[] tmp = new byte[length * 2]; + Platform.copyMemory( + buffer, + Platform.BYTE_ARRAY_OFFSET, + tmp, + Platform.BYTE_ARRAY_OFFSET, + totalSize()); + buffer = tmp; + } + } + + public void reset() { + cursor = Platform.BYTE_ARRAY_OFFSET; + } + + public int totalSize() { + return cursor - Platform.BYTE_ARRAY_OFFSET; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java new file mode 100644 index 0000000000000000000000000000000000000000..138178ce99d853d1a510286d6e7048011b980c61 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A helper class to write data into global row buffer using `UnsafeArrayData` format, + * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. + */ +public class UnsafeArrayWriter { + + private BufferHolder holder; + // The offset of the global buffer where we start to write this array. + private int startingOffset; + + public void initialize(BufferHolder holder, int numElements, int fixedElementSize) { + // We need 4 bytes each element to store offset. + final int fixedSize = 4 * numElements; + + this.holder = holder; + this.startingOffset = holder.cursor; + + holder.grow(fixedSize); + holder.cursor += fixedSize; + + // Grows the global buffer ahead for fixed size data. + holder.grow(fixedElementSize * numElements); + } + + private long getElementOffset(int ordinal) { + return startingOffset + 4 * ordinal; + } + + public void setNullAt(int ordinal) { + final int relativeOffset = holder.cursor - startingOffset; + // Writes negative offset value to represent null element. + Platform.putInt(holder.buffer, getElementOffset(ordinal), -relativeOffset); + } + + public void setOffset(int ordinal) { + final int relativeOffset = holder.cursor - startingOffset; + Platform.putInt(holder.buffer, getElementOffset(ordinal), relativeOffset); + } + + public void writeCompactDecimal(int ordinal, Decimal input, int precision, int scale) { + // make sure Decimal object has the same scale as DecimalType + if (input.changePrecision(precision, scale)) { + Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong()); + setOffset(ordinal); + holder.cursor += 8; + } else { + setNullAt(ordinal); + } + } + + public void write(int ordinal, Decimal input, int precision, int scale) { + // make sure Decimal object has the same scale as DecimalType + if (input.changePrecision(precision, scale)) { + final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); + assert bytes.length <= 16; + holder.grow(bytes.length); + + // Write the bytes to the variable length portion. + Platform.copyMemory( + bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + setOffset(ordinal); + holder.cursor += bytes.length; + } else { + setNullAt(ordinal); + } + } + + public void write(int ordinal, UTF8String input) { + final int numBytes = input.numBytes(); + + // grow the global buffer before writing data. + holder.grow(numBytes); + + // Write the bytes to the variable length portion. + input.writeToMemory(holder.buffer, holder.cursor); + + setOffset(ordinal); + + // move the cursor forward. + holder.cursor += numBytes; + } + + public void write(int ordinal, byte[] input) { + // grow the global buffer before writing data. + holder.grow(input.length); + + // Write the bytes to the variable length portion. + Platform.copyMemory( + input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, input.length); + + setOffset(ordinal); + + // move the cursor forward. + holder.cursor += input.length; + } + + public void write(int ordinal, CalendarInterval input) { + // grow the global buffer before writing data. + holder.grow(16); + + // Write the months and microseconds fields of Interval to the variable length portion. + Platform.putLong(holder.buffer, holder.cursor, input.months); + Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); + + setOffset(ordinal); + + // move the cursor forward. + holder.cursor += 16; + } + + + + // If this array is already an UnsafeArray, we don't need to go through all elements, we can + // directly write it. + public static void directWrite(BufferHolder holder, UnsafeArrayData input) { + final int numBytes = input.getSizeInBytes(); + + // grow the global buffer before writing data. + holder.grow(numBytes); + + // Writes the array content to the variable length portion. + input.writeToMemory(holder.buffer, holder.cursor); + + holder.cursor += numBytes; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java new file mode 100644 index 0000000000000000000000000000000000000000..8b7debd440031be801cef3bfa6c493add33e70a1 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.bitset.BitSetMethods; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A helper class to write data into global row buffer using `UnsafeRow` format, + * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. + */ +public class UnsafeRowWriter { + + private BufferHolder holder; + // The offset of the global buffer where we start to write this row. + private int startingOffset; + private int nullBitsSize; + + public void initialize(BufferHolder holder, int numFields) { + this.holder = holder; + this.startingOffset = holder.cursor; + this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields); + + // grow the global buffer to make sure it has enough space to write fixed-length data. + final int fixedSize = nullBitsSize + 8 * numFields; + holder.grow(fixedSize); + holder.cursor += fixedSize; + + // zero-out the null bits region + for (int i = 0; i < nullBitsSize; i += 8) { + Platform.putLong(holder.buffer, startingOffset + i, 0L); + } + } + + private void zeroOutPaddingBytes(int numBytes) { + if ((numBytes & 0x07) > 0) { + Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L); + } + } + + public void setNullAt(int ordinal) { + BitSetMethods.set(holder.buffer, startingOffset, ordinal); + Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L); + } + + public long getFieldOffset(int ordinal) { + return startingOffset + nullBitsSize + 8 * ordinal; + } + + public void setOffsetAndSize(int ordinal, long size) { + setOffsetAndSize(ordinal, holder.cursor, size); + } + + public void setOffsetAndSize(int ordinal, long currentCursor, long size) { + final long relativeOffset = currentCursor - startingOffset; + final long fieldOffset = getFieldOffset(ordinal); + final long offsetAndSize = (relativeOffset << 32) | size; + + Platform.putLong(holder.buffer, fieldOffset, offsetAndSize); + } + + // Do word alignment for this row and grow the row buffer if needed. + // todo: remove this after we make unsafe array data word align. + public void alignToWords(int numBytes) { + final int remainder = numBytes & 0x07; + + if (remainder > 0) { + final int paddingBytes = 8 - remainder; + holder.grow(paddingBytes); + + for (int i = 0; i < paddingBytes; i++) { + Platform.putByte(holder.buffer, holder.cursor, (byte) 0); + holder.cursor++; + } + } + } + + public void writeCompactDecimal(int ordinal, Decimal input, int precision, int scale) { + // make sure Decimal object has the same scale as DecimalType + if (input.changePrecision(precision, scale)) { + Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong()); + } else { + setNullAt(ordinal); + } + } + + public void write(int ordinal, Decimal input, int precision, int scale) { + // grow the global buffer before writing data. + holder.grow(16); + + // zero-out the bytes + Platform.putLong(holder.buffer, holder.cursor, 0L); + Platform.putLong(holder.buffer, holder.cursor + 8, 0L); + + // Make sure Decimal object has the same scale as DecimalType. + // Note that we may pass in null Decimal object to set null for it. + if (input == null || !input.changePrecision(precision, scale)) { + BitSetMethods.set(holder.buffer, startingOffset, ordinal); + // keep the offset for future update + setOffsetAndSize(ordinal, 0L); + } else { + final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); + assert bytes.length <= 16; + + // Write the bytes to the variable length portion. + Platform.copyMemory( + bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + setOffsetAndSize(ordinal, bytes.length); + } + + // move the cursor forward. + holder.cursor += 16; + } + + public void write(int ordinal, UTF8String input) { + final int numBytes = input.numBytes(); + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + + // grow the global buffer before writing data. + holder.grow(roundedSize); + + zeroOutPaddingBytes(numBytes); + + // Write the bytes to the variable length portion. + input.writeToMemory(holder.buffer, holder.cursor); + + setOffsetAndSize(ordinal, numBytes); + + // move the cursor forward. + holder.cursor += roundedSize; + } + + public void write(int ordinal, byte[] input) { + final int numBytes = input.length; + final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + + // grow the global buffer before writing data. + holder.grow(roundedSize); + + zeroOutPaddingBytes(numBytes); + + // Write the bytes to the variable length portion. + Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); + + setOffsetAndSize(ordinal, numBytes); + + // move the cursor forward. + holder.cursor += roundedSize; + } + + public void write(int ordinal, CalendarInterval input) { + // grow the global buffer before writing data. + holder.grow(16); + + // Write the months and microseconds fields of Interval to the variable length portion. + Platform.putLong(holder.buffer, holder.cursor, input.months); + Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds); + + setOffsetAndSize(ordinal, 16); + + // move the cursor forward. + holder.cursor += 16; + } + + + + // If this struct is already an UnsafeRow, we don't need to go through all fields, we can + // directly write it. + public static void directWrite(BufferHolder holder, UnsafeRow input) { + // No need to zero-out the bytes as UnsafeRow is word aligned for sure. + final int numBytes = input.getSizeInBytes(); + // grow the global buffer before writing data. + holder.grow(numBytes); + // Write the bytes to the variable length portion. + input.writeToMemory(holder.buffer, holder.cursor); + // move the cursor forward. + holder.cursor += numBytes; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index da3103b4ebb6b7242f73392b898030d7f04f8dae..9a2878113304cdd7ea88461858c28a8c409d61d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -272,6 +272,7 @@ class CodeGenContext { * 64kb code size limit in JVM * * @param row the variable name of row that is used by expressions + * @param expressions the codes to evaluate expressions. */ def splitExpressions(row: String, expressions: Seq[String]): String = { val blocks = new ArrayBuffer[String]() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 55562facf9652fcf40ca4b11e30780a1a2333dea..99bf50a84571b727815828d5900c12a71f02bf5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -393,10 +393,292 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => input } + private val rowWriterClass = classOf[UnsafeRowWriter].getName + private val arrayWriterClass = classOf[UnsafeArrayWriter].getName + + // TODO: if the nullability of field is correct, we can use it to save null check. + private def writeStructToBuffer( + ctx: CodeGenContext, + input: String, + fieldTypes: Seq[DataType], + bufferHolder: String): String = { + val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => + val fieldName = ctx.freshName("fieldName") + val code = s"final ${ctx.javaType(dt)} $fieldName = ${ctx.getValue(input, dt, i.toString)};" + val isNull = s"$input.isNullAt($i)" + GeneratedExpressionCode(code, isNull, fieldName) + } + + s""" + if ($input instanceof UnsafeRow) { + $rowWriterClass.directWrite($bufferHolder, (UnsafeRow) $input); + } else { + ${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)} + } + """ + } + + private def writeExpressionsToBuffer( + ctx: CodeGenContext, + row: String, + inputs: Seq[GeneratedExpressionCode], + inputTypes: Seq[DataType], + bufferHolder: String): String = { + val rowWriter = ctx.freshName("rowWriter") + ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();") + + val writeFields = inputs.zip(inputTypes).zipWithIndex.map { + case ((input, dt), index) => + val tmpCursor = ctx.freshName("tmpCursor") + + val setNull = dt match { + case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => + // Can't call setNullAt() for DecimalType with precision larger than 18. + s"$rowWriter.write($index, null, ${t.precision}, ${t.scale});" + case _ => s"$rowWriter.setNullAt($index);" + } + + val writeField = dt match { + case t: StructType => + s""" + // Remember the current cursor so that we can calculate how many bytes are + // written later. + final int $tmpCursor = $bufferHolder.cursor; + ${writeStructToBuffer(ctx, input.primitive, t.map(_.dataType), bufferHolder)} + $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + """ + + case a @ ArrayType(et, _) => + s""" + // Remember the current cursor so that we can calculate how many bytes are + // written later. + final int $tmpCursor = $bufferHolder.cursor; + ${writeArrayToBuffer(ctx, input.primitive, et, bufferHolder)} + $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor); + """ + + case m @ MapType(kt, vt, _) => + s""" + // Remember the current cursor so that we can calculate how many bytes are + // written later. + final int $tmpCursor = $bufferHolder.cursor; + ${writeMapToBuffer(ctx, input.primitive, kt, vt, bufferHolder)} + $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); + $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor); + """ + + case _ if ctx.isPrimitiveType(dt) => + val fieldOffset = ctx.freshName("fieldOffset") + s""" + final long $fieldOffset = $rowWriter.getFieldOffset($index); + Platform.putLong($bufferHolder.buffer, $fieldOffset, 0L); + ${writePrimitiveType(ctx, input.primitive, dt, s"$bufferHolder.buffer", fieldOffset)} + """ + + case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => + s"$rowWriter.writeCompactDecimal($index, ${input.primitive}, " + + s"${t.precision}, ${t.scale});" + + case t: DecimalType => + s"$rowWriter.write($index, ${input.primitive}, ${t.precision}, ${t.scale});" + + case NullType => "" + + case _ => s"$rowWriter.write($index, ${input.primitive});" + } + + s""" + ${input.code} + if (${input.isNull}) { + $setNull + } else { + $writeField + } + """ + } + + s""" + $rowWriter.initialize($bufferHolder, ${inputs.length}); + ${ctx.splitExpressions(row, writeFields)} + """ + } + + // TODO: if the nullability of array element is correct, we can use it to save null check. + private def writeArrayToBuffer( + ctx: CodeGenContext, + input: String, + elementType: DataType, + bufferHolder: String, + needHeader: Boolean = true): String = { + val arrayWriter = ctx.freshName("arrayWriter") + ctx.addMutableState(arrayWriterClass, arrayWriter, + s"this.$arrayWriter = new $arrayWriterClass();") + val numElements = ctx.freshName("numElements") + val index = ctx.freshName("index") + val element = ctx.freshName("element") + + val jt = ctx.javaType(elementType) + + val fixedElementSize = elementType match { + case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8 + case _ if ctx.isPrimitiveType(jt) => elementType.defaultSize + case _ => 0 + } + + val writeElement = elementType match { + case t: StructType => + s""" + $arrayWriter.setOffset($index); + ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)} + """ + + case a @ ArrayType(et, _) => + s""" + $arrayWriter.setOffset($index); + ${writeArrayToBuffer(ctx, element, et, bufferHolder)} + """ + + case m @ MapType(kt, vt, _) => + s""" + $arrayWriter.setOffset($index); + ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} + """ + + case _ if ctx.isPrimitiveType(elementType) => + // Should we do word align? + val dataSize = elementType.defaultSize + + s""" + $arrayWriter.setOffset($index); + ${writePrimitiveType(ctx, element, elementType, + s"$bufferHolder.buffer", s"$bufferHolder.cursor")} + $bufferHolder.cursor += $dataSize; + """ + + case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => + s"$arrayWriter.writeCompactDecimal($index, $element, ${t.precision}, ${t.scale});" + + case t: DecimalType => + s"$arrayWriter.write($index, $element, ${t.precision}, ${t.scale});" + + case NullType => "" + + case _ => s"$arrayWriter.write($index, $element);" + } + + val writeHeader = if (needHeader) { + // If header is required, we need to write the number of elements into first 4 bytes. + s""" + $bufferHolder.grow(4); + Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, $numElements); + $bufferHolder.cursor += 4; + """ + } else "" + + s""" + final int $numElements = $input.numElements(); + $writeHeader + if ($input instanceof UnsafeArrayData) { + $arrayWriterClass.directWrite($bufferHolder, (UnsafeArrayData) $input); + } else { + $arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize); + + for (int $index = 0; $index < $numElements; $index++) { + if ($input.isNullAt($index)) { + $arrayWriter.setNullAt($index); + } else { + final $jt $element = ${ctx.getValue(input, elementType, index)}; + $writeElement + } + } + } + """ + } + + // TODO: if the nullability of value element is correct, we can use it to save null check. + private def writeMapToBuffer( + ctx: CodeGenContext, + input: String, + keyType: DataType, + valueType: DataType, + bufferHolder: String): String = { + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val tmpCursor = ctx.freshName("tmpCursor") + + + // Writes out unsafe map according to the format described in `UnsafeMapData`. + s""" + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); + + $bufferHolder.grow(8); + + // Write the numElements into first 4 bytes. + Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, $keys.numElements()); + + $bufferHolder.cursor += 8; + // Remember the current cursor so that we can write numBytes of key array later. + final int $tmpCursor = $bufferHolder.cursor; + + ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder, needHeader = false)} + // Write the numBytes of key array into second 4 bytes. + Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor); + + ${writeArrayToBuffer(ctx, values, valueType, bufferHolder, needHeader = false)} + """ + } + + private def writePrimitiveType( + ctx: CodeGenContext, + input: String, + dt: DataType, + buffer: String, + offset: String) = { + assert(ctx.isPrimitiveType(dt)) + + val putMethod = s"put${ctx.primitiveTypeName(dt)}" + + dt match { + case FloatType | DoubleType => + val normalized = ctx.freshName("normalized") + val boxedType = ctx.boxedType(dt) + val handleNaN = + s""" + final ${ctx.javaType(dt)} $normalized; + if ($boxedType.isNaN($input)) { + $normalized = $boxedType.NaN; + } else { + $normalized = $input; + } + """ + + s""" + $handleNaN + Platform.$putMethod($buffer, $offset, $normalized); + """ + case _ => s"Platform.$putMethod($buffer, $offset, $input);" + } + } + def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = { val exprEvals = expressions.map(e => e.gen(ctx)) val exprTypes = expressions.map(_.dataType) - createCodeForStruct(ctx, "i", exprEvals, exprTypes) + + val result = ctx.freshName("result") + ctx.addMutableState("UnsafeRow", result, s"this.$result = new UnsafeRow();") + val bufferHolder = ctx.freshName("bufferHolder") + val holderClass = classOf[BufferHolder].getName + ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") + + val code = + s""" + $bufferHolder.reset(); + ${writeExpressionsToBuffer(ctx, "i", exprEvals, exprTypes, bufferHolder)} + $result.pointTo($bufferHolder.buffer, ${expressions.length}, $bufferHolder.totalSize()); + """ + GeneratedExpressionCode(code, "false", result) } protected def canonicalize(in: Seq[Expression]): Seq[Expression] = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 8c7220319363088017a696206b5986ab3c5a9c01..c991cd86d28c8ac0070c3017e5893b04ff8de2b3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} -import java.util.Arrays import org.scalatest.Matchers @@ -43,7 +42,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row.setInt(2, 2) val unsafeRow: UnsafeRow = converter.apply(row) - assert(converter.apply(row).getSizeInBytes === 8 + (3 * 8)) + assert(unsafeRow.getSizeInBytes === 8 + (3 * 8)) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getLong(1) === 1) assert(unsafeRow.getInt(2) === 2) @@ -62,6 +61,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRowCopy.getLong(0) === 0) assert(unsafeRowCopy.getLong(1) === 1) assert(unsafeRowCopy.getInt(2) === 2) + + // Make sure the converter can be reused, i.e. we correctly reset all states. + val unsafeRow2: UnsafeRow = converter.apply(row) + assert(unsafeRow2.getSizeInBytes === 8 + (3 * 8)) + assert(unsafeRow2.getLong(0) === 0) + assert(unsafeRow2.getLong(1) === 1) + assert(unsafeRow2.getInt(2) === 2) } test("basic conversion with primitive, string and binary types") { @@ -176,7 +182,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r } - // todo: we reuse the UnsafeRow in projection, so these tests are meaningless. val setToNullAfterCreation = converter.apply(rowWithNoNullColumns) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) @@ -192,7 +197,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { rowWithNoNullColumns.getDecimal(10, 10, 0)) assert(setToNullAfterCreation.getDecimal(11, 38, 18) === rowWithNoNullColumns.getDecimal(11, 38, 18)) - // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) for (i <- fieldTypes.indices) { // Cann't call setNullAt() on DecimalType @@ -202,8 +206,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { setToNullAfterCreation.setNullAt(i) } } - // There are some garbage left in the var-length area - assert(Arrays.equals(createdFromNull.getBytes, setToNullAfterCreation.getBytes())) setToNullAfterCreation.setNullAt(0) setToNullAfterCreation.setBoolean(1, false) @@ -251,107 +253,274 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) } + test("basic conversion with struct type") { + val fieldTypes: Array[DataType] = Array( + new StructType().add("i", IntegerType), + new StructType().add("nest", new StructType().add("l", LongType)) + ) + + val converter = UnsafeProjection.create(fieldTypes) + + val row = new GenericMutableRow(fieldTypes.length) + row.update(0, InternalRow(1)) + row.update(1, InternalRow(InternalRow(2L))) + + val unsafeRow: UnsafeRow = converter.apply(row) + assert(unsafeRow.numFields == 2) + + val row1 = unsafeRow.getStruct(0, 1) + assert(row1.getSizeInBytes == 8 + 1 * 8) + assert(row1.numFields == 1) + assert(row1.getInt(0) == 1) + + val row2 = unsafeRow.getStruct(1, 1) + assert(row2.numFields() == 1) + + val innerRow = row2.getStruct(0, 1) + + { + assert(innerRow.getSizeInBytes == 8 + 1 * 8) + assert(innerRow.numFields == 1) + assert(innerRow.getLong(0) == 2L) + } + + assert(row2.getSizeInBytes == 8 + 1 * 8 + innerRow.getSizeInBytes) + + assert(unsafeRow.getSizeInBytes == 8 + 2 * 8 + row1.getSizeInBytes + row2.getSizeInBytes) + } + + private def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray) + + private def createMap(keys: Any*)(values: Any*): MapData = { + assert(keys.length == values.length) + new ArrayBasedMapData(createArray(keys: _*), createArray(values: _*)) + } + + private def arraySizeInRow(numBytes: Int): Int = roundedSize(4 + numBytes) + + private def mapSizeInRow(numBytes: Int): Int = roundedSize(8 + numBytes) + + private def testArrayInt(array: UnsafeArrayData, values: Seq[Int]): Unit = { + assert(array.numElements == values.length) + assert(array.getSizeInBytes == (4 + 4) * values.length) + values.zipWithIndex.foreach { + case (value, index) => assert(array.getInt(index) == value) + } + } + + private def testMapInt(map: UnsafeMapData, keys: Seq[Int], values: Seq[Int]): Unit = { + assert(keys.length == values.length) + assert(map.numElements == keys.length) + + testArrayInt(map.keyArray, keys) + testArrayInt(map.valueArray, values) + + assert(map.getSizeInBytes == map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes) + } + test("basic conversion with array type") { val fieldTypes: Array[DataType] = Array( - ArrayType(LongType), - ArrayType(ArrayType(LongType)) + ArrayType(IntegerType), + ArrayType(ArrayType(IntegerType)) ) val converter = UnsafeProjection.create(fieldTypes) - val array1 = new GenericArrayData(Array[Any](1L, 2L)) - val array2 = new GenericArrayData(Array[Any](new GenericArrayData(Array[Any](3L, 4L)))) val row = new GenericMutableRow(fieldTypes.length) - row.update(0, array1) - row.update(1, array2) + row.update(0, createArray(1, 2)) + row.update(1, createArray(createArray(3, 4))) val unsafeRow: UnsafeRow = converter.apply(row) assert(unsafeRow.numFields() == 2) - val unsafeArray1 = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData] - assert(unsafeArray1.getSizeInBytes == 4 * 2 + 8 * 2) - assert(unsafeArray1.numElements() == 2) - assert(unsafeArray1.getLong(0) == 1L) - assert(unsafeArray1.getLong(1) == 2L) + val unsafeArray1 = unsafeRow.getArray(0) + testArrayInt(unsafeArray1, Seq(1, 2)) - val unsafeArray2 = unsafeRow.getArray(1).asInstanceOf[UnsafeArrayData] - assert(unsafeArray2.numElements() == 1) + val unsafeArray2 = unsafeRow.getArray(1) + assert(unsafeArray2.numElements == 1) - val nestedArray = unsafeArray2.getArray(0).asInstanceOf[UnsafeArrayData] - assert(nestedArray.getSizeInBytes == 4 * 2 + 8 * 2) - assert(nestedArray.numElements() == 2) - assert(nestedArray.getLong(0) == 3L) - assert(nestedArray.getLong(1) == 4L) + val nestedArray = unsafeArray2.getArray(0) + testArrayInt(nestedArray, Seq(3, 4)) - assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes) + assert(unsafeArray2.getSizeInBytes == 4 + (4 + nestedArray.getSizeInBytes)) - val array1Size = roundedSize(4 + unsafeArray1.getSizeInBytes) - val array2Size = roundedSize(4 + unsafeArray2.getSizeInBytes) + val array1Size = arraySizeInRow(unsafeArray1.getSizeInBytes) + val array2Size = arraySizeInRow(unsafeArray2.getSizeInBytes) assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size) } test("basic conversion with map type") { - def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray) + val fieldTypes: Array[DataType] = Array( + MapType(IntegerType, IntegerType), + MapType(IntegerType, MapType(IntegerType, IntegerType)) + ) + val converter = UnsafeProjection.create(fieldTypes) - def testIntLongMap(map: UnsafeMapData, keys: Array[Int], values: Array[Long]): Unit = { - val numElements = keys.length - assert(map.numElements() == numElements) + val map1 = createMap(1, 2)(3, 4) - val keyArray = map.keys - assert(keyArray.getSizeInBytes == 4 * numElements + 4 * numElements) - assert(keyArray.numElements() == numElements) - keys.zipWithIndex.foreach { case (key, i) => - assert(keyArray.getInt(i) == key) - } + val innerMap = createMap(5, 6)(7, 8) + val map2 = createMap(9)(innerMap) - val valueArray = map.values - assert(valueArray.getSizeInBytes == 4 * numElements + 8 * numElements) - assert(valueArray.numElements() == numElements) - values.zipWithIndex.foreach { case (value, i) => - assert(valueArray.getLong(i) == value) - } + val row = new GenericMutableRow(fieldTypes.length) + row.update(0, map1) + row.update(1, map2) - assert(map.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes) + val unsafeRow: UnsafeRow = converter.apply(row) + assert(unsafeRow.numFields == 2) + + val unsafeMap1 = unsafeRow.getMap(0) + testMapInt(unsafeMap1, Seq(1, 2), Seq(3, 4)) + + val unsafeMap2 = unsafeRow.getMap(1) + assert(unsafeMap2.numElements == 1) + + val keyArray = unsafeMap2.keyArray + testArrayInt(keyArray, Seq(9)) + + val valueArray = unsafeMap2.valueArray + + { + assert(valueArray.numElements == 1) + + val nestedMap = valueArray.getMap(0) + testMapInt(nestedMap, Seq(5, 6), Seq(7, 8)) + + assert(valueArray.getSizeInBytes == 4 + (8 + nestedMap.getSizeInBytes)) } + assert(unsafeMap2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes) + + val map1Size = mapSizeInRow(unsafeMap1.getSizeInBytes) + val map2Size = mapSizeInRow(unsafeMap2.getSizeInBytes) + assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size) + } + + test("basic conversion with struct and array") { val fieldTypes: Array[DataType] = Array( - MapType(IntegerType, LongType), - MapType(IntegerType, MapType(IntegerType, LongType)) + new StructType().add("arr", ArrayType(IntegerType)), + ArrayType(new StructType().add("l", LongType)) ) val converter = UnsafeProjection.create(fieldTypes) - val map1 = new ArrayBasedMapData(createArray(1, 2), createArray(3L, 4L)) + val row = new GenericMutableRow(fieldTypes.length) + row.update(0, InternalRow(createArray(1))) + row.update(1, createArray(InternalRow(2L))) + + val unsafeRow: UnsafeRow = converter.apply(row) + assert(unsafeRow.numFields() == 2) + + val field1 = unsafeRow.getStruct(0, 1) + assert(field1.numFields == 1) + + val innerArray = field1.getArray(0) + testArrayInt(innerArray, Seq(1)) - val innerMap = new ArrayBasedMapData(createArray(5, 6), createArray(7L, 8L)) - val map2 = new ArrayBasedMapData(createArray(9), createArray(innerMap)) + assert(field1.getSizeInBytes == 8 + 8 + arraySizeInRow(innerArray.getSizeInBytes)) + + val field2 = unsafeRow.getArray(1) + assert(field2.numElements == 1) + + val innerStruct = field2.getStruct(0, 1) + + { + assert(innerStruct.numFields == 1) + assert(innerStruct.getSizeInBytes == 8 + 8) + assert(innerStruct.getLong(0) == 2L) + } + + assert(field2.getSizeInBytes == 4 + innerStruct.getSizeInBytes) + + assert(unsafeRow.getSizeInBytes == + 8 + 8 * 2 + field1.getSizeInBytes + arraySizeInRow(field2.getSizeInBytes)) + } + + test("basic conversion with struct and map") { + val fieldTypes: Array[DataType] = Array( + new StructType().add("map", MapType(IntegerType, IntegerType)), + MapType(IntegerType, new StructType().add("l", LongType)) + ) + val converter = UnsafeProjection.create(fieldTypes) val row = new GenericMutableRow(fieldTypes.length) - row.update(0, map1) - row.update(1, map2) + row.update(0, InternalRow(createMap(1)(2))) + row.update(1, createMap(3)(InternalRow(4L))) val unsafeRow: UnsafeRow = converter.apply(row) assert(unsafeRow.numFields() == 2) - val unsafeMap1 = unsafeRow.getMap(0).asInstanceOf[UnsafeMapData] - testIntLongMap(unsafeMap1, Array(1, 2), Array(3L, 4L)) + val field1 = unsafeRow.getStruct(0, 1) + assert(field1.numFields == 1) - val unsafeMap2 = unsafeRow.getMap(1).asInstanceOf[UnsafeMapData] - assert(unsafeMap2.numElements() == 1) + val innerMap = field1.getMap(0) + testMapInt(innerMap, Seq(1), Seq(2)) - val keyArray = unsafeMap2.keys - assert(keyArray.getSizeInBytes == 4 + 4) - assert(keyArray.numElements() == 1) - assert(keyArray.getInt(0) == 9) + assert(field1.getSizeInBytes == 8 + 8 + mapSizeInRow(innerMap.getSizeInBytes)) - val valueArray = unsafeMap2.values - assert(valueArray.numElements() == 1) - val nestedMap = valueArray.getMap(0).asInstanceOf[UnsafeMapData] - testIntLongMap(nestedMap, Array(5, 6), Array(7L, 8L)) - assert(valueArray.getSizeInBytes == 4 + 8 + nestedMap.getSizeInBytes) + val field2 = unsafeRow.getMap(1) - assert(unsafeMap2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes) + val keyArray = field2.keyArray + testArrayInt(keyArray, Seq(3)) - val map1Size = roundedSize(8 + unsafeMap1.getSizeInBytes) - val map2Size = roundedSize(8 + unsafeMap2.getSizeInBytes) - assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size) + val valueArray = field2.valueArray + + { + assert(valueArray.numElements == 1) + + val innerStruct = valueArray.getStruct(0, 1) + assert(innerStruct.numFields == 1) + assert(innerStruct.getSizeInBytes == 8 + 8) + assert(innerStruct.getLong(0) == 4L) + + assert(valueArray.getSizeInBytes == 4 + innerStruct.getSizeInBytes) + } + + assert(field2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes) + + assert(unsafeRow.getSizeInBytes == + 8 + 8 * 2 + field1.getSizeInBytes + mapSizeInRow(field2.getSizeInBytes)) + } + + test("basic conversion with array and map") { + val fieldTypes: Array[DataType] = Array( + ArrayType(MapType(IntegerType, IntegerType)), + MapType(IntegerType, ArrayType(IntegerType)) + ) + val converter = UnsafeProjection.create(fieldTypes) + + val row = new GenericMutableRow(fieldTypes.length) + row.update(0, createArray(createMap(1)(2))) + row.update(1, createMap(3)(createArray(4))) + + val unsafeRow: UnsafeRow = converter.apply(row) + assert(unsafeRow.numFields() == 2) + + val field1 = unsafeRow.getArray(0) + assert(field1.numElements == 1) + + val innerMap = field1.getMap(0) + testMapInt(innerMap, Seq(1), Seq(2)) + + assert(field1.getSizeInBytes == 4 + (8 + innerMap.getSizeInBytes)) + + val field2 = unsafeRow.getMap(1) + assert(field2.numElements == 1) + + val keyArray = field2.keyArray + testArrayInt(keyArray, Seq(3)) + + val valueArray = field2.valueArray + + { + assert(valueArray.numElements == 1) + + val innerArray = valueArray.getArray(0) + testArrayInt(innerArray, Seq(4)) + + assert(valueArray.getSizeInBytes == 4 + (4 + innerArray.getSizeInBytes)) + } + + assert(field2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes) + + assert(unsafeRow.getSizeInBytes == + 8 + 8 * 2 + arraySizeInRow(field1.getSizeInBytes) + mapSizeInRow(field2.getSizeInBytes)) } }