From 75438422c2cd90dca53f84879cddecfc2ee0e957 Mon Sep 17 00:00:00 2001
From: Wenchen Fan <cloud0fan@outlook.com>
Date: Mon, 27 Jul 2015 11:28:22 -0700
Subject: [PATCH] [SPARK-9369][SQL] Support IntervalType in UnsafeRow

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #7688 from cloud-fan/interval and squashes the following commits:

5b36b17 [Wenchen Fan] fix codegen
a99ed50 [Wenchen Fan] address comment
9e6d319 [Wenchen Fan] Support IntervalType in UnsafeRow
---
 .../sql/catalyst/expressions/UnsafeRow.java   | 23 ++++++++++++++-----
 .../expressions/UnsafeRowWriters.java         | 19 ++++++++++++++-
 .../spark/sql/catalyst/InternalRow.scala      |  4 +++-
 .../catalyst/expressions/BoundAttribute.scala |  1 +
 .../spark/sql/catalyst/expressions/Cast.scala |  2 +-
 .../expressions/codegen/CodeGenerator.scala   |  7 +++---
 .../codegen/GenerateUnsafeProjection.scala    |  6 +++++
 .../expressions/ExpressionEvalHelper.scala    |  2 --
 8 files changed, 50 insertions(+), 14 deletions(-)

diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 0fb33dd5a1..fb084dd13b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -29,6 +29,7 @@ import org.apache.spark.unsafe.PlatformDependent;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.bitset.BitSetMethods;
 import org.apache.spark.unsafe.hash.Murmur3_x86_32;
+import org.apache.spark.unsafe.types.Interval;
 import org.apache.spark.unsafe.types.UTF8String;
 
 import static org.apache.spark.sql.types.DataTypes.*;
