From f27e56aa612538188a8550fe72ee20b8b13304d7 Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@apache.org>
Date: Mon, 7 Apr 2014 19:28:24 -0700
Subject: [PATCH] Change timestamp cast semantics. When cast to numeric types,
 return the unix time in seconds (instead of millis).

@marmbrus @chenghao-intel

Author: Reynold Xin <rxin@apache.org>

Closes #352 from rxin/timestamp-cast and squashes the following commits:

18aacd3 [Reynold Xin] Fixed precision for double.
2adb235 [Reynold Xin] Change timestamp cast semantics. When cast to numeric types, return the unix time in seconds (instead of millis).
---
 .../spark/sql/catalyst/dsl/package.scala      |  2 +-
 .../spark/sql/catalyst/expressions/Cast.scala | 23 ++++++++++------
 .../ExpressionEvaluationSuite.scala           | 27 ++++++++++++++++---
 3 files changed, 40 insertions(+), 12 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 2d62e4cbbc..987befe8e2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -104,7 +104,7 @@ package object dsl {
     implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name }
     // TODO more implicit class for literal?
     implicit class DslString(val s: String) extends ImplicitOperators {
-      def expr: Expression = Literal(s)
+      override def expr: Expression = Literal(s)
       def attr = analysis.UnresolvedAttribute(s)
     }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 89226999ca..17118499d0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -87,7 +87,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
 
   private def decimalToTimestamp(d: BigDecimal) = {
     val seconds = d.longValue()
-    val bd = (d - seconds) * (1000000000)
+    val bd = (d - seconds) * 1000000000
     val nanos = bd.intValue()
 
     // Convert to millis
@@ -96,18 +96,23 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
 
     // remaining fractional portion as nanos
     t.setNanos(nanos)
-    
     t
   }
 
-  private def timestampToDouble(t: Timestamp) = (t.getSeconds() + t.getNanos().toDouble / 1000)
+  // Timestamp to long, converting milliseconds to seconds
+  private def timestampToLong(ts: Timestamp) = ts.getTime / 1000
+
+  private def timestampToDouble(ts: Timestamp) = {
+    // First part is the seconds since the beginning of time, followed by nanosecs.
+    ts.getTime / 1000 + ts.getNanos.toDouble / 1000000000
+  }
 
   def castToLong: Any => Any = child.dataType match {
     case StringType => nullOrCast[String](_, s => try s.toLong catch {
       case _: NumberFormatException => null
     })
     case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
-    case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toLong)
+    case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t))
     case DecimalType => nullOrCast[BigDecimal](_, _.toLong)
     case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
   }
@@ -117,7 +122,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
       case _: NumberFormatException => null
     })
     case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
-    case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toInt)
+    case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toInt)
     case DecimalType => nullOrCast[BigDecimal](_, _.toInt)
     case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
   }
@@ -127,7 +132,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
       case _: NumberFormatException => null
     })
     case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
-    case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toShort)
+    case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toShort)
     case DecimalType => nullOrCast[BigDecimal](_, _.toShort)
     case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
   }
@@ -137,7 +142,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
       case _: NumberFormatException => null
     })
     case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
-    case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toByte)
+    case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toByte)
     case DecimalType => nullOrCast[BigDecimal](_, _.toByte)
     case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
   }
@@ -147,7 +152,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
       case _: NumberFormatException => null
     })
     case BooleanType => nullOrCast[Boolean](_, b => if(b) BigDecimal(1) else BigDecimal(0))
-    case TimestampType => nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
+    case TimestampType =>
+      // Note that we lose precision here.
+      nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
     case x: NumericType => b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
   }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 31be6c4ef1..888a19d79f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -201,7 +201,7 @@ class ExpressionEvaluationSuite extends FunSuite {
     
     val sts = "1970-01-01 00:00:01.0"
     val ts = Timestamp.valueOf(sts)
-    
+
     checkEvaluation("abdef" cast StringType, "abdef")
     checkEvaluation("abdef" cast DecimalType, null)
     checkEvaluation("abdef" cast TimestampType, null)
@@ -209,7 +209,6 @@ class ExpressionEvaluationSuite extends FunSuite {
 
     checkEvaluation(Literal(1) cast LongType, 1)
     checkEvaluation(Cast(Literal(1) cast TimestampType, LongType), 1)
-    checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1)
     checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble)
 
     checkEvaluation(Cast(Literal(sts) cast TimestampType, StringType), sts)
@@ -240,12 +239,34 @@ class ExpressionEvaluationSuite extends FunSuite {
     
     intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)}
   }
-  
+
   test("timestamp") {
     val ts1 = new Timestamp(12)
     val ts2 = new Timestamp(123)
     checkEvaluation(Literal("ab") < Literal("abc"), true)
     checkEvaluation(Literal(ts1) < Literal(ts2), true)
   }
+
+  test("timestamp casting") {
+    val millis = 15 * 1000 + 2
+    val ts = new Timestamp(millis)
+    val ts1 = new Timestamp(15 * 1000)  // a timestamp without the milliseconds part
+    checkEvaluation(Cast(ts, ShortType), 15)
+    checkEvaluation(Cast(ts, IntegerType), 15)
+    checkEvaluation(Cast(ts, LongType), 15)
+    checkEvaluation(Cast(ts, FloatType), 15.002f)
+    checkEvaluation(Cast(ts, DoubleType), 15.002)
+    checkEvaluation(Cast(Cast(ts, ShortType), TimestampType), ts1)
+    checkEvaluation(Cast(Cast(ts, IntegerType), TimestampType), ts1)
+    checkEvaluation(Cast(Cast(ts, LongType), TimestampType), ts1)
+    checkEvaluation(Cast(Cast(millis.toFloat / 1000, TimestampType), FloatType),
+      millis.toFloat / 1000)
+    checkEvaluation(Cast(Cast(millis.toDouble / 1000, TimestampType), DoubleType),
+      millis.toDouble / 1000)
+    checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1)
+
+    // A test for higher precision than millis
+    checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001)
+  }
 }
 
-- 
GitLab