diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index cf5322125bd7284544ae7b859c3eeda8cee8fcb3..5dd661ee6b339b0f73217cefd8728ce0c432ad9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -148,6 +148,28 @@ object DecimalType extends AbstractDataType { } } + /** + * Returns if dt is a DecimalType that fits inside a long + */ + def is64BitDecimalType(dt: DataType): Boolean = { + dt match { + case t: DecimalType => + t.precision <= Decimal.MAX_LONG_DIGITS + case _ => false + } + } + + /** + * Returns if dt is a DecimalType that doesn't fit inside a long + */ + def isByteArrayDecimalType(dt: DataType): Boolean = { + dt match { + case t: DecimalType => + t.precision > Decimal.MAX_LONG_DIGITS + case _ => false + } + } + def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType] def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType] diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index a0bf8734b6545ef05ef1fdb66969712240d57d0a..a5bc506a65ac2ad77166bcfa18a24ed5243610f0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -16,6 +16,9 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.math.BigDecimal; +import java.math.BigInteger; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; @@ -102,18 +105,36 @@ public abstract class ColumnVector { DataType dt = data.dataType(); Object[] list = new Object[length]; - if (dt instanceof ByteType) { + if (dt instanceof BooleanType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getBoolean(offset + i); + } + } + } else if (dt instanceof ByteType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { list[i] = data.getByte(offset + i); } } + } else if (dt instanceof ShortType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getShort(offset + i); + } + } } else if (dt instanceof IntegerType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { list[i] = data.getInt(offset + i); } } + } else if (dt instanceof FloatType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = data.getFloat(offset + i); + } + } } else if (dt instanceof DoubleType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { @@ -126,12 +147,25 @@ public abstract class ColumnVector { list[i] = data.getLong(offset + i); } } + } else if (dt instanceof DecimalType) { + DecimalType decType = (DecimalType)dt; + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = getDecimal(i, decType.precision(), decType.scale()); + } + } } else if (dt instanceof StringType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { list[i] = ColumnVectorUtils.toString(data.getByteArray(offset + i)); } } + } else if (dt instanceof CalendarIntervalType) { + for (int i = 0; i < length; i++) { + if (!data.getIsNull(offset + i)) { + list[i] = getInterval(i); + } + } } else { throw new NotImplementedException("Type " + dt); } @@ -170,7 +204,14 @@ public abstract class ColumnVector { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - throw new NotImplementedException(); + if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.apply(getLong(ordinal), precision, scale); + } else { + byte[] bytes = getBinary(ordinal); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(javaDecimal, precision, scale); + } } @Override @@ -181,17 +222,22 @@ public abstract class ColumnVector { @Override public byte[] getBinary(int ordinal) { - throw new NotImplementedException(); + ColumnVector.Array array = data.getByteArray(offset + ordinal); + byte[] bytes = new byte[array.length]; + System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); + return bytes; } @Override public CalendarInterval getInterval(int ordinal) { - throw new NotImplementedException(); + int month = data.getChildColumn(0).getInt(offset + ordinal); + long microseconds = data.getChildColumn(1).getLong(offset + ordinal); + return new CalendarInterval(month, microseconds); } @Override public InternalRow getStruct(int ordinal, int numFields) { - throw new NotImplementedException(); + return data.getStruct(offset + ordinal); } @Override @@ -279,6 +325,21 @@ public abstract class ColumnVector { */ public abstract boolean getIsNull(int rowId); + /** + * Sets the value at rowId to `value`. + */ + public abstract void putBoolean(int rowId, boolean value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putBooleans(int rowId, int count, boolean value); + + /** + * Returns the value for rowId. + */ + public abstract boolean getBoolean(int rowId); + /** * Sets the value at rowId to `value`. */ @@ -299,6 +360,26 @@ public abstract class ColumnVector { */ public abstract byte getByte(int rowId); + /** + * Sets the value at rowId to `value`. + */ + public abstract void putShort(int rowId, short value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putShorts(int rowId, int count, short value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putShorts(int rowId, int count, short[] src, int srcIndex); + + /** + * Returns the value for rowId. + */ + public abstract short getShort(int rowId); + /** * Sets the value at rowId to `value`. */ @@ -351,6 +432,33 @@ public abstract class ColumnVector { */ public abstract long getLong(int rowId); + /** + * Sets the value at rowId to `value`. + */ + public abstract void putFloat(int rowId, float value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putFloats(int rowId, int count, float value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * src should contain `count` doubles written as ieee format. + */ + public abstract void putFloats(int rowId, int count, float[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * The data in src must be ieee formatted floats. + */ + public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex); + + /** + * Returns the value for rowId. + */ + public abstract float getFloat(int rowId); + /** * Sets the value at rowId to `value`. */ @@ -369,7 +477,7 @@ public abstract class ColumnVector { /** * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be ieee formated doubles. + * The data in src must be ieee formatted doubles. */ public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex); @@ -469,6 +577,20 @@ public abstract class ColumnVector { return result; } + public final int appendBoolean(boolean v) { + reserve(elementsAppended + 1); + putBoolean(elementsAppended, v); + return elementsAppended++; + } + + public final int appendBooleans(int count, boolean v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putBooleans(elementsAppended, count, v); + elementsAppended += count; + return result; + } + public final int appendByte(byte v) { reserve(elementsAppended + 1); putByte(elementsAppended, v); @@ -491,6 +613,28 @@ public abstract class ColumnVector { return result; } + public final int appendShort(short v) { + reserve(elementsAppended + 1); + putShort(elementsAppended, v); + return elementsAppended++; + } + + public final int appendShorts(int count, short v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putShorts(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendShorts(int length, short[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putShorts(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + public final int appendInt(int v) { reserve(elementsAppended + 1); putInt(elementsAppended, v); @@ -535,6 +679,20 @@ public abstract class ColumnVector { return result; } + public final int appendFloat(float v) { + reserve(elementsAppended + 1); + putFloat(elementsAppended, v); + return elementsAppended++; + } + + public final int appendFloats(int count, float v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putFloats(elementsAppended, count, v); + elementsAppended += count; + return result; + } + public final int appendDouble(double v) { reserve(elementsAppended + 1); putDouble(elementsAppended, v); @@ -661,7 +819,8 @@ public abstract class ColumnVector { this.capacity = capacity; this.type = type; - if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType) { + if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType + || DecimalType.isByteArrayDecimalType(type)) { DataType childType; int childCapacity = capacity; if (type instanceof ArrayType) { @@ -682,6 +841,13 @@ public abstract class ColumnVector { } this.resultArray = null; this.resultStruct = new ColumnarBatch.Row(this.childColumns); + } else if (type instanceof CalendarIntervalType) { + // Two columns. Months as int. Microseconds as Long. + this.childColumns = new ColumnVector[2]; + this.childColumns[0] = ColumnVector.allocate(capacity, DataTypes.IntegerType, memMode); + this.childColumns[1] = ColumnVector.allocate(capacity, DataTypes.LongType, memMode); + this.resultArray = null; + this.resultStruct = new ColumnarBatch.Row(this.childColumns); } else { this.childColumns = null; this.resultArray = null; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 6c651a759d250ec32375c84a0d21d7598e57e88f..453bc15e1350318b15dcae52d18f370cc02fd230 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -16,12 +16,15 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.math.BigDecimal; +import java.math.BigInteger; import java.util.Iterator; import java.util.List; import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.Row; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.commons.lang.NotImplementedException; @@ -59,19 +62,44 @@ public class ColumnVectorUtils { private static void appendValue(ColumnVector dst, DataType t, Object o) { if (o == null) { - dst.appendNull(); + if (t instanceof CalendarIntervalType) { + dst.appendStruct(true); + } else { + dst.appendNull(); + } } else { - if (t == DataTypes.ByteType) { - dst.appendByte(((Byte)o).byteValue()); + if (t == DataTypes.BooleanType) { + dst.appendBoolean(((Boolean)o).booleanValue()); + } else if (t == DataTypes.ByteType) { + dst.appendByte(((Byte) o).byteValue()); + } else if (t == DataTypes.ShortType) { + dst.appendShort(((Short)o).shortValue()); } else if (t == DataTypes.IntegerType) { dst.appendInt(((Integer)o).intValue()); } else if (t == DataTypes.LongType) { dst.appendLong(((Long)o).longValue()); + } else if (t == DataTypes.FloatType) { + dst.appendFloat(((Float)o).floatValue()); } else if (t == DataTypes.DoubleType) { dst.appendDouble(((Double)o).doubleValue()); } else if (t == DataTypes.StringType) { byte[] b =((String)o).getBytes(); dst.appendByteArray(b, 0, b.length); + } else if (t instanceof DecimalType) { + DecimalType dt = (DecimalType)t; + Decimal d = Decimal.apply((BigDecimal)o, dt.precision(), dt.scale()); + if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) { + dst.appendLong(d.toUnscaledLong()); + } else { + final BigInteger integer = d.toJavaBigDecimal().unscaledValue(); + byte[] bytes = integer.toByteArray(); + dst.appendByteArray(bytes, 0, bytes.length); + } + } else if (t instanceof CalendarIntervalType) { + CalendarInterval c = (CalendarInterval)o; + dst.appendStruct(false); + dst.getChildColumn(0).appendInt(c.months); + dst.getChildColumn(1).appendLong(c.microseconds); } else { throw new NotImplementedException("Type " + t); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 5a575811fa8963d49cfc3436da976bf5a5d85685..dbad5e070f1fea2f4bb4983a4671a196b7689174 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.math.BigDecimal; +import java.math.BigInteger; import java.util.Arrays; import java.util.Iterator; @@ -25,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -150,44 +153,40 @@ public final class ColumnarBatch { } @Override - public final boolean isNullAt(int ordinal) { - return columns[ordinal].getIsNull(rowId); - } + public final boolean isNullAt(int ordinal) { return columns[ordinal].getIsNull(rowId); } @Override - public final boolean getBoolean(int ordinal) { - throw new NotImplementedException(); - } + public final boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); } @Override public final byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } @Override - public final short getShort(int ordinal) { - throw new NotImplementedException(); - } + public final short getShort(int ordinal) { return columns[ordinal].getShort(rowId); } @Override - public final int getInt(int ordinal) { - return columns[ordinal].getInt(rowId); - } + public final int getInt(int ordinal) { return columns[ordinal].getInt(rowId); } @Override public final long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } @Override - public final float getFloat(int ordinal) { - throw new NotImplementedException(); - } + public final float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); } @Override - public final double getDouble(int ordinal) { - return columns[ordinal].getDouble(rowId); - } + public final double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); } @Override public final Decimal getDecimal(int ordinal, int precision, int scale) { - throw new NotImplementedException(); + if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.apply(getLong(ordinal), precision, scale); + } else { + // TODO: best perf? + byte[] bytes = getBinary(ordinal); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(javaDecimal, precision, scale); + } } @Override @@ -198,12 +197,17 @@ public final class ColumnarBatch { @Override public final byte[] getBinary(int ordinal) { - throw new NotImplementedException(); + ColumnVector.Array array = columns[ordinal].getByteArray(rowId); + byte[] bytes = new byte[array.length]; + System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); + return bytes; } @Override public final CalendarInterval getInterval(int ordinal) { - throw new NotImplementedException(); + final int months = columns[ordinal].getChildColumn(0).getInt(rowId); + final long microseconds = columns[ordinal].getChildColumn(1).getLong(rowId); + return new CalendarInterval(months, microseconds); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 335124fd5a6033a2d0cd2c1f2be9ab4851600440..22c5e5fc81a4a33611b3649d9794a22831022c95 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -19,11 +19,15 @@ package org.apache.spark.sql.execution.vectorized; import java.nio.ByteOrder; import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.types.BooleanType; import org.apache.spark.sql.types.ByteType; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DecimalType; import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; import org.apache.spark.sql.types.IntegerType; import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.ShortType; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.UTF8String; @@ -121,6 +125,26 @@ public final class OffHeapColumnVector extends ColumnVector { return Platform.getByte(null, nulls + rowId) == 1; } + // + // APIs dealing with Booleans + // + + @Override + public final void putBoolean(int rowId, boolean value) { + Platform.putByte(null, data + rowId, (byte)((value) ? 1 : 0)); + } + + @Override + public final void putBooleans(int rowId, int count, boolean value) { + byte v = (byte)((value) ? 1 : 0); + for (int i = 0; i < count; ++i) { + Platform.putByte(null, data + rowId + i, v); + } + } + + @Override + public final boolean getBoolean(int rowId) { return Platform.getByte(null, data + rowId) == 1; } + // // APIs dealing with Bytes // @@ -148,6 +172,34 @@ public final class OffHeapColumnVector extends ColumnVector { return Platform.getByte(null, data + rowId); } + // + // APIs dealing with shorts + // + + @Override + public final void putShort(int rowId, short value) { + Platform.putShort(null, data + 2 * rowId, value); + } + + @Override + public final void putShorts(int rowId, int count, short value) { + long offset = data + 2 * rowId; + for (int i = 0; i < count; ++i, offset += 4) { + Platform.putShort(null, offset, value); + } + } + + @Override + public final void putShorts(int rowId, int count, short[] src, int srcIndex) { + Platform.copyMemory(src, Platform.SHORT_ARRAY_OFFSET + srcIndex * 2, + null, data + 2 * rowId, count * 2); + } + + @Override + public final short getShort(int rowId) { + return Platform.getShort(null, data + 2 * rowId); + } + // // APIs dealing with ints // @@ -216,6 +268,41 @@ public final class OffHeapColumnVector extends ColumnVector { return Platform.getLong(null, data + 8 * rowId); } + // + // APIs dealing with floats + // + + @Override + public final void putFloat(int rowId, float value) { + Platform.putFloat(null, data + rowId * 4, value); + } + + @Override + public final void putFloats(int rowId, int count, float value) { + long offset = data + 4 * rowId; + for (int i = 0; i < count; ++i, offset += 4) { + Platform.putFloat(null, offset, value); + } + } + + @Override + public final void putFloats(int rowId, int count, float[] src, int srcIndex) { + Platform.copyMemory(src, Platform.FLOAT_ARRAY_OFFSET + srcIndex * 4, + null, data + 4 * rowId, count * 4); + } + + @Override + public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 4, count * 4); + } + + @Override + public final float getFloat(int rowId) { + return Platform.getFloat(null, data + rowId * 4); + } + + // // APIs dealing with doubles // @@ -241,7 +328,7 @@ public final class OffHeapColumnVector extends ColumnVector { @Override public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { - Platform.copyMemory(src, Platform.DOUBLE_ARRAY_OFFSET + srcIndex, + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, null, data + rowId * 8, count * 8); } @@ -300,11 +387,14 @@ public final class OffHeapColumnVector extends ColumnVector { Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4); this.offsetData = Platform.reallocateMemory(offsetData, elementsAppended * 4, newCapacity * 4); - } else if (type instanceof ByteType) { + } else if (type instanceof ByteType || type instanceof BooleanType) { this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity); - } else if (type instanceof IntegerType) { + } else if (type instanceof ShortType) { + this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2); + } else if (type instanceof IntegerType || type instanceof FloatType) { this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); - } else if (type instanceof LongType || type instanceof DoubleType) { + } else if (type instanceof LongType || type instanceof DoubleType || + DecimalType.is64BitDecimalType(type)) { this.data = Platform.reallocateMemory(data, elementsAppended * 8, newCapacity * 8); } else if (resultStruct != null) { // Nothing to store. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 8197fa11cd4c8632a709690c7f4cd60eb1bde7b5..32356334c031f2ffba7eebbf21e66d55f46b4f3f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -35,8 +35,10 @@ public final class OnHeapColumnVector extends ColumnVector { // Array for each type. Only 1 is populated for any type. private byte[] byteData; + private short[] shortData; private int[] intData; private long[] longData; + private float[] floatData; private double[] doubleData; // Only set if type is Array. @@ -104,6 +106,30 @@ public final class OnHeapColumnVector extends ColumnVector { return nulls[rowId] == 1; } + // + // APIs dealing with Booleans + // + + @Override + public final void putBoolean(int rowId, boolean value) { + byteData[rowId] = (byte)((value) ? 1 : 0); + } + + @Override + public final void putBooleans(int rowId, int count, boolean value) { + byte v = (byte)((value) ? 1 : 0); + for (int i = 0; i < count; ++i) { + byteData[i + rowId] = v; + } + } + + @Override + public final boolean getBoolean(int rowId) { + return byteData[rowId] == 1; + } + + // + // // APIs dealing with Bytes // @@ -130,6 +156,33 @@ public final class OnHeapColumnVector extends ColumnVector { return byteData[rowId]; } + // + // APIs dealing with Shorts + // + + @Override + public final void putShort(int rowId, short value) { + shortData[rowId] = value; + } + + @Override + public final void putShorts(int rowId, int count, short value) { + for (int i = 0; i < count; ++i) { + shortData[i + rowId] = value; + } + } + + @Override + public final void putShorts(int rowId, int count, short[] src, int srcIndex) { + System.arraycopy(src, srcIndex, shortData, rowId, count); + } + + @Override + public final short getShort(int rowId) { + return shortData[rowId]; + } + + // // APIs dealing with Ints // @@ -202,6 +255,31 @@ public final class OnHeapColumnVector extends ColumnVector { return longData[rowId]; } + // + // APIs dealing with floats + // + + @Override + public final void putFloat(int rowId, float value) { floatData[rowId] = value; } + + @Override + public final void putFloats(int rowId, int count, float value) { + Arrays.fill(floatData, rowId, rowId + count, value); + } + + @Override + public final void putFloats(int rowId, int count, float[] src, int srcIndex) { + System.arraycopy(src, srcIndex, floatData, rowId, count); + } + + @Override + public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + floatData, Platform.DOUBLE_ARRAY_OFFSET + rowId * 4, count * 4); + } + + @Override + public final float getFloat(int rowId) { return floatData[rowId]; } // // APIs dealing with doubles @@ -277,7 +355,7 @@ public final class OnHeapColumnVector extends ColumnVector { // Spilt this function out since it is the slow path. private final void reserveInternal(int newCapacity) { - if (this.resultArray != null) { + if (this.resultArray != null || DecimalType.isByteArrayDecimalType(type)) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; if (this.arrayLengths != null) { @@ -286,18 +364,30 @@ public final class OnHeapColumnVector extends ColumnVector { } arrayLengths = newLengths; arrayOffsets = newOffsets; + } else if (type instanceof BooleanType) { + byte[] newData = new byte[newCapacity]; + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + byteData = newData; } else if (type instanceof ByteType) { byte[] newData = new byte[newCapacity]; if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); byteData = newData; + } else if (type instanceof ShortType) { + short[] newData = new short[newCapacity]; + if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); + shortData = newData; } else if (type instanceof IntegerType) { int[] newData = new int[newCapacity]; if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); intData = newData; - } else if (type instanceof LongType) { + } else if (type instanceof LongType || DecimalType.is64BitDecimalType(type)) { long[] newData = new long[newCapacity]; if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended); longData = newData; + } else if (type instanceof FloatType) { + float[] newData = new float[newCapacity]; + if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended); + floatData = newData; } else if (type instanceof DoubleType) { double[] newData = new double[newCapacity]; if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 67cc08b6fc8ba73c277a159f29568cd106afaae5..445f311107e337862f1be5081f815f98529973c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.types.CalendarInterval class ColumnarBatchSuite extends SparkFunSuite { test("Null Apis") { @@ -571,7 +572,6 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } - private def doubleEquals(d1: Double, d2: Double): Boolean = { if (d1.isNaN && d2.isNaN) { true @@ -585,13 +585,23 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(r1.isNullAt(v._2) == r2.isNullAt(v._2), "Seed = " + seed) if (!r1.isNullAt(v._2)) { v._1.dataType match { + case BooleanType => assert(r1.getBoolean(v._2) == r2.getBoolean(v._2), "Seed = " + seed) case ByteType => assert(r1.getByte(v._2) == r2.getByte(v._2), "Seed = " + seed) + case ShortType => assert(r1.getShort(v._2) == r2.getShort(v._2), "Seed = " + seed) case IntegerType => assert(r1.getInt(v._2) == r2.getInt(v._2), "Seed = " + seed) case LongType => assert(r1.getLong(v._2) == r2.getLong(v._2), "Seed = " + seed) + case FloatType => assert(doubleEquals(r1.getFloat(v._2), r2.getFloat(v._2)), + "Seed = " + seed) case DoubleType => assert(doubleEquals(r1.getDouble(v._2), r2.getDouble(v._2)), "Seed = " + seed) + case t: DecimalType => + val d1 = r1.getDecimal(v._2, t.precision, t.scale).toBigDecimal + val d2 = r2.getDecimal(v._2) + assert(d1.compare(d2) == 0, "Seed = " + seed) case StringType => assert(r1.getString(v._2) == r2.getString(v._2), "Seed = " + seed) + case CalendarIntervalType => + assert(r1.getInterval(v._2) === r2.get(v._2).asInstanceOf[CalendarInterval]) case ArrayType(childType, n) => val a1 = r1.getArray(v._2).array val a2 = r2.getList(v._2).toArray @@ -605,6 +615,27 @@ class ColumnarBatchSuite extends SparkFunSuite { i += 1 } } + case FloatType => { + var i = 0 + while (i < a1.length) { + assert(doubleEquals(a1(i).asInstanceOf[Float], a2(i).asInstanceOf[Float]), + "Seed = " + seed) + i += 1 + } + } + + case t: DecimalType => + var i = 0 + while (i < a1.length) { + assert((a1(i) == null) == (a2(i) == null), "Seed = " + seed) + if (a1(i) != null) { + val d1 = a1(i).asInstanceOf[Decimal].toBigDecimal + val d2 = a2(i).asInstanceOf[java.math.BigDecimal] + assert(d1.compare(d2) == 0, "Seed = " + seed) + } + i += 1 + } + case _ => assert(a1 === a2, "Seed = " + seed) } case StructType(childFields) => @@ -644,10 +675,13 @@ class ColumnarBatchSuite extends SparkFunSuite { * results. */ def testRandomRows(flatSchema: Boolean, numFields: Int) { - // TODO: add remaining types. Figure out why StringType doesn't work on jenkins. - val types = Array(ByteType, IntegerType, LongType, DoubleType) + // TODO: Figure out why StringType doesn't work on jenkins. + val types = Array( + BooleanType, ByteType, FloatType, DoubleType, + IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10), + CalendarIntervalType) val seed = System.nanoTime() - val NUM_ROWS = 500 + val NUM_ROWS = 200 val NUM_ITERS = 1000 val random = new Random(seed) var i = 0 @@ -682,7 +716,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } test("Random flat schema") { - testRandomRows(true, 10) + testRandomRows(true, 15) } test("Random nested schema") { diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index b29bf6a464b30864dc588e388f93561e2b8dfa62..18761bfd222a2292964aae5adc6ac2c396795a8e 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -27,10 +27,14 @@ public final class Platform { public static final int BYTE_ARRAY_OFFSET; + public static final int SHORT_ARRAY_OFFSET; + public static final int INT_ARRAY_OFFSET; public static final int LONG_ARRAY_OFFSET; + public static final int FLOAT_ARRAY_OFFSET; + public static final int DOUBLE_ARRAY_OFFSET; public static int getInt(Object object, long offset) { @@ -168,13 +172,17 @@ public final class Platform { if (_UNSAFE != null) { BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class); + SHORT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(short[].class); INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class); LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class); + FLOAT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(float[].class); DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class); } else { BYTE_ARRAY_OFFSET = 0; + SHORT_ARRAY_OFFSET = 0; INT_ARRAY_OFFSET = 0; LONG_ARRAY_OFFSET = 0; + FLOAT_ARRAY_OFFSET = 0; DOUBLE_ARRAY_OFFSET = 0; } }