@@ -90,7 +91,8 @@ public final class UnsafeRow extends MutableRow {
     final Set<DataType> _readableFieldTypes = new HashSet<>(
       Arrays.asList(new DataType[]{
         StringType,
-        BinaryType
+        BinaryType,
+        IntervalType
       }));
     _readableFieldTypes.addAll(settableFieldTypes);
     readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
@@ -332,11 +334,6 @@ public final class UnsafeRow extends MutableRow {
     return isNullAt(ordinal) ? null : UTF8String.fromBytes(getBinary(ordinal));
   }
 
-  @Override
-  public String getString(int ordinal) {
-    return getUTF8String(ordinal).toString();
-  }
-
   @Override
   public byte[] getBinary(int ordinal) {
     if (isNullAt(ordinal)) {
@@ -358,6 +355,20 @@ public final class UnsafeRow extends MutableRow {
     }
   }
 
+  @Override
+  public Interval getInterval(int ordinal) {
+    if (isNullAt(ordinal)) {
+      return null;
+    } else {
+      final long offsetAndSize = getLong(ordinal);
+      final int offset = (int) (offsetAndSize >> 32);
+      final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset);
+      final long microseconds =
+        PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8);
+      return new Interval(months, microseconds);
+    }
+  }
+
   @Override
   public UnsafeRow getStruct(int ordinal, int numFields) {
     if (isNullAt(ordinal)) {
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
index 87521d1f23..0ba31d3b9b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions;
 import org.apache.spark.unsafe.PlatformDependent;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.types.ByteArray;
+import org.apache.spark.unsafe.types.Interval;
 import org.apache.spark.unsafe.types.UTF8String;
 
 /**
@@ -54,7 +55,7 @@ public class UnsafeRowWriters {
     }
   }
 
-  /** Writer for bianry (byte array) type. */
+  /** Writer for binary (byte array) type. */
   public static class BinaryWriter {
 
     public static int getSize(byte[] input) {
@@ -80,4 +81,20 @@ public class UnsafeRowWriters {
     }
   }
 
+  /** Writer for interval type. */
+  public static class IntervalWriter {
+
+    public static int write(UnsafeRow target, int ordinal, int cursor, Interval input) {
+      final long offset = target.getBaseOffset() + cursor;
+
+      // Write the months and microseconds fields of Interval to the variable length portion.
+      PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, input.months);
+      PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, input.microseconds);
+
+      // Set the fixed length portion.
+      target.setLong(ordinal, ((long) cursor) << 32);
+      return 16;
+    }
+  }
+
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index ad3977281d..9a11de3840 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{Interval, UTF8String}
 
 /**
  * An abstract class for row used internal in Spark SQL, which only contain the columns as
@@ -60,6 +60,8 @@ abstract class InternalRow extends Serializable {
 
   def getDecimal(ordinal: Int): Decimal = getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT)
 
+  def getInterval(ordinal: Int): Interval = getAs[Interval](ordinal, IntervalType)
+
   // This is only use for test and will throw a null pointer exception if the position is null.
   def getString(ordinal: Int): String = getUTF8String(ordinal).toString
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 6b5c450e3f..41a877f214 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -48,6 +48,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
         case DoubleType => input.getDouble(ordinal)
         case StringType => input.getUTF8String(ordinal)
         case BinaryType => input.getBinary(ordinal)
+        case IntervalType => input.getInterval(ordinal)
         case t: StructType => input.getStruct(ordinal, t.size)
         case dataType => input.get(ordinal, dataType)
       }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index e208262da9..bd8b0177eb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -630,7 +630,7 @@ case class Cast(child: Expression, dataType: DataType)
   private[this] def castToIntervalCode(from: DataType): CastFunction = from match {
     case StringType =>
       (c, evPrim, evNull) =>
-        s"$evPrim = org.apache.spark.unsafe.types.Interval.fromString($c.toString());"
+        s"$evPrim = Interval.fromString($c.toString());"
   }
 
   private[this] def decimalToTimestampCode(d: String): String =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 2a1e288cb8..2f02c90b1d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -79,7 +79,6 @@ class CodeGenContext {
     mutableStates += ((javaType, variableName, initCode))
   }
 
-  final val intervalType: String = classOf[Interval].getName
   final val JAVA_BOOLEAN = "boolean"
   final val JAVA_BYTE = "byte"
   final val JAVA_SHORT = "short"
@@ -109,6 +108,7 @@ class CodeGenContext {
       case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)"
       case StringType => s"$row.getUTF8String($ordinal)"
       case BinaryType => s"$row.getBinary($ordinal)"
+      case IntervalType => s"$row.getInterval($ordinal)"
       case t: StructType => s"$row.getStruct($ordinal, ${t.size})"
       case _ => s"($jt)$row.get($ordinal)"
     }
@@ -150,7 +150,7 @@ class CodeGenContext {
     case dt: DecimalType => "Decimal"
     case BinaryType => "byte[]"
     case StringType => "UTF8String"
-    case IntervalType => intervalType
+    case IntervalType => "Interval"
     case _: StructType => "InternalRow"
     case _: ArrayType => s"scala.collection.Seq"
     case _: MapType => s"scala.collection.Map"
@@ -292,7 +292,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
       classOf[InternalRow].getName,
       classOf[UnsafeRow].getName,
       classOf[UTF8String].getName,
-      classOf[Decimal].getName
+      classOf[Decimal].getName,
+      classOf[Interval].getName
     ))
     evaluator.setExtendedClass(classOf[GeneratedClass])
     try {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index afd0d9cfa1..9d2161947b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -33,10 +33,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
 
   private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName
   private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName
+  private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName
 
   /** Returns true iff we support this data type. */
   def canSupport(dataType: DataType): Boolean = dataType match {
     case t: AtomicType if !t.isInstanceOf[DecimalType] => true
+    case _: IntervalType => true
     case NullType => true
     case _ => false
   }
@@ -68,6 +70,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
           s" + (${exprs(i).isNull} ? 0 : $StringWriter.getSize(${exprs(i).primitive}))"
         case BinaryType =>
           s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))"
+        case IntervalType =>
+          s" + (${exprs(i).isNull} ? 0 : 16)"
         case _ => ""
       }
     }.mkString("")
@@ -80,6 +84,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
           s"$cursorTerm += $StringWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
         case BinaryType =>
           s"$cursorTerm += $BinaryWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
+        case IntervalType =>
+          s"$cursorTerm += $IntervalWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})"
         case NullType => ""
         case _ =>
           throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 8b0f90cf3a..ab0cdc857c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -78,8 +78,6 @@ trait ExpressionEvalHelper {
       generator
     } catch {
       case e: Throwable =>
-        val ctx = new CodeGenContext
-        val evaluated = expression.gen(ctx)
         fail(
           s"""
             |Code generation of $expression failed:
-- 
GitLab