From ec8973d1245d4a99edeb7365d7f4b0063ac31ddf Mon Sep 17 00:00:00 2001 From: Davies Liu <davies@databricks.com> Date: Fri, 17 Jul 2015 01:27:14 -0700 Subject: [PATCH] [SPARK-9022] [SQL] Generated projections for UnsafeRow Added two projections: GenerateUnsafeProjection and FromUnsafeProjection, which could be used to convert UnsafeRow from/to GenericInternalRow. They will re-use the buffer during projection, similar to MutableProjection (without all the interface MutableProjection has). cc rxin JoshRosen Author: Davies Liu <davies@databricks.com> Closes #7437 from davies/unsafe_proj2 and squashes the following commits: dbf538e [Davies Liu] test with all the expression (only for supported types) dc737b2 [Davies Liu] address comment e424520 [Davies Liu] fix scala style 70e231c [Davies Liu] address comments 729138d [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_proj2 5a26373 [Davies Liu] unsafe projections --- .../execution/UnsafeExternalRowSorter.java | 27 ++-- .../spark/sql/catalyst/expressions/Cast.scala | 8 +- .../sql/catalyst/expressions/Projection.scala | 35 +++++ .../expressions/UnsafeRowConverter.scala | 69 +++++----- .../expressions/codegen/CodeGenerator.scala | 15 ++- .../codegen/GenerateProjection.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 125 ++++++++++++++++++ .../expressions/decimalFunctions.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 2 +- .../spark/sql/catalyst/expressions/misc.scala | 17 ++- .../expressions/ExpressionEvalHelper.scala | 34 ++++- 11 files changed, 266 insertions(+), 72 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index b94601cf6d..d1d81c87bb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -28,13 +28,11 @@ import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; import org.apache.spark.sql.AbstractScalaRowIterator; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.ObjectUnsafeColumnWriter; import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter; +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter; import org.apache.spark.sql.catalyst.util.ObjectPool; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; import org.apache.spark.util.collection.unsafe.sort.RecordComparator; @@ -52,10 +50,9 @@ final class UnsafeExternalRowSorter { private long numRowsInserted = 0; private final StructType schema; - private final UnsafeRowConverter rowConverter; + private final UnsafeProjection unsafeProjection; private final PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; - private byte[] rowConversionBuffer = new byte[1024 * 8]; public static abstract class PrefixComputer { abstract long computePrefix(InternalRow row); @@ -67,7 +64,7 @@ final class UnsafeExternalRowSorter { PrefixComparator prefixComparator, PrefixComputer prefixComputer) throws IOException { this.schema = schema; - this.rowConverter = new UnsafeRowConverter(schema); + this.unsafeProjection = UnsafeProjection.create(schema); this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); final TaskContext taskContext = TaskContext.get(); @@ -94,18 +91,12 @@ final class UnsafeExternalRowSorter { @VisibleForTesting void insertRow(InternalRow row) throws IOException { - final int sizeRequirement = rowConverter.getSizeRequirement(row); - if (sizeRequirement > rowConversionBuffer.length) { - rowConversionBuffer = new byte[sizeRequirement]; - } - final int bytesWritten = rowConverter.writeRow( - row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, sizeRequirement, null); - assert (bytesWritten == sizeRequirement); + UnsafeRow unsafeRow = unsafeProjection.apply(row); final long prefix = prefixComputer.computePrefix(row); sorter.insertRecord( - rowConversionBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, - sizeRequirement, + unsafeRow.getBaseObject(), + unsafeRow.getBaseOffset(), + unsafeRow.getSizeInBytes(), prefix ); numRowsInserted++; @@ -186,7 +177,7 @@ final class UnsafeExternalRowSorter { public static boolean supportsSchema(StructType schema) { // TODO: add spilling note to explain why we do this for now: for (StructField field : schema.fields()) { - if (UnsafeColumnWriter.forType(field.dataType()) instanceof ObjectUnsafeColumnWriter) { + if (!UnsafeColumnWriter.canEmbed(field.dataType())) { return false; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 65ae87fe6d..692b9fddbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -424,20 +424,20 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (BinaryType, StringType) => defineCodeGen (ctx, ev, c => - s"${ctx.stringType}.fromBytes($c)") + s"UTF8String.fromBytes($c)") case (DateType, StringType) => defineCodeGen(ctx, ev, c => - s"""${ctx.stringType}.fromString( + s"""UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""") case (TimestampType, StringType) => defineCodeGen(ctx, ev, c => - s"""${ctx.stringType}.fromString( + s"""UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""") case (_, StringType) => - defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))") + defineCodeGen(ctx, ev, c => s"UTF8String.fromString(String.valueOf($c))") case (StringType, IntervalType) => defineCodeGen(ctx, ev, c => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index bf47a6c75b..24b01ea551 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection} +import org.apache.spark.sql.types.{StructType, DataType} /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. @@ -73,6 +75,39 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu } } +/** + * A projection that returns UnsafeRow. + */ +abstract class UnsafeProjection extends Projection { + override def apply(row: InternalRow): UnsafeRow +} + +object UnsafeProjection { + def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType)) + + def create(fields: Seq[DataType]): UnsafeProjection = { + val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)) + GenerateUnsafeProjection.generate(exprs) + } +} + +/** + * A projection that could turn UnsafeRow into GenericInternalRow + */ +case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection { + + private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) => + new BoundReference(idx, dt, true) + } + + @transient private[this] lazy val generatedProj = + GenerateMutableProjection.generate(expressions)() + + override def apply(input: InternalRow): InternalRow = { + generatedProj(input) + } +} + /** * A mutable wrapper that makes two rows appear as a single concatenated row. Designed to * be instantiated once per thread and reused. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 6af5e6200e..885ab091fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -147,77 +147,73 @@ private object UnsafeColumnWriter { case t => ObjectUnsafeColumnWriter } } + + /** + * Returns whether the dataType can be embedded into UnsafeRow (not using ObjectPool). + */ + def canEmbed(dataType: DataType): Boolean = { + forType(dataType) != ObjectUnsafeColumnWriter + } } // ------------------------------------------------------------------------------------------------ -private object NullUnsafeColumnWriter extends NullUnsafeColumnWriter -private object BooleanUnsafeColumnWriter extends BooleanUnsafeColumnWriter -private object ByteUnsafeColumnWriter extends ByteUnsafeColumnWriter -private object ShortUnsafeColumnWriter extends ShortUnsafeColumnWriter -private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter -private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter -private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter -private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter -private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter -private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter -private object ObjectUnsafeColumnWriter extends ObjectUnsafeColumnWriter private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { // Primitives don't write to the variable-length region: def getSize(sourceRow: InternalRow, column: Int): Int = 0 } -private class NullUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object NullUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setNullAt(column) 0 } } -private class BooleanUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object BooleanUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setBoolean(column, source.getBoolean(column)) 0 } } -private class ByteUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object ByteUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setByte(column, source.getByte(column)) 0 } } -private class ShortUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object ShortUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setShort(column, source.getShort(column)) 0 } } -private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object IntUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setInt(column, source.getInt(column)) 0 } } -private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object LongUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setLong(column, source.getLong(column)) 0 } } -private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object FloatUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setFloat(column, source.getFloat(column)) 0 } } -private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { +private object DoubleUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setDouble(column, source.getDouble(column)) 0 @@ -226,18 +222,21 @@ private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWr private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { - def getBytes(source: InternalRow, column: Int): Array[Byte] + protected[this] def isString: Boolean + protected[this] def getBytes(source: InternalRow, column: Int): Array[Byte] - def getSize(source: InternalRow, column: Int): Int = { + override def getSize(source: InternalRow, column: Int): Int = { val numBytes = getBytes(source, column).length ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } - protected[this] def isString: Boolean - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - val offset = target.getBaseOffset + cursor val bytes = getBytes(source, column) + write(target, bytes, column, cursor) + } + + def write(target: UnsafeRow, bytes: Array[Byte], column: Int, cursor: Int): Int = { + val offset = target.getBaseOffset + cursor val numBytes = bytes.length if ((numBytes & 0x07) > 0) { // zero-out the padding bytes @@ -256,22 +255,32 @@ private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { } } -private class StringUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { +private object StringUnsafeColumnWriter extends BytesUnsafeColumnWriter { protected[this] def isString: Boolean = true def getBytes(source: InternalRow, column: Int): Array[Byte] = { source.getAs[UTF8String](column).getBytes } + // TODO(davies): refactor this + // specialized for codegen + def getSize(value: UTF8String): Int = + ByteArrayMethods.roundNumberOfBytesToNearestWord(value.numBytes()) + def write(target: UnsafeRow, value: UTF8String, column: Int, cursor: Int): Int = { + write(target, value.getBytes, column, cursor) + } } -private class BinaryUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { - protected[this] def isString: Boolean = false - def getBytes(source: InternalRow, column: Int): Array[Byte] = { +private object BinaryUnsafeColumnWriter extends BytesUnsafeColumnWriter { + protected[this] override def isString: Boolean = false + override def getBytes(source: InternalRow, column: Int): Array[Byte] = { source.getAs[Array[Byte]](column) } + // specialized for codegen + def getSize(value: Array[Byte]): Int = + ByteArrayMethods.roundNumberOfBytesToNearestWord(value.length) } -private class ObjectUnsafeColumnWriter private() extends UnsafeColumnWriter { - def getSize(sourceRow: InternalRow, column: Int): Int = 0 +private object ObjectUnsafeColumnWriter extends UnsafeColumnWriter { + override def getSize(sourceRow: InternalRow, column: Int): Int = 0 override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { val obj = source.get(column) val idx = target.getPool.put(obj) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 328d635de8..45dc146488 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -24,6 +24,7 @@ import com.google.common.cache.{CacheBuilder, CacheLoader} import org.codehaus.janino.ClassBodyEvaluator import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -68,9 +69,6 @@ class CodeGenContext { mutableStates += ((javaType, variableName, initialValue)) } - val stringType: String = classOf[UTF8String].getName - val decimalType: String = classOf[Decimal].getName - final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -136,9 +134,9 @@ class CodeGenContext { case LongType | TimestampType => JAVA_LONG case FloatType => JAVA_FLOAT case DoubleType => JAVA_DOUBLE - case dt: DecimalType => decimalType + case dt: DecimalType => "Decimal" case BinaryType => "byte[]" - case StringType => stringType + case StringType => "UTF8String" case _: StructType => "InternalRow" case _: ArrayType => s"scala.collection.Seq" case _: MapType => s"scala.collection.Map" @@ -262,7 +260,12 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() evaluator.setParentClassLoader(getClass.getClassLoader) - evaluator.setDefaultImports(Array("org.apache.spark.sql.catalyst.InternalRow")) + evaluator.setDefaultImports(Array( + classOf[InternalRow].getName, + classOf[UnsafeRow].getName, + classOf[UTF8String].getName, + classOf[Decimal].getName + )) evaluator.setExtendedClass(classOf[GeneratedClass]) try { evaluator.cook(code) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 3e5ca308dc..8f9fcbf810 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.types._ /** * Java can not access Projection (in package object) */ -abstract class BaseProject extends Projection {} +abstract class BaseProjection extends Projection {} /** * Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input @@ -160,7 +160,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return new SpecificProjection(expr); } - class SpecificProjection extends ${classOf[BaseProject].getName} { + class SpecificProjection extends ${classOf[BaseProjection].getName} { private $exprType[] expressions = null; $mutableStates diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala new file mode 100644 index 0000000000..a81d545a8e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{NullType, BinaryType, StringType} + + +/** + * Generates a [[Projection]] that returns an [[UnsafeRow]]. + * + * It generates the code for all the expressions, compute the total length for all the columns + * (can be accessed via variables), and then copy the data into a scratch buffer space in the + * form of UnsafeRow (the scratch buffer will grow as needed). + * + * Note: The returned UnsafeRow will be pointed to a scratch buffer inside the projection. + */ +object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] { + + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = + in.map(ExpressionCanonicalizer.execute) + + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = + in.map(BindReferences.bindReference(_, inputSchema)) + + protected def create(expressions: Seq[Expression]): UnsafeProjection = { + val ctx = newCodeGenContext() + val exprs = expressions.map(_.gen(ctx)) + val allExprs = exprs.map(_.code).mkString("\n") + val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) + val stringWriter = "org.apache.spark.sql.catalyst.expressions.StringUnsafeColumnWriter" + val binaryWriter = "org.apache.spark.sql.catalyst.expressions.BinaryUnsafeColumnWriter" + val additionalSize = expressions.zipWithIndex.map { case (e, i) => + e.dataType match { + case StringType => + s" + (${exprs(i).isNull} ? 0 : $stringWriter.getSize(${exprs(i).primitive}))" + case BinaryType => + s" + (${exprs(i).isNull} ? 0 : $binaryWriter.getSize(${exprs(i).primitive}))" + case _ => "" + } + }.mkString("") + + val writers = expressions.zipWithIndex.map { case (e, i) => + val update = e.dataType match { + case dt if ctx.isPrimitiveType(dt) => + s"${ctx.setColumn("target", dt, i, exprs(i).primitive)}" + case StringType => + s"cursor += $stringWriter.write(target, ${exprs(i).primitive}, $i, cursor)" + case BinaryType => + s"cursor += $binaryWriter.write(target, ${exprs(i).primitive}, $i, cursor)" + case NullType => "" + case _ => + throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}") + } + s"""if (${exprs(i).isNull}) { + target.setNullAt($i); + } else { + $update; + }""" + }.mkString("\n ") + + val mutableStates = ctx.mutableStates.map { case (javaType, variableName, initialValue) => + s"private $javaType $variableName = $initialValue;" + }.mkString("\n ") + + val code = s""" + private $exprType[] expressions; + + public Object generate($exprType[] expr) { + this.expressions = expr; + return new SpecificProjection(); + } + + class SpecificProjection extends ${classOf[UnsafeProjection].getName} { + + private UnsafeRow target = new UnsafeRow(); + private byte[] buffer = new byte[64]; + + $mutableStates + + public SpecificProjection() {} + + // Scala.Function1 need this + public Object apply(Object row) { + return apply((InternalRow) row); + } + + public UnsafeRow apply(InternalRow i) { + ${allExprs} + + // additionalSize had '+' in the beginning + int numBytes = $fixedSize $additionalSize; + if (numBytes > buffer.length) { + buffer = new byte[numBytes]; + } + target.pointTo(buffer, org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, + ${expressions.size}, numBytes, null); + int cursor = $fixedSize; + $writers + return target; + } + } + """ + + logDebug(s"code for ${expressions.mkString(",")}:\n$code") + + val c = compile(code) + c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 2fa74b4ffc..b9d4736a65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -54,7 +54,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, eval => { s""" - ${ev.primitive} = (new ${ctx.decimalType}()).setOrNull($eval, $precision, $scale); + ${ev.primitive} = (new Decimal()).setOrNull($eval, $precision, $scale); ${ev.isNull} = ${ev.primitive} == null; """ }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index a7ad452ef4..84b289c4d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -263,7 +263,7 @@ case class Bin(child: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, (c) => - s"${ctx.stringType}.fromString(java.lang.Long.toBinaryString($c))") + s"UTF8String.fromString(java.lang.Long.toBinaryString($c))") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index a269ec4a1e..8d8d66ddeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.catalyst.expressions -import java.security.MessageDigest -import java.security.NoSuchAlgorithmException +import java.security.{MessageDigest, NoSuchAlgorithmException} import java.util.zip.CRC32 import org.apache.commons.codec.digest.DigestUtils -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult + import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -42,7 +41,7 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => - s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") + s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") } } @@ -93,19 +92,19 @@ case class Sha2(left: Expression, right: Expression) try { java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224"); md.update($eval1); - ${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest()); + ${ev.primitive} = UTF8String.fromBytes(md.digest()); } catch (java.security.NoSuchAlgorithmException e) { ${ev.isNull} = true; } } else if ($eval2 == 256 || $eval2 == 0) { ${ev.primitive} = - ${ctx.stringType}.fromString($digestUtils.sha256Hex($eval1)); + UTF8String.fromString($digestUtils.sha256Hex($eval1)); } else if ($eval2 == 384) { ${ev.primitive} = - ${ctx.stringType}.fromString($digestUtils.sha384Hex($eval1)); + UTF8String.fromString($digestUtils.sha384Hex($eval1)); } else if ($eval2 == 512) { ${ev.primitive} = - ${ctx.stringType}.fromString($digestUtils.sha512Hex($eval1)); + UTF8String.fromString($digestUtils.sha512Hex($eval1)); } else { ${ev.isNull} = true; } @@ -129,7 +128,7 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => - s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.shaHex($c))" + s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.shaHex($c))" ) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 43392df4be..c43486b3dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -23,7 +23,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} @@ -43,6 +43,9 @@ trait ExpressionEvalHelper { checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow) + if (UnsafeColumnWriter.canEmbed(expression.dataType)) { + checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow) + } checkEvaluationWithOptimization(expression, catalystValue, inputRow) } @@ -142,6 +145,35 @@ trait ExpressionEvalHelper { } } + protected def checkEvalutionWithUnsafeProjection( + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { + val ctx = GenerateUnsafeProjection.newCodeGenContext() + lazy val evaluated = expression.gen(ctx) + + val plan = try { + GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) + } catch { + case e: Throwable => + fail( + s""" + |Code generation of $expression failed: + |${evaluated.code} + |$e + """.stripMargin) + } + + val unsafeRow = plan(inputRow) + // UnsafeRow cannot be compared with GenericInternalRow directly + val actual = FromUnsafeProjection(expression.dataType :: Nil)(unsafeRow) + val expectedRow = InternalRow(expected) + if (actual != expectedRow) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } + protected def checkEvaluationWithOptimization( expression: Expression, expected: Any, -- GitLab