From c4da5345a0ef643a7518756caaa18ff3f3ea9acc Mon Sep 17 00:00:00 2001 From: Davies Liu <davies@databricks.com> Date: Mon, 12 Oct 2015 21:12:59 -0700 Subject: [PATCH] [SPARK-10990] [SPARK-11018] [SQL] improve unrolling of complex types This PR improve the unrolling and read of complex types in columnar cache: 1) Using UnsafeProjection to do serialization of complex types, so they will not be serialized three times (two for actualSize) 2) Copy the bytes from UnsafeRow/UnsafeArrayData to ByteBuffer directly, avoiding the immediate byte[] 3) Using the underlying array in ByteBuffer to create UTF8String/UnsafeRow/UnsafeArrayData without copy. Combine these optimizations, we can reduce the unrolling time from 25s to 21s (20% less), reduce the scanning time from 3.5s to 2.5s (28% less). ``` df = sqlContext.read.parquet(path) t = time.time() df.cache() df.count() print 'unrolling', time.time() - t for i in range(10): t = time.time() print df.select("*")._jdf.queryExecution().toRdd().count() print time.time() - t ``` The schema is ``` root |-- a: struct (nullable = true) | |-- b: long (nullable = true) | |-- c: string (nullable = true) |-- d: array (nullable = true) | |-- element: long (containsNull = true) |-- e: map (nullable = true) | |-- key: long | |-- value: string (valueContainsNull = true) ``` Now the columnar cache depends on that UnsafeProjection support all the data types (including UDT), this PR also fix that. Author: Davies Liu <davies@databricks.com> Closes #9016 from davies/complex2. --- .../catalyst/expressions/UnsafeArrayData.java | 12 ++ .../sql/catalyst/expressions/UnsafeRow.java | 12 ++ .../expressions/codegen/CodeGenerator.scala | 5 + .../codegen/GenerateSafeProjection.scala | 1 + .../codegen/GenerateUnsafeProjection.scala | 29 ++- .../spark/sql/columnar/ColumnAccessor.scala | 9 +- .../spark/sql/columnar/ColumnType.scala | 187 +++++++++--------- .../columnar/InMemoryColumnarTableScan.scala | 6 +- .../spark/sql/columnar/ColumnTypeSuite.scala | 37 ++-- .../NullableColumnAccessorSuite.scala | 7 +- .../columnar/NullableColumnBuilderSuite.scala | 13 +- .../apache/spark/unsafe/types/UTF8String.java | 10 + 12 files changed, 188 insertions(+), 140 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index fdd9125613..796f8abec9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions; import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.ByteBuffer; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; @@ -145,6 +146,8 @@ public class UnsafeArrayData extends ArrayData { return getArray(ordinal); } else if (dataType instanceof MapType) { return getMap(ordinal); + } else if (dataType instanceof UserDefinedType) { + return get(ordinal, ((UserDefinedType)dataType).sqlType()); } else { throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); } @@ -306,6 +309,15 @@ public class UnsafeArrayData extends ArrayData { Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } + public void writeTo(ByteBuffer buffer) { + assert(buffer.hasArray()); + byte[] target = buffer.array(); + int offset = buffer.arrayOffset(); + int pos = buffer.position(); + writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos); + buffer.position(pos + sizeInBytes); + } + @Override public UnsafeArrayData copy() { UnsafeArrayData arrayCopy = new UnsafeArrayData(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 5af7ed5d6e..36859fbab9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions; import java.io.*; import java.math.BigDecimal; import java.math.BigInteger; +import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; @@ -326,6 +327,8 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS return getArray(ordinal); } else if (dataType instanceof MapType) { return getMap(ordinal); + } else if (dataType instanceof UserDefinedType) { + return get(ordinal, ((UserDefinedType)dataType).sqlType()); } else { throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); } @@ -602,6 +605,15 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes); } + public void writeTo(ByteBuffer buffer) { + assert (buffer.hasArray()); + byte[] target = buffer.array(); + int offset = buffer.arrayOffset(); + int pos = buffer.position(); + writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos); + buffer.position(pos + sizeInBytes); + } + @Override public void writeExternal(ObjectOutput out) throws IOException { byte[] bytes = getBytes(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a0fe5bd77e..7544d27e3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -129,6 +129,7 @@ class CodeGenContext { case _: ArrayType => s"$input.getArray($ordinal)" case _: MapType => s"$input.getMap($ordinal)" case NullType => "null" + case udt: UserDefinedType[_] => getValue(input, udt.sqlType, ordinal) case _ => s"($jt)$input.get($ordinal, null)" } } @@ -143,6 +144,7 @@ class CodeGenContext { case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" // The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes) case StringType => s"$row.update($ordinal, $value.clone())" + case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) case _ => s"$row.update($ordinal, $value)" } } @@ -177,6 +179,7 @@ class CodeGenContext { case _: MapType => "MapData" case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName + case udt: UserDefinedType[_] => javaType(udt.sqlType) case ObjectType(cls) if cls.isArray => s"${javaType(ObjectType(cls.getComponentType))}[]" case ObjectType(cls) => cls.getName case _ => "Object" @@ -222,6 +225,7 @@ class CodeGenContext { case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2" case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2" case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" + case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2) case other => s"$c1.equals($c2)" } @@ -255,6 +259,7 @@ class CodeGenContext { addNewFunction(compareFunc, funcCode) s"this.$compareFunc($c1, $c2)" case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" + case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) case _ => throw new IllegalArgumentException("cannot generate compare code for un-comparable type") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 9873630937..ee50587ed0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -124,6 +124,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe. case StringType => GeneratedExpressionCode("", "false", s"$input.clone()") + case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) case _ => GeneratedExpressionCode("", "false", input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 3e0e81733f..1b957a508d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -39,6 +39,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true + case dt: OpenHashSetUDT => false // it's not a standard UDT + case udt: UserDefinedType[_] => canSupport(udt.sqlType) case _ => false } @@ -77,7 +79,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();") val writeFields = inputs.zip(inputTypes).zipWithIndex.map { - case ((input, dt), index) => + case ((input, dataType), index) => + val dt = dataType match { + case udt: UserDefinedType[_] => udt.sqlType + case other => other + } val tmpCursor = ctx.freshName("tmpCursor") val setNull = dt match { @@ -167,15 +173,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val index = ctx.freshName("index") val element = ctx.freshName("element") - val jt = ctx.javaType(elementType) + val et = elementType match { + case udt: UserDefinedType[_] => udt.sqlType + case other => other + } + + val jt = ctx.javaType(et) - val fixedElementSize = elementType match { + val fixedElementSize = et match { case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8 - case _ if ctx.isPrimitiveType(jt) => elementType.defaultSize + case _ if ctx.isPrimitiveType(jt) => et.defaultSize case _ => 0 } - val writeElement = elementType match { + val writeElement = et match { case t: StructType => s""" $arrayWriter.setOffset($index); @@ -194,13 +205,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)} """ - case _ if ctx.isPrimitiveType(elementType) => + case _ if ctx.isPrimitiveType(et) => // Should we do word align? - val dataSize = elementType.defaultSize + val dataSize = et.defaultSize s""" $arrayWriter.setOffset($index); - ${writePrimitiveType(ctx, element, elementType, + ${writePrimitiveType(ctx, element, et, s"$bufferHolder.buffer", s"$bufferHolder.cursor")} $bufferHolder.cursor += $dataSize; """ @@ -237,7 +248,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro if ($input.isNullAt($index)) { $arrayWriter.setNullAt($index); } else { - final $jt $element = ${ctx.getValue(input, elementType, index)}; + final $jt $element = ${ctx.getValue(input, et, index)}; $writeElement } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 62478667eb..42ec4d3433 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.columnar import java.nio.{ByteBuffer, ByteOrder} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.MutableRow +import org.apache.spark.sql.catalyst.expressions.{MutableRow, UnsafeArrayData, UnsafeMapData, UnsafeRow} import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor import org.apache.spark.sql.types._ @@ -109,15 +108,15 @@ private[sql] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalTy with NullableColumnAccessor private[sql] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType) - extends BasicColumnAccessor[InternalRow](buffer, STRUCT(dataType)) + extends BasicColumnAccessor[UnsafeRow](buffer, STRUCT(dataType)) with NullableColumnAccessor private[sql] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType) - extends BasicColumnAccessor[ArrayData](buffer, ARRAY(dataType)) + extends BasicColumnAccessor[UnsafeArrayData](buffer, ARRAY(dataType)) with NullableColumnAccessor private[sql] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) - extends BasicColumnAccessor[MapData](buffer, MAP(dataType)) + extends BasicColumnAccessor[UnsafeMapData](buffer, MAP(dataType)) with NullableColumnAccessor private[sql] object ColumnAccessor { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 3563eacb3a..2bc2c96b61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.columnar import java.math.{BigDecimal, BigInteger} -import java.nio.{ByteOrder, ByteBuffer} +import java.nio.ByteBuffer import scala.reflect.runtime.universe.TypeTag @@ -92,7 +92,7 @@ private[sql] sealed abstract class ColumnType[JvmType] { * boxing/unboxing costs whenever possible. */ def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { - to.update(toOrdinal, from.get(fromOrdinal, dataType)) + setField(to, toOrdinal, getField(from, fromOrdinal)) } /** @@ -147,6 +147,7 @@ private[sql] object INT extends NativeColumnType(IntegerType, 4) { override def getField(row: InternalRow, ordinal: Int): Int = row.getInt(ordinal) + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setInt(toOrdinal, from.getInt(fromOrdinal)) } @@ -324,15 +325,18 @@ private[sql] object STRING extends NativeColumnType(StringType, 8) { } override def append(v: UTF8String, buffer: ByteBuffer): Unit = { - val stringBytes = v.getBytes - buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length) + buffer.putInt(v.numBytes()) + v.writeTo(buffer) } override def extract(buffer: ByteBuffer): UTF8String = { val length = buffer.getInt() - val stringBytes = new Array[Byte](length) - buffer.get(stringBytes, 0, length) - UTF8String.fromBytes(stringBytes) + assert(buffer.hasArray) + val base = buffer.array() + val offset = buffer.arrayOffset() + val cursor = buffer.position() + buffer.position(cursor + length) + UTF8String.fromBytes(base, offset + cursor, length) } override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = { @@ -386,11 +390,6 @@ private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: def serialize(value: JvmType): Array[Byte] def deserialize(bytes: Array[Byte]): JvmType - override def actualSize(row: InternalRow, ordinal: Int): Int = { - // TODO: grow the buffer in append(), so serialize() will not be called twice - serialize(getField(row, ordinal)).length + 4 - } - override def append(v: JvmType, buffer: ByteBuffer): Unit = { val bytes = serialize(v) buffer.putInt(bytes.length).put(bytes, 0, bytes.length) @@ -416,6 +415,10 @@ private[sql] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { row.getBinary(ordinal) } + override def actualSize(row: InternalRow, ordinal: Int): Int = { + row.getBinary(ordinal).length + 4 + } + def serialize(value: Array[Byte]): Array[Byte] = value def deserialize(bytes: Array[Byte]): Array[Byte] = bytes } @@ -433,6 +436,10 @@ private[sql] case class LARGE_DECIMAL(precision: Int, scale: Int) row.setDecimal(ordinal, value, precision) } + override def actualSize(row: InternalRow, ordinal: Int): Int = { + 4 + getField(row, ordinal).toJavaBigDecimal.unscaledValue().bitLength() / 8 + 1 + } + override def serialize(value: Decimal): Array[Byte] = { value.toJavaBigDecimal.unscaledValue().toByteArray } @@ -449,124 +456,118 @@ private[sql] object LARGE_DECIMAL { } } -private[sql] case class STRUCT(dataType: StructType) - extends ByteArrayColumnType[InternalRow](20) { +private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRow] { - private val projection: UnsafeProjection = - UnsafeProjection.create(dataType) private val numOfFields: Int = dataType.fields.size - override def setField(row: MutableRow, ordinal: Int, value: InternalRow): Unit = { + override def defaultSize: Int = 20 + + override def setField(row: MutableRow, ordinal: Int, value: UnsafeRow): Unit = { row.update(ordinal, value) } - override def getField(row: InternalRow, ordinal: Int): InternalRow = { - row.getStruct(ordinal, numOfFields) + override def getField(row: InternalRow, ordinal: Int): UnsafeRow = { + row.getStruct(ordinal, numOfFields).asInstanceOf[UnsafeRow] } - override def serialize(value: InternalRow): Array[Byte] = { - val unsafeRow = if (value.isInstanceOf[UnsafeRow]) { - value.asInstanceOf[UnsafeRow] - } else { - projection(value) - } - unsafeRow.getBytes + override def actualSize(row: InternalRow, ordinal: Int): Int = { + 4 + getField(row, ordinal).getSizeInBytes } - override def deserialize(bytes: Array[Byte]): InternalRow = { + override def append(value: UnsafeRow, buffer: ByteBuffer): Unit = { + buffer.putInt(value.getSizeInBytes) + value.writeTo(buffer) + } + + override def extract(buffer: ByteBuffer): UnsafeRow = { + val sizeInBytes = buffer.getInt() + assert(buffer.hasArray) + val base = buffer.array() + val offset = buffer.arrayOffset() + val cursor = buffer.position() + buffer.position(cursor + sizeInBytes) val unsafeRow = new UnsafeRow - unsafeRow.pointTo(bytes, numOfFields, bytes.length) + unsafeRow.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor, numOfFields, sizeInBytes) unsafeRow } - override def clone(v: InternalRow): InternalRow = v.copy() + override def clone(v: UnsafeRow): UnsafeRow = v.copy() } -private[sql] case class ARRAY(dataType: ArrayType) - extends ByteArrayColumnType[ArrayData](16) { +private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArrayData] { - private lazy val projection = UnsafeProjection.create(Array[DataType](dataType)) - private val mutableRow = new GenericMutableRow(new Array[Any](1)) + override def defaultSize: Int = 16 - override def setField(row: MutableRow, ordinal: Int, value: ArrayData): Unit = { + override def setField(row: MutableRow, ordinal: Int, value: UnsafeArrayData): Unit = { row.update(ordinal, value) } - override def getField(row: InternalRow, ordinal: Int): ArrayData = { - row.getArray(ordinal) + override def getField(row: InternalRow, ordinal: Int): UnsafeArrayData = { + row.getArray(ordinal).asInstanceOf[UnsafeArrayData] } - override def serialize(value: ArrayData): Array[Byte] = { - val unsafeArray = if (value.isInstanceOf[UnsafeArrayData]) { - value.asInstanceOf[UnsafeArrayData] - } else { - mutableRow(0) = value - projection(mutableRow).getArray(0) - } - val outputBuffer = - ByteBuffer.allocate(4 + unsafeArray.getSizeInBytes).order(ByteOrder.nativeOrder()) - outputBuffer.putInt(unsafeArray.numElements()) - val underlying = outputBuffer.array() - unsafeArray.writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 4) - underlying + override def actualSize(row: InternalRow, ordinal: Int): Int = { + val unsafeArray = getField(row, ordinal) + 4 + 4 + unsafeArray.getSizeInBytes } - override def deserialize(bytes: Array[Byte]): ArrayData = { - val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder()) - val numElements = buffer.getInt - val array = new UnsafeArrayData - array.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 4, numElements, bytes.length - 4) - array + override def append(value: UnsafeArrayData, buffer: ByteBuffer): Unit = { + buffer.putInt(4 + value.getSizeInBytes) + buffer.putInt(value.numElements()) + value.writeTo(buffer) } - override def clone(v: ArrayData): ArrayData = v.copy() + override def extract(buffer: ByteBuffer): UnsafeArrayData = { + val numBytes = buffer.getInt + assert(buffer.hasArray) + val cursor = buffer.position() + buffer.position(cursor + numBytes) + UnsafeReaders.readArray( + buffer.array(), + Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, + numBytes) + } + + override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy() } -private[sql] case class MAP(dataType: MapType) extends ByteArrayColumnType[MapData](32) { +private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData] { - private lazy val projection: UnsafeProjection = UnsafeProjection.create(Array[DataType](dataType)) - private val mutableRow = new GenericMutableRow(new Array[Any](1)) + override def defaultSize: Int = 32 - override def setField(row: MutableRow, ordinal: Int, value: MapData): Unit = { + override def setField(row: MutableRow, ordinal: Int, value: UnsafeMapData): Unit = { row.update(ordinal, value) } - override def getField(row: InternalRow, ordinal: Int): MapData = { - row.getMap(ordinal) + override def getField(row: InternalRow, ordinal: Int): UnsafeMapData = { + row.getMap(ordinal).asInstanceOf[UnsafeMapData] } - override def serialize(value: MapData): Array[Byte] = { - val unsafeMap = if (value.isInstanceOf[UnsafeMapData]) { - value.asInstanceOf[UnsafeMapData] - } else { - mutableRow(0) = value - projection(mutableRow).getMap(0) - } + override def actualSize(row: InternalRow, ordinal: Int): Int = { + val unsafeMap = getField(row, ordinal) + 12 + unsafeMap.keyArray().getSizeInBytes + unsafeMap.valueArray().getSizeInBytes + } + + override def append(value: UnsafeMapData, buffer: ByteBuffer): Unit = { + buffer.putInt(8 + value.keyArray().getSizeInBytes + value.valueArray().getSizeInBytes) + buffer.putInt(value.numElements()) + buffer.putInt(value.keyArray().getSizeInBytes) + value.keyArray().writeTo(buffer) + value.valueArray().writeTo(buffer) + } + + override def extract(buffer: ByteBuffer): UnsafeMapData = { + val numBytes = buffer.getInt + assert(buffer.hasArray) + val cursor = buffer.position() + buffer.position(cursor + numBytes) + UnsafeReaders.readMap( + buffer.array(), + Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, + numBytes) + } - val outputBuffer = - ByteBuffer.allocate(8 + unsafeMap.getSizeInBytes).order(ByteOrder.nativeOrder()) - outputBuffer.putInt(unsafeMap.numElements()) - val keyBytes = unsafeMap.keyArray().getSizeInBytes - outputBuffer.putInt(keyBytes) - val underlying = outputBuffer.array() - unsafeMap.keyArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8) - unsafeMap.valueArray().writeToMemory(underlying, Platform.BYTE_ARRAY_OFFSET + 8 + keyBytes) - underlying - } - - override def deserialize(bytes: Array[Byte]): MapData = { - val buffer = ByteBuffer.wrap(bytes).order(ByteOrder.nativeOrder()) - val numElements = buffer.getInt - val keyArraySize = buffer.getInt - val keyArray = new UnsafeArrayData - val valueArray = new UnsafeArrayData - keyArray.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 8, numElements, keyArraySize) - valueArray.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + 8 + keyArraySize, numElements, - bytes.length - 8 - keyArraySize) - new UnsafeMapData(keyArray, valueArray) - } - - override def clone(v: MapData): MapData = v.copy() + override def clone(v: UnsafeMapData): UnsafeMapData = v.copy() } private[sql] object ColumnType { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index d7e145f9c2..d967814f62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.execution.{LeafNode, SparkPlan} +import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan} import org.apache.spark.storage.StorageLevel import org.apache.spark.{Accumulable, Accumulator, Accumulators} @@ -38,7 +38,9 @@ private[sql] object InMemoryRelation { storageLevel: StorageLevel, child: SparkPlan, tableName: Option[String]): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)() + new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, + if (child.outputsUnsafeRows) child else ConvertToUnsafe(child), + tableName)() } private[sql] case class CachedBatch(buffers: Array[Array[Byte]], stats: InternalRow) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index ceb8ad97bb..0e6e1bcf72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.columnar -import java.nio.ByteBuffer +import java.nio.{ByteOrder, ByteBuffer} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types._ import org.apache.spark.{Logging, SparkFunSuite} @@ -55,7 +55,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { assertResult(expected, s"Wrong actualSize for $columnType") { val row = new GenericMutableRow(1) row.update(0, CatalystTypeConverters.convertToCatalyst(value)) - columnType.actualSize(row, 0) + val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) + columnType.actualSize(proj(row), 0) } } @@ -99,35 +100,27 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { def testColumnType[JvmType](columnType: ColumnType[JvmType]): Unit = { - val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE) - val seq = (0 until 4).map(_ => makeRandomValue(columnType)) + val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE).order(ByteOrder.nativeOrder()) + val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) val converter = CatalystTypeConverters.createToScalaConverter(columnType.dataType) + val seq = (0 until 4).map(_ => proj(makeRandomRow(columnType)).copy()) test(s"$columnType append/extract") { buffer.rewind() - seq.foreach(columnType.append(_, buffer)) + seq.foreach(columnType.append(_, 0, buffer)) buffer.rewind() - seq.foreach { expected => - logInfo("buffer = " + buffer + ", expected = " + expected) - val extracted = columnType.extract(buffer) - assert( - converter(expected) === converter(extracted), - "Extracted value didn't equal to the original one. " + - hexDump(expected) + " != " + hexDump(extracted) + - ", buffer = " + dumpBuffer(buffer.duplicate().rewind().asInstanceOf[ByteBuffer])) + seq.foreach { row => + logInfo("buffer = " + buffer + ", expected = " + row) + val expected = converter(row.get(0, columnType.dataType)) + val extracted = converter(columnType.extract(buffer)) + assert(expected === extracted, + s"Extracted value didn't equal to the original one. $expected != $extracted, buffer =" + + dumpBuffer(buffer.duplicate().rewind().asInstanceOf[ByteBuffer])) } } } - private def hexDump(value: Any): String = { - if (value == null) { - "" - } else { - value.toString.map(ch => Integer.toHexString(ch & 0xffff)).mkString(" ") - } - } - private def dumpBuffer(buff: ByteBuffer): Any = { val sb = new StringBuilder() while (buff.hasRemaining) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 78cebbf3cc..aa1605fee8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} import org.apache.spark.sql.types._ class TestNullableColumnAccessor[JvmType]( @@ -64,10 +64,11 @@ class NullableColumnAccessorSuite extends SparkFunSuite { test(s"Nullable $typeName column accessor: access null values") { val builder = TestNullableColumnBuilder(columnType) val randomRow = makeRandomRow(columnType) + val proj = UnsafeProjection.create(Array[DataType](columnType.dataType)) (0 until 4).foreach { _ => - builder.appendFrom(randomRow, 0) - builder.appendFrom(nullRow, 0) + builder.appendFrom(proj(randomRow), 0) + builder.appendFrom(proj(nullRow), 0) } val accessor = TestNullableColumnAccessor(builder.build(), columnType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index fba08e626d..9140457783 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} import org.apache.spark.sql.types._ class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType]) @@ -51,6 +51,9 @@ class NullableColumnBuilderSuite extends SparkFunSuite { columnType: ColumnType[JvmType]): Unit = { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") + val dataType = columnType.dataType + val proj = UnsafeProjection.create(Array[DataType](dataType)) + val converter = CatalystTypeConverters.createToScalaConverter(dataType) test(s"$typeName column builder: empty column") { val columnBuilder = TestNullableColumnBuilder(columnType) @@ -65,7 +68,7 @@ class NullableColumnBuilderSuite extends SparkFunSuite { val randomRow = makeRandomRow(columnType) (0 until 4).foreach { _ => - columnBuilder.appendFrom(randomRow, 0) + columnBuilder.appendFrom(proj(randomRow), 0) } val buffer = columnBuilder.build() @@ -77,12 +80,10 @@ class NullableColumnBuilderSuite extends SparkFunSuite { val columnBuilder = TestNullableColumnBuilder(columnType) val randomRow = makeRandomRow(columnType) val nullRow = makeNullRow(1) - val dataType = columnType.dataType - val converter = CatalystTypeConverters.createToScalaConverter(dataType) (0 until 4).foreach { _ => - columnBuilder.appendFrom(randomRow, 0) - columnBuilder.appendFrom(nullRow, 0) + columnBuilder.appendFrom(proj(randomRow), 0) + columnBuilder.appendFrom(proj(nullRow), 0) } val buffer = columnBuilder.build() diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 216aeea60d..b7aecb5102 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -19,6 +19,7 @@ package org.apache.spark.unsafe.types; import javax.annotation.Nonnull; import java.io.*; +import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; import java.util.Map; @@ -137,6 +138,15 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable Platform.copyMemory(base, offset, target, targetOffset, numBytes); } + public void writeTo(ByteBuffer buffer) { + assert(buffer.hasArray()); + byte[] target = buffer.array(); + int offset = buffer.arrayOffset(); + int pos = buffer.position(); + writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos); + buffer.position(pos + numBytes); + } + /** * Returns the number of bytes for a code point with the first byte as `b` * @param b The first byte of a code point -- GitLab