Skip to content
Snippets Groups Projects
Commit 75438422 authored by Wenchen Fan's avatar Wenchen Fan Committed by Reynold Xin
Browse files

[SPARK-9369][SQL] Support IntervalType in UnsafeRow

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #7688 from cloud-fan/interval and squashes the following commits:

5b36b17 [Wenchen Fan] fix codegen
a99ed50 [Wenchen Fan] address comment
9e6d319 [Wenchen Fan] Support IntervalType in UnsafeRow
parent dd9ae794
No related branches found
No related tags found
No related merge requests found
Showing with 50 additions and 14 deletions
...@@ -29,6 +29,7 @@ import org.apache.spark.unsafe.PlatformDependent; ...@@ -29,6 +29,7 @@ import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.Interval;
import org.apache.spark.unsafe.types.UTF8String; import org.apache.spark.unsafe.types.UTF8String;
import static org.apache.spark.sql.types.DataTypes.*; import static org.apache.spark.sql.types.DataTypes.*;
...@@ -90,7 +91,8 @@ public final class UnsafeRow extends MutableRow { ...@@ -90,7 +91,8 @@ public final class UnsafeRow extends MutableRow {
final Set<DataType> _readableFieldTypes = new HashSet<>( final Set<DataType> _readableFieldTypes = new HashSet<>(
Arrays.asList(new DataType[]{ Arrays.asList(new DataType[]{
StringType, StringType,
BinaryType BinaryType,
IntervalType
})); }));
_readableFieldTypes.addAll(settableFieldTypes); _readableFieldTypes.addAll(settableFieldTypes);
readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
...@@ -332,11 +334,6 @@ public final class UnsafeRow extends MutableRow { ...@@ -332,11 +334,6 @@ public final class UnsafeRow extends MutableRow {
return isNullAt(ordinal) ? null : UTF8String.fromBytes(getBinary(ordinal)); return isNullAt(ordinal) ? null : UTF8String.fromBytes(getBinary(ordinal));
} }
@Override
public String getString(int ordinal) {
return getUTF8String(ordinal).toString();
}
@Override @Override
public byte[] getBinary(int ordinal) { public byte[] getBinary(int ordinal) {
if (isNullAt(ordinal)) { if (isNullAt(ordinal)) {
...@@ -358,6 +355,20 @@ public final class UnsafeRow extends MutableRow { ...@@ -358,6 +355,20 @@ public final class UnsafeRow extends MutableRow {
} }
} }
@Override
public Interval getInterval(int ordinal) {
if (isNullAt(ordinal)) {
return null;
} else {
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset);
final long microseconds =
PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8);
return new Interval(months, microseconds);
}
}
@Override @Override
public UnsafeRow getStruct(int ordinal, int numFields) { public UnsafeRow getStruct(int ordinal, int numFields) {
if (isNullAt(ordinal)) { if (isNullAt(ordinal)) {
......
...@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions; ...@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions;
import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.types.ByteArray; import org.apache.spark.unsafe.types.ByteArray;
import org.apache.spark.unsafe.types.Interval;
import org.apache.spark.unsafe.types.UTF8String; import org.apache.spark.unsafe.types.UTF8String;
/** /**
...@@ -54,7 +55,7 @@ public class UnsafeRowWriters { ...@@ -54,7 +55,7 @@ public class UnsafeRowWriters {
} }
} }
/** Writer for bianry (byte array) type. */ /** Writer for binary (byte array) type. */
public static class BinaryWriter { public static class BinaryWriter {
public static int getSize(byte[] input) { public static int getSize(byte[] input) {
...@@ -80,4 +81,20 @@ public class UnsafeRowWriters { ...@@ -80,4 +81,20 @@ public class UnsafeRowWriters {
} }
} }
/** Writer for interval type. */
public static class IntervalWriter {
public static int write(UnsafeRow target, int ordinal, int cursor, Interval input) {
final long offset = target.getBaseOffset() + cursor;
// Write the months and microseconds fields of Interval to the variable length portion.
PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, input.months);
PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, input.microseconds);
// Set the fixed length portion.
target.setLong(ordinal, ((long) cursor) << 32);
return 16;
}
}
} }
...@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst ...@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.Row import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.unsafe.types.{Interval, UTF8String}
/** /**
* An abstract class for row used internal in Spark SQL, which only contain the columns as * An abstract class for row used internal in Spark SQL, which only contain the columns as
...@@ -60,6 +60,8 @@ abstract class InternalRow extends Serializable { ...@@ -60,6 +60,8 @@ abstract class InternalRow extends Serializable {
def getDecimal(ordinal: Int): Decimal = getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT) def getDecimal(ordinal: Int): Decimal = getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT)
def getInterval(ordinal: Int): Interval = getAs[Interval](ordinal, IntervalType)
// This is only use for test and will throw a null pointer exception if the position is null. // This is only use for test and will throw a null pointer exception if the position is null.
def getString(ordinal: Int): String = getUTF8String(ordinal).toString def getString(ordinal: Int): String = getUTF8String(ordinal).toString
......
...@@ -48,6 +48,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) ...@@ -48,6 +48,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
case DoubleType => input.getDouble(ordinal) case DoubleType => input.getDouble(ordinal)
case StringType => input.getUTF8String(ordinal) case StringType => input.getUTF8String(ordinal)
case BinaryType => input.getBinary(ordinal) case BinaryType => input.getBinary(ordinal)
case IntervalType => input.getInterval(ordinal)
case t: StructType => input.getStruct(ordinal, t.size) case t: StructType => input.getStruct(ordinal, t.size)
case dataType => input.get(ordinal, dataType) case dataType => input.get(ordinal, dataType)
} }
......
...@@ -630,7 +630,7 @@ case class Cast(child: Expression, dataType: DataType) ...@@ -630,7 +630,7 @@ case class Cast(child: Expression, dataType: DataType)
private[this] def castToIntervalCode(from: DataType): CastFunction = from match { private[this] def castToIntervalCode(from: DataType): CastFunction = from match {
case StringType => case StringType =>
(c, evPrim, evNull) => (c, evPrim, evNull) =>
s"$evPrim = org.apache.spark.unsafe.types.Interval.fromString($c.toString());" s"$evPrim = Interval.fromString($c.toString());"
} }
private[this] def decimalToTimestampCode(d: String): String = private[this] def decimalToTimestampCode(d: String): String =
......
...@@ -79,7 +79,6 @@ class CodeGenContext { ...@@ -79,7 +79,6 @@ class CodeGenContext {
mutableStates += ((javaType, variableName, initCode)) mutableStates += ((javaType, variableName, initCode))
} }
final val intervalType: String = classOf[Interval].getName
final val JAVA_BOOLEAN = "boolean" final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte" final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short" final val JAVA_SHORT = "short"
...@@ -109,6 +108,7 @@ class CodeGenContext { ...@@ -109,6 +108,7 @@ class CodeGenContext {
case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)" case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)"
case StringType => s"$row.getUTF8String($ordinal)" case StringType => s"$row.getUTF8String($ordinal)"
case BinaryType => s"$row.getBinary($ordinal)" case BinaryType => s"$row.getBinary($ordinal)"
case IntervalType => s"$row.getInterval($ordinal)"
case t: StructType => s"$row.getStruct($ordinal, ${t.size})" case t: StructType => s"$row.getStruct($ordinal, ${t.size})"
case _ => s"($jt)$row.get($ordinal)" case _ => s"($jt)$row.get($ordinal)"
} }
...@@ -150,7 +150,7 @@ class CodeGenContext { ...@@ -150,7 +150,7 @@ class CodeGenContext {
case dt: DecimalType => "Decimal" case dt: DecimalType => "Decimal"
case BinaryType => "byte[]" case BinaryType => "byte[]"
case StringType => "UTF8String" case StringType => "UTF8String"
case IntervalType => intervalType case IntervalType => "Interval"
case _: StructType => "InternalRow" case _: StructType => "InternalRow"
case _: ArrayType => s"scala.collection.Seq" case _: ArrayType => s"scala.collection.Seq"
case _: MapType => s"scala.collection.Map" case _: MapType => s"scala.collection.Map"
...@@ -292,7 +292,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin ...@@ -292,7 +292,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
classOf[InternalRow].getName, classOf[InternalRow].getName,
classOf[UnsafeRow].getName, classOf[UnsafeRow].getName,
classOf[UTF8String].getName, classOf[UTF8String].getName,
classOf[Decimal].getName classOf[Decimal].getName,
classOf[Interval].getName
)) ))
evaluator.setExtendedClass(classOf[GeneratedClass]) evaluator.setExtendedClass(classOf[GeneratedClass])
try { try {
......
...@@ -33,10 +33,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ...@@ -33,10 +33,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName
private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName
private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName
/** Returns true iff we support this data type. */ /** Returns true iff we support this data type. */
def canSupport(dataType: DataType): Boolean = dataType match { def canSupport(dataType: DataType): Boolean = dataType match {
case t: AtomicType if !t.isInstanceOf[DecimalType] => true case t: AtomicType if !t.isInstanceOf[DecimalType] => true
case _: IntervalType => true
case NullType => true case NullType => true
case _ => false case _ => false
} }
...@@ -68,6 +70,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ...@@ -68,6 +70,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s" + (${exprs(i).isNull} ? 0 : $StringWriter.getSize(${exprs(i).primitive}))" s" + (${exprs(i).isNull} ? 0 : $StringWriter.getSize(${exprs(i).primitive}))"
case BinaryType => case BinaryType =>
s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))" s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))"
case IntervalType =>
s" + (${exprs(i).isNull} ? 0 : 16)"
case _ => "" case _ => ""
} }
}.mkString("") }.mkString("")
...@@ -80,6 +84,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ...@@ -80,6 +84,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"$cursorTerm += $StringWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" s"$cursorTerm += $StringWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
case BinaryType => case BinaryType =>
s"$cursorTerm += $BinaryWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" s"$cursorTerm += $BinaryWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
case IntervalType =>
s"$cursorTerm += $IntervalWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
case NullType => "" case NullType => ""
case _ => case _ =>
throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}") throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}")
......
...@@ -78,8 +78,6 @@ trait ExpressionEvalHelper { ...@@ -78,8 +78,6 @@ trait ExpressionEvalHelper {
generator generator
} catch { } catch {
case e: Throwable => case e: Throwable =>
val ctx = new CodeGenContext
val evaluated = expression.gen(ctx)
fail( fail(
s""" s"""
|Code generation of $expression failed: |Code generation of $expression failed:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment