From 1d9733271595596683a6d956a7433fa601df1cc1 Mon Sep 17 00:00:00 2001 From: Davies Liu <davies@databricks.com> Date: Wed, 21 Oct 2015 19:20:31 -0700 Subject: [PATCH] [SPARK-11243][SQL] output UnsafeRow from columnar cache This PR change InMemoryTableScan to output UnsafeRow, and optimize the unrolling and scanning by coping the bytes for var-length types between UnsafeRow and ByteBuffer directly without creating the wrapper objects. When scanning the decimals in TPC-DS store_sales table, it's 80% faster (copy it as long without create Decimal objects). Author: Davies Liu <davies@databricks.com> Closes #9203 from davies/unsafe_cache. --- .../sql/catalyst/expressions/UnsafeRow.java | 31 +++++- .../codegen/UnsafeArrayWriter.java | 78 ++++++++++---- .../expressions/codegen/UnsafeRowWriter.java | 102 ++++++++++++------ .../codegen/GenerateUnsafeProjection.scala | 57 +--------- .../spark/sql/columnar/ColumnType.scala | 81 ++++++++++++-- .../sql/columnar/GenerateColumnAccessor.scala | 68 ++++++++++-- .../columnar/InMemoryColumnarTableScan.scala | 6 +- 7 files changed, 291 insertions(+), 132 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 366615f6fe..850838af9b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -402,7 +402,7 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS if (isNullAt(ordinal)) return null; final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; return UTF8String.fromAddress(baseObject, baseOffset + offset, size); } @@ -413,7 +413,7 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; final byte[] bytes = new byte[size]; Platform.copyMemory( baseObject, @@ -446,7 +446,7 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; final UnsafeRow row = new UnsafeRow(); row.pointTo(baseObject, baseOffset + offset, numFields, size); return row; @@ -460,7 +460,7 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; final UnsafeArrayData array = new UnsafeArrayData(); array.pointTo(baseObject, baseOffset + offset, size); return array; @@ -474,7 +474,7 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS } else { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); - final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final int size = (int) offsetAndSize; final UnsafeMapData map = new UnsafeMapData(); map.pointTo(baseObject, baseOffset + offset, size); return map; @@ -618,6 +618,27 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS buffer.position(pos + sizeInBytes); } + /** + * Write the bytes of var-length field into ByteBuffer + * + * Note: only work with HeapByteBuffer + */ + public void writeFieldTo(int ordinal, ByteBuffer buffer) { + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) offsetAndSize; + + buffer.putInt(size); + int pos = buffer.position(); + buffer.position(pos + size); + Platform.copyMemory( + baseObject, + baseOffset + offset, + buffer.array(), + Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + pos, + size); + } + @Override public void writeExternal(ObjectOutput out) throws IOException { byte[] bytes = getBytes(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index 7f2a1cb07a..7dd932d198 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen; -import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.CalendarInterval; @@ -64,29 +63,72 @@ public class UnsafeArrayWriter { Platform.putInt(holder.buffer, getElementOffset(ordinal), relativeOffset); } - public void writeCompactDecimal(int ordinal, Decimal input, int precision, int scale) { - // make sure Decimal object has the same scale as DecimalType - if (input.changePrecision(precision, scale)) { - Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong()); - setOffset(ordinal); - holder.cursor += 8; - } else { - setNullAt(ordinal); + public void write(int ordinal, boolean value) { + Platform.putBoolean(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 1; + } + + public void write(int ordinal, byte value) { + Platform.putByte(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 1; + } + + public void write(int ordinal, short value) { + Platform.putShort(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 2; + } + + public void write(int ordinal, int value) { + Platform.putInt(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 4; + } + + public void write(int ordinal, long value) { + Platform.putLong(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 8; + } + + public void write(int ordinal, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; + } + Platform.putFloat(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 4; + } + + public void write(int ordinal, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; } + Platform.putDouble(holder.buffer, holder.cursor, value); + setOffset(ordinal); + holder.cursor += 8; } public void write(int ordinal, Decimal input, int precision, int scale) { // make sure Decimal object has the same scale as DecimalType if (input.changePrecision(precision, scale)) { - final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - assert bytes.length <= 16; - holder.grow(bytes.length); - - // Write the bytes to the variable length portion. - Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); - setOffset(ordinal); - holder.cursor += bytes.length; + if (precision <= Decimal.MAX_LONG_DIGITS()) { + Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong()); + setOffset(ordinal); + holder.cursor += 8; + } else { + final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); + assert bytes.length <= 16; + holder.grow(bytes.length); + + // Write the bytes to the variable length portion. + Platform.copyMemory( + bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + setOffset(ordinal); + holder.cursor += bytes.length; + } } else { setNullAt(ordinal); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index e1f5a05d1d..adbe262187 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -58,6 +58,10 @@ public class UnsafeRowWriter { } } + public boolean isNullAt(int ordinal) { + return BitSetMethods.isSet(holder.buffer, startingOffset, ordinal); + } + public void setNullAt(int ordinal) { BitSetMethods.set(holder.buffer, startingOffset, ordinal); Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L); @@ -95,41 +99,75 @@ public class UnsafeRowWriter { } } - public void writeCompactDecimal(int ordinal, Decimal input, int precision, int scale) { - // make sure Decimal object has the same scale as DecimalType - if (input.changePrecision(precision, scale)) { - Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong()); - } else { - setNullAt(ordinal); - } + public void write(int ordinal, boolean value) { + Platform.putBoolean(holder.buffer, getFieldOffset(ordinal), value); } - public void write(int ordinal, Decimal input, int precision, int scale) { - // grow the global buffer before writing data. - holder.grow(16); + public void write(int ordinal, byte value) { + Platform.putByte(holder.buffer, getFieldOffset(ordinal), value); + } - // zero-out the bytes - Platform.putLong(holder.buffer, holder.cursor, 0L); - Platform.putLong(holder.buffer, holder.cursor + 8, 0L); + public void write(int ordinal, short value) { + Platform.putShort(holder.buffer, getFieldOffset(ordinal), value); + } - // Make sure Decimal object has the same scale as DecimalType. - // Note that we may pass in null Decimal object to set null for it. - if (input == null || !input.changePrecision(precision, scale)) { - BitSetMethods.set(holder.buffer, startingOffset, ordinal); - // keep the offset for future update - setOffsetAndSize(ordinal, 0L); - } else { - final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - assert bytes.length <= 16; + public void write(int ordinal, int value) { + Platform.putInt(holder.buffer, getFieldOffset(ordinal), value); + } - // Write the bytes to the variable length portion. - Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); - setOffsetAndSize(ordinal, bytes.length); + public void write(int ordinal, long value) { + Platform.putLong(holder.buffer, getFieldOffset(ordinal), value); + } + + public void write(int ordinal, float value) { + if (Float.isNaN(value)) { + value = Float.NaN; } + Platform.putFloat(holder.buffer, getFieldOffset(ordinal), value); + } - // move the cursor forward. - holder.cursor += 16; + public void write(int ordinal, double value) { + if (Double.isNaN(value)) { + value = Double.NaN; + } + Platform.putDouble(holder.buffer, getFieldOffset(ordinal), value); + } + + public void write(int ordinal, Decimal input, int precision, int scale) { + if (precision <= Decimal.MAX_LONG_DIGITS()) { + // make sure Decimal object has the same scale as DecimalType + if (input.changePrecision(precision, scale)) { + Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong()); + } else { + setNullAt(ordinal); + } + } else { + // grow the global buffer before writing data. + holder.grow(16); + + // zero-out the bytes + Platform.putLong(holder.buffer, holder.cursor, 0L); + Platform.putLong(holder.buffer, holder.cursor + 8, 0L); + + // Make sure Decimal object has the same scale as DecimalType. + // Note that we may pass in null Decimal object to set null for it. + if (input == null || !input.changePrecision(precision, scale)) { + BitSetMethods.set(holder.buffer, startingOffset, ordinal); + // keep the offset for future update + setOffsetAndSize(ordinal, 0L); + } else { + final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); + assert bytes.length <= 16; + + // Write the bytes to the variable length portion. + Platform.copyMemory( + bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length); + setOffsetAndSize(ordinal, bytes.length); + } + + // move the cursor forward. + holder.cursor += 16; + } } public void write(int ordinal, UTF8String input) { @@ -151,7 +189,10 @@ public class UnsafeRowWriter { } public void write(int ordinal, byte[] input) { - final int numBytes = input.length; + write(ordinal, input, 0, input.length); + } + + public void write(int ordinal, byte[] input, int offset, int numBytes) { final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); // grow the global buffer before writing data. @@ -160,7 +201,8 @@ public class UnsafeRowWriter { zeroOutPaddingBytes(numBytes); // Write the bytes to the variable length portion. - Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes); + Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET + offset, + holder.buffer, holder.cursor, numBytes); setOffsetAndSize(ordinal, numBytes); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index dbe92d6a83..2136f82ba4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -89,7 +89,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val setNull = dt match { case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => // Can't call setNullAt() for DecimalType with precision larger than 18. - s"$rowWriter.write($index, null, ${t.precision}, ${t.scale});" + s"$rowWriter.write($index, (Decimal) null, ${t.precision}, ${t.scale});" case _ => s"$rowWriter.setNullAt($index);" } @@ -124,17 +124,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """ case _ if ctx.isPrimitiveType(dt) => - val fieldOffset = ctx.freshName("fieldOffset") s""" - final long $fieldOffset = $rowWriter.getFieldOffset($index); - Platform.putLong($bufferHolder.buffer, $fieldOffset, 0L); - ${writePrimitiveType(ctx, input.value, dt, s"$bufferHolder.buffer", fieldOffset)} + $rowWriter.write($index, ${input.value}); """ - case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => - s"$rowWriter.writeCompactDecimal($index, ${input.value}, " + - s"${t.precision}, ${t.scale});" - case t: DecimalType => s"$rowWriter.write($index, ${input.value}, ${t.precision}, ${t.scale});" @@ -204,20 +197,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} """ - case _ if ctx.isPrimitiveType(et) => - // Should we do word align? - val dataSize = et.defaultSize - - s""" - $arrayWriter.setOffset($index); - ${writePrimitiveType(ctx, element, et, - s"$bufferHolder.buffer", s"$bufferHolder.cursor")} - $bufferHolder.cursor += $dataSize; - """ - - case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => - s"$arrayWriter.writeCompactDecimal($index, $element, ${t.precision}, ${t.scale});" - case t: DecimalType => s"$arrayWriter.write($index, $element, ${t.precision}, ${t.scale});" @@ -296,38 +275,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """ } - private def writePrimitiveType( - ctx: CodeGenContext, - input: String, - dt: DataType, - buffer: String, - offset: String) = { - assert(ctx.isPrimitiveType(dt)) - - val putMethod = s"put${ctx.primitiveTypeName(dt)}" - - dt match { - case FloatType | DoubleType => - val normalized = ctx.freshName("normalized") - val boxedType = ctx.boxedType(dt) - val handleNaN = - s""" - final ${ctx.javaType(dt)} $normalized; - if ($boxedType.isNaN($input)) { - $normalized = $boxedType.NaN; - } else { - $normalized = $input; - } - """ - - s""" - $handleNaN - Platform.$putMethod($buffer, $offset, $normalized); - """ - case _ => s"Platform.$putMethod($buffer, $offset, $input);" - } - } - def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = { val exprEvals = expressions.map(e => e.gen(ctx)) val exprTypes = expressions.map(_.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 72fa299aa9..68e509eb50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -32,6 +32,13 @@ import org.apache.spark.unsafe.types.UTF8String /** * A help class for fast reading Int/Long/Float/Double from ByteBuffer in native order. * + * Note: There is not much difference between ByteBuffer.getByte/getShort and + * Unsafe.getByte/getShort, so we do not have helper methods for them. + * + * The unrolling (building columnar cache) is already slow, putLong/putDouble will not help much, + * so we do not have helper methods for them. + * + * * WARNNING: This only works with HeapByteBuffer */ object ByteBufferHelper { @@ -351,7 +358,38 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 2) { } } -private[sql] object STRING extends NativeColumnType(StringType, 8) { +/** + * A fast path to copy var-length bytes between ByteBuffer and UnsafeRow without creating wrapper + * objects. + */ +private[sql] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { + + // copy the bytes from ByteBuffer to UnsafeRow + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + if (row.isInstanceOf[MutableUnsafeRow]) { + val numBytes = buffer.getInt + val cursor = buffer.position() + buffer.position(cursor + numBytes) + row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, buffer.array(), + buffer.arrayOffset() + cursor, numBytes) + } else { + setField(row, ordinal, extract(buffer)) + } + } + + // copy the bytes from UnsafeRow to ByteBuffer + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + if (row.isInstanceOf[UnsafeRow]) { + row.asInstanceOf[UnsafeRow].writeFieldTo(ordinal, buffer) + } else { + super.append(row, ordinal, buffer) + } + } +} + +private[sql] object STRING + extends NativeColumnType(StringType, 8) with DirectCopyColumnType[UTF8String] { + override def actualSize(row: InternalRow, ordinal: Int): Int = { row.getUTF8String(ordinal).numBytes() + 4 } @@ -363,16 +401,17 @@ private[sql] object STRING extends NativeColumnType(StringType, 8) { override def extract(buffer: ByteBuffer): UTF8String = { val length = buffer.getInt() - assert(buffer.hasArray) - val base = buffer.array() - val offset = buffer.arrayOffset() val cursor = buffer.position() buffer.position(cursor + length) - UTF8String.fromBytes(base, offset + cursor, length) + UTF8String.fromBytes(buffer.array(), buffer.arrayOffset() + cursor, length) } override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = { - row.update(ordinal, value.clone()) + if (row.isInstanceOf[MutableUnsafeRow]) { + row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, value) + } else { + row.update(ordinal, value.clone()) + } } override def getField(row: InternalRow, ordinal: Int): UTF8String = { @@ -393,10 +432,28 @@ private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int) Decimal(ByteBufferHelper.getLong(buffer), precision, scale) } + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + if (row.isInstanceOf[MutableUnsafeRow]) { + // copy it as Long + row.setLong(ordinal, ByteBufferHelper.getLong(buffer)) + } else { + setField(row, ordinal, extract(buffer)) + } + } + override def append(v: Decimal, buffer: ByteBuffer): Unit = { buffer.putLong(v.toUnscaledLong) } + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + if (row.isInstanceOf[UnsafeRow]) { + // copy it as Long + buffer.putLong(row.getLong(ordinal)) + } else { + append(getField(row, ordinal), buffer) + } + } + override def getField(row: InternalRow, ordinal: Int): Decimal = { row.getDecimal(ordinal, precision, scale) } @@ -417,7 +474,7 @@ private[sql] object COMPACT_DECIMAL { } private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int) - extends ColumnType[JvmType] { + extends ColumnType[JvmType] with DirectCopyColumnType[JvmType] { def serialize(value: JvmType): Array[Byte] def deserialize(bytes: Array[Byte]): JvmType @@ -488,7 +545,8 @@ private[sql] object LARGE_DECIMAL { } } -private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRow] { +private[sql] case class STRUCT(dataType: StructType) + extends ColumnType[UnsafeRow] with DirectCopyColumnType[UnsafeRow] { private val numOfFields: Int = dataType.fields.size @@ -528,7 +586,8 @@ private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRo override def clone(v: UnsafeRow): UnsafeRow = v.copy() } -private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArrayData] { +private[sql] case class ARRAY(dataType: ArrayType) + extends ColumnType[UnsafeArrayData] with DirectCopyColumnType[UnsafeArrayData] { override def defaultSize: Int = 16 @@ -566,7 +625,8 @@ private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArra override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy() } -private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] { +private[sql] case class MAP(dataType: MapType) + extends ColumnType[UnsafeMapData] with DirectCopyColumnType[UnsafeMapData] { override def defaultSize: Int = 32 @@ -590,7 +650,6 @@ private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] override def extract(buffer: ByteBuffer): UnsafeMapData = { val numBytes = buffer.getInt - assert(buffer.hasArray) val cursor = buffer.position() buffer.position(cursor + numBytes) val map = new UnsafeMapData diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala index e04bcda580..d0f5bfa1cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/GenerateColumnAccessor.scala @@ -20,17 +20,43 @@ package org.apache.spark.sql.columnar import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator} +import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, CodeFormatter, CodeGenerator} import org.apache.spark.sql.types._ /** - * An Iterator to walk throught the InternalRows from a CachedBatch + * An Iterator to walk through the InternalRows from a CachedBatch */ abstract class ColumnarIterator extends Iterator[InternalRow] { - def initialize(input: Iterator[CachedBatch], mutableRow: MutableRow, columnTypes: Array[DataType], + def initialize(input: Iterator[CachedBatch], columnTypes: Array[DataType], columnIndexes: Array[Int]): Unit } +/** + * An helper class to update the fields of UnsafeRow, used by ColumnAccessor + * + * WARNNING: These setter MUST be called in increasing order of ordinals. + */ +class MutableUnsafeRow(val writer: UnsafeRowWriter) extends GenericMutableRow(null) { + + override def isNullAt(i: Int): Boolean = writer.isNullAt(i) + override def setNullAt(i: Int): Unit = writer.setNullAt(i) + + override def setBoolean(i: Int, v: Boolean): Unit = writer.write(i, v) + override def setByte(i: Int, v: Byte): Unit = writer.write(i, v) + override def setShort(i: Int, v: Short): Unit = writer.write(i, v) + override def setInt(i: Int, v: Int): Unit = writer.write(i, v) + override def setLong(i: Int, v: Long): Unit = writer.write(i, v) + override def setFloat(i: Int, v: Float): Unit = writer.write(i, v) + override def setDouble(i: Int, v: Double): Unit = writer.write(i, v) + + // the writer will be used directly to avoid creating wrapper objects + override def setDecimal(i: Int, v: Decimal, precision: Int): Unit = + throw new UnsupportedOperationException + override def update(i: Int, v: Any): Unit = throw new UnsupportedOperationException + + // all other methods inherited from GenericMutableRow are not need +} + /** * Generates bytecode for an [[ColumnarIterator]] for columnar cache. */ @@ -41,6 +67,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera protected def create(columnTypes: Seq[DataType]): ColumnarIterator = { val ctx = newCodeGenContext() + val numFields = columnTypes.size val (initializeAccessors, extractors) = columnTypes.zipWithIndex.map { case (dt, index) => val accessorName = ctx.freshName("accessor") val accessorCls = dt match { @@ -74,13 +101,27 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera } val extract = s"$accessorName.extractTo(mutableRow, $index);" - - (createCode, extract) + val patch = dt match { + case DecimalType.Fixed(p, s) if p > Decimal.MAX_LONG_DIGITS => + // For large Decimal, it should have 16 bytes for future update even it's null now. + s""" + if (mutableRow.isNullAt($index)) { + rowWriter.write($index, (Decimal) null, $p, $s); + } + """ + case other => "" + } + (createCode, extract + patch) }.unzip val code = s""" import java.nio.ByteBuffer; import java.nio.ByteOrder; + import scala.collection.Iterator; + import org.apache.spark.sql.types.DataType; + import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; + import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; + import org.apache.spark.sql.columnar.MutableUnsafeRow; public SpecificColumnarIterator generate($exprType[] expr) { return new SpecificColumnarIterator(); @@ -90,13 +131,17 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private ByteOrder nativeOrder = null; private byte[][] buffers = null; + private UnsafeRow unsafeRow = new UnsafeRow(); + private BufferHolder bufferHolder = new BufferHolder(); + private UnsafeRowWriter rowWriter = new UnsafeRowWriter(); + private MutableUnsafeRow mutableRow = null; private int currentRow = 0; private int numRowsInBatch = 0; private scala.collection.Iterator input = null; private MutableRow mutableRow = null; - private ${classOf[DataType].getName}[] columnTypes = null; + private DataType[] columnTypes = null; private int[] columnIndexes = null; ${declareMutableStates(ctx)} @@ -104,12 +149,12 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera public SpecificColumnarIterator() { this.nativeOrder = ByteOrder.nativeOrder(); this.buffers = new byte[${columnTypes.length}][]; + this.mutableRow = new MutableUnsafeRow(rowWriter); ${initMutableStates(ctx)} } - public void initialize(scala.collection.Iterator input, MutableRow mutableRow, - ${classOf[DataType].getName}[] columnTypes, int[] columnIndexes) { + public void initialize(Iterator input, DataType[] columnTypes, int[] columnIndexes) { this.input = input; this.mutableRow = mutableRow; this.columnTypes = columnTypes; @@ -136,9 +181,12 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera } public InternalRow next() { - ${extractors.mkString("\n")} currentRow += 1; - return mutableRow; + bufferHolder.reset(); + rowWriter.initialize(bufferHolder, $numFields); + ${extractors.mkString("\n")} + unsafeRow.pointTo(bufferHolder.buffer, $numFields, bufferHolder.totalSize()); + return unsafeRow; } }""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 9f76a61a15..b4607b12fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -209,6 +209,8 @@ private[sql] case class InMemoryColumnarTableScan( override def output: Seq[Attribute] = attributes + override def outputsUnsafeRows: Boolean = true + private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) // Returned filter predicate should return false iff it is impossible for the input expression @@ -317,14 +319,12 @@ private[sql] case class InMemoryColumnarTableScan( cachedBatchIterator } - val nextRow = new SpecificMutableRow(requestedColumnDataTypes) val columnTypes = requestedColumnDataTypes.map { case udt: UserDefinedType[_] => udt.sqlType case other => other }.toArray val columnarIterator = GenerateColumnAccessor.generate(columnTypes) - columnarIterator.initialize(cachedBatchesToScan, nextRow, columnTypes, - requestedColumnIndices.toArray) + columnarIterator.initialize(cachedBatchesToScan, columnTypes, requestedColumnIndices.toArray) if (enableAccumulators && columnarIterator.hasNext) { readPartitions += 1 } -- GitLab