From 75438422c2cd90dca53f84879cddecfc2ee0e957 Mon Sep 17 00:00:00 2001 From: Wenchen Fan <cloud0fan@outlook.com> Date: Mon, 27 Jul 2015 11:28:22 -0700 Subject: [PATCH] [SPARK-9369][SQL] Support IntervalType in UnsafeRow Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7688 from cloud-fan/interval and squashes the following commits: 5b36b17 [Wenchen Fan] fix codegen a99ed50 [Wenchen Fan] address comment 9e6d319 [Wenchen Fan] Support IntervalType in UnsafeRow --- .../sql/catalyst/expressions/UnsafeRow.java | 23 ++++++++++++++----- .../expressions/UnsafeRowWriters.java | 19 ++++++++++++++- .../spark/sql/catalyst/InternalRow.scala | 4 +++- .../catalyst/expressions/BoundAttribute.scala | 1 + .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../expressions/codegen/CodeGenerator.scala | 7 +++--- .../codegen/GenerateUnsafeProjection.scala | 6 +++++ .../expressions/ExpressionEvalHelper.scala | 2 -- 8 files changed, 50 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 0fb33dd5a1..fb084dd13b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -29,6 +29,7 @@ import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.types.Interval; import org.apache.spark.unsafe.types.UTF8String; import static org.apache.spark.sql.types.DataTypes.*; @@ -90,7 +91,8 @@ public final class UnsafeRow extends MutableRow { final Set<DataType> _readableFieldTypes = new HashSet<>( Arrays.asList(new DataType[]{ StringType, - BinaryType + BinaryType, + IntervalType })); _readableFieldTypes.addAll(settableFieldTypes); readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); @@ -332,11 +334,6 @@ public final class UnsafeRow extends MutableRow { return isNullAt(ordinal) ? null : UTF8String.fromBytes(getBinary(ordinal)); } - @Override - public String getString(int ordinal) { - return getUTF8String(ordinal).toString(); - } - @Override public byte[] getBinary(int ordinal) { if (isNullAt(ordinal)) { @@ -358,6 +355,20 @@ public final class UnsafeRow extends MutableRow { } } + @Override + public Interval getInterval(int ordinal) { + if (isNullAt(ordinal)) { + return null; + } else { + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + final long microseconds = + PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8); + return new Interval(months, microseconds); + } + } + @Override public UnsafeRow getStruct(int ordinal, int numFields) { if (isNullAt(ordinal)) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index 87521d1f23..0ba31d3b9b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.ByteArray; +import org.apache.spark.unsafe.types.Interval; import org.apache.spark.unsafe.types.UTF8String; /** @@ -54,7 +55,7 @@ public class UnsafeRowWriters { } } - /** Writer for bianry (byte array) type. */ + /** Writer for binary (byte array) type. */ public static class BinaryWriter { public static int getSize(byte[] input) { @@ -80,4 +81,20 @@ public class UnsafeRowWriters { } } + /** Writer for interval type. */ + public static class IntervalWriter { + + public static int write(UnsafeRow target, int ordinal, int cursor, Interval input) { + final long offset = target.getBaseOffset() + cursor; + + // Write the months and microseconds fields of Interval to the variable length portion. + PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, input.months); + PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, input.microseconds); + + // Set the fixed length portion. + target.setLong(ordinal, ((long) cursor) << 32); + return 16; + } + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index ad3977281d..9a11de3840 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{Interval, UTF8String} /** * An abstract class for row used internal in Spark SQL, which only contain the columns as @@ -60,6 +60,8 @@ abstract class InternalRow extends Serializable { def getDecimal(ordinal: Int): Decimal = getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT) + def getInterval(ordinal: Int): Interval = getAs[Interval](ordinal, IntervalType) + // This is only use for test and will throw a null pointer exception if the position is null. def getString(ordinal: Int): String = getUTF8String(ordinal).toString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 6b5c450e3f..41a877f214 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -48,6 +48,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) case DoubleType => input.getDouble(ordinal) case StringType => input.getUTF8String(ordinal) case BinaryType => input.getBinary(ordinal) + case IntervalType => input.getInterval(ordinal) case t: StructType => input.getStruct(ordinal, t.size) case dataType => input.get(ordinal, dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index e208262da9..bd8b0177eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -630,7 +630,7 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castToIntervalCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s"$evPrim = org.apache.spark.unsafe.types.Interval.fromString($c.toString());" + s"$evPrim = Interval.fromString($c.toString());" } private[this] def decimalToTimestampCode(d: String): String = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2a1e288cb8..2f02c90b1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -79,7 +79,6 @@ class CodeGenContext { mutableStates += ((javaType, variableName, initCode)) } - final val intervalType: String = classOf[Interval].getName final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -109,6 +108,7 @@ class CodeGenContext { case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)" case StringType => s"$row.getUTF8String($ordinal)" case BinaryType => s"$row.getBinary($ordinal)" + case IntervalType => s"$row.getInterval($ordinal)" case t: StructType => s"$row.getStruct($ordinal, ${t.size})" case _ => s"($jt)$row.get($ordinal)" } @@ -150,7 +150,7 @@ class CodeGenContext { case dt: DecimalType => "Decimal" case BinaryType => "byte[]" case StringType => "UTF8String" - case IntervalType => intervalType + case IntervalType => "Interval" case _: StructType => "InternalRow" case _: ArrayType => s"scala.collection.Seq" case _: MapType => s"scala.collection.Map" @@ -292,7 +292,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin classOf[InternalRow].getName, classOf[UnsafeRow].getName, classOf[UTF8String].getName, - classOf[Decimal].getName + classOf[Decimal].getName, + classOf[Interval].getName )) evaluator.setExtendedClass(classOf[GeneratedClass]) try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index afd0d9cfa1..9d2161947b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -33,10 +33,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName + private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { case t: AtomicType if !t.isInstanceOf[DecimalType] => true + case _: IntervalType => true case NullType => true case _ => false } @@ -68,6 +70,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s" + (${exprs(i).isNull} ? 0 : $StringWriter.getSize(${exprs(i).primitive}))" case BinaryType => s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))" + case IntervalType => + s" + (${exprs(i).isNull} ? 0 : 16)" case _ => "" } }.mkString("") @@ -80,6 +84,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$cursorTerm += $StringWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" case BinaryType => s"$cursorTerm += $BinaryWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" + case IntervalType => + s"$cursorTerm += $IntervalWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" case NullType => "" case _ => throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 8b0f90cf3a..ab0cdc857c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -78,8 +78,6 @@ trait ExpressionEvalHelper { generator } catch { case e: Throwable => - val ctx = new CodeGenContext - val evaluated = expression.gen(ctx) fail( s""" |Code generation of $expression failed: -- GitLab