Skip to content
Snippets Groups Projects
Commit 5d1feda2 authored by Cheng Hao's avatar Cheng Hao Committed by Reynold Xin
Browse files

[SPARK-1360] Add Timestamp Support for SQL

This PR includes:
1) Add new data type Timestamp
2) Add more data type casting base on Hive's Rule
3) Fix bug missing data type in both parsers (HiveQl & SQLParser).

Author: Cheng Hao <hao.cheng@intel.com>

Closes #275 from chenghao-intel/timestamp and squashes the following commits:

df709e5 [Cheng Hao] Move orc_ends_with_nulls to blacklist
24b04b0 [Cheng Hao] Put 3 cases into the black lists(describe_pretty,describe_syntax,lateral_view_outer)
fc512c2 [Cheng Hao] remove the unnecessary data type equality check in data casting
d0d1919 [Cheng Hao] Add more data type for scala reflection
3259808 [Cheng Hao] Add the new Golden files
3823b97 [Cheng Hao] Update the UnitTest cases & add timestamp type for HiveQL
54a0489 [Cheng Hao] fix bug mapping to 0 (which is supposed to be null) when NumberFormatException occurs
9cb505c [Cheng Hao] Fix issues according to PR comments
e529168 [Cheng Hao] Fix bug of converting from String
6fc8100 [Cheng Hao] Update Unit Test & CodeStyle
8a1d4d6 [Cheng Hao] Add DataType for SqlParser
ce4385e [Cheng Hao] Add TimestampType Support
parent fbebaedf
No related branches found
No related tags found
No related merge requests found
Showing
with 344 additions and 100 deletions
......@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst
import java.sql.Timestamp
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
......@@ -54,14 +56,15 @@ object ScalaReflection {
val TypeRef(_, _, Seq(keyType, valueType)) = t
MapType(schemaFor(keyType), schemaFor(valueType))
case t if t <:< typeOf[String] => StringType
case t if t <:< typeOf[Timestamp] => TimestampType
case t if t <:< typeOf[BigDecimal] => DecimalType
case t if t <:< definitions.IntTpe => IntegerType
case t if t <:< definitions.LongTpe => LongType
case t if t <:< definitions.FloatTpe => FloatType
case t if t <:< definitions.DoubleTpe => DoubleType
case t if t <:< definitions.FloatTpe => FloatType
case t if t <:< definitions.ShortTpe => ShortType
case t if t <:< definitions.ByteTpe => ByteType
case t if t <:< definitions.BooleanTpe => BooleanType
case t if t <:< typeOf[BigDecimal] => DecimalType
}
implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) {
......
......@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst
import java.sql.Timestamp
import scala.language.implicitConversions
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
......@@ -72,6 +74,7 @@ package object dsl {
def like(other: Expression) = Like(expr, other)
def rlike(other: Expression) = RLike(expr, other)
def cast(to: DataType) = Cast(expr, to)
def asc = SortOrder(expr, Ascending)
def desc = SortOrder(expr, Descending)
......@@ -84,15 +87,22 @@ package object dsl {
def expr = e
}
implicit def booleanToLiteral(b: Boolean) = Literal(b)
implicit def byteToLiteral(b: Byte) = Literal(b)
implicit def shortToLiteral(s: Short) = Literal(s)
implicit def intToLiteral(i: Int) = Literal(i)
implicit def longToLiteral(l: Long) = Literal(l)
implicit def floatToLiteral(f: Float) = Literal(f)
implicit def doubleToLiteral(d: Double) = Literal(d)
implicit def stringToLiteral(s: String) = Literal(s)
implicit def decimalToLiteral(d: BigDecimal) = Literal(d)
implicit def timestampToLiteral(t: Timestamp) = Literal(t)
implicit def binaryToLiteral(a: Array[Byte]) = Literal(a)
implicit def symbolToUnresolvedAttribute(s: Symbol) = analysis.UnresolvedAttribute(s.name)
implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name }
// TODO more implicit class for literal?
implicit class DslString(val s: String) extends ImplicitOperators {
def expr: Expression = Literal(s)
def attr = analysis.UnresolvedAttribute(s)
......@@ -103,11 +113,38 @@ package object dsl {
def expr = attr
def attr = analysis.UnresolvedAttribute(s)
/** Creates a new typed attributes of type int */
/** Creates a new AttributeReference of type boolean */
def boolean = AttributeReference(s, BooleanType, nullable = false)()
/** Creates a new AttributeReference of type byte */
def byte = AttributeReference(s, ByteType, nullable = false)()
/** Creates a new AttributeReference of type short */
def short = AttributeReference(s, ShortType, nullable = false)()
/** Creates a new AttributeReference of type int */
def int = AttributeReference(s, IntegerType, nullable = false)()
/** Creates a new typed attributes of type string */
/** Creates a new AttributeReference of type long */
def long = AttributeReference(s, LongType, nullable = false)()
/** Creates a new AttributeReference of type float */
def float = AttributeReference(s, FloatType, nullable = false)()
/** Creates a new AttributeReference of type double */
def double = AttributeReference(s, DoubleType, nullable = false)()
/** Creates a new AttributeReference of type string */
def string = AttributeReference(s, StringType, nullable = false)()
/** Creates a new AttributeReference of type decimal */
def decimal = AttributeReference(s, DecimalType, nullable = false)()
/** Creates a new AttributeReference of type timestamp */
def timestamp = AttributeReference(s, TimestampType, nullable = false)()
/** Creates a new AttributeReference of type binary */
def binary = AttributeReference(s, BinaryType, nullable = false)()
}
implicit class DslAttribute(a: AttributeReference) {
......
......@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
import java.sql.Timestamp
import org.apache.spark.sql.catalyst.types._
/** Cast the child expression to the target data type. */
......@@ -26,52 +28,169 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
override def toString = s"CAST($child, $dataType)"
type EvaluatedType = Any
def nullOrCast[T](a: Any, func: T => Any): Any = if(a == null) {
null
} else {
func(a.asInstanceOf[T])
}
lazy val castingFunction: Any => Any = (child.dataType, dataType) match {
case (BinaryType, StringType) => a: Any => new String(a.asInstanceOf[Array[Byte]])
case (StringType, BinaryType) => a: Any => a.asInstanceOf[String].getBytes
case (_, StringType) => a: Any => a.toString
case (StringType, IntegerType) => a: Any => castOrNull(a, _.toInt)
case (StringType, DoubleType) => a: Any => castOrNull(a, _.toDouble)
case (StringType, FloatType) => a: Any => castOrNull(a, _.toFloat)
case (StringType, LongType) => a: Any => castOrNull(a, _.toLong)
case (StringType, ShortType) => a: Any => castOrNull(a, _.toShort)
case (StringType, ByteType) => a: Any => castOrNull(a, _.toByte)
case (StringType, DecimalType) => a: Any => castOrNull(a, BigDecimal(_))
case (BooleanType, ByteType) => {
case null => null
case true => 1.toByte
case false => 0.toByte
}
case (dt, IntegerType) =>
a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a)
case (dt, DoubleType) =>
a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a)
case (dt, FloatType) =>
a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toFloat(a)
case (dt, LongType) =>
a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toLong(a)
case (dt, ShortType) =>
a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toShort
case (dt, ByteType) =>
a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toByte
case (dt, DecimalType) =>
a: Any =>
BigDecimal(dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a))
// UDFToString
def castToString: Any => Any = child.dataType match {
case BinaryType => nullOrCast[Array[Byte]](_, new String(_, "UTF-8"))
case _ => nullOrCast[Any](_, _.toString)
}
// BinaryConverter
def castToBinary: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, _.getBytes("UTF-8"))
}
@inline
protected def castOrNull[A](a: Any, f: String => A) =
try f(a.asInstanceOf[String]) catch {
case _: java.lang.NumberFormatException => null
}
// UDFToBoolean
def castToBoolean: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, _.length() != 0)
case TimestampType => nullOrCast[Timestamp](_, b => {(b.getTime() != 0 || b.getNanos() != 0)})
case LongType => nullOrCast[Long](_, _ != 0)
case IntegerType => nullOrCast[Int](_, _ != 0)
case ShortType => nullOrCast[Short](_, _ != 0)
case ByteType => nullOrCast[Byte](_, _ != 0)
case DecimalType => nullOrCast[BigDecimal](_, _ != 0)
case DoubleType => nullOrCast[Double](_, _ != 0)
case FloatType => nullOrCast[Float](_, _ != 0)
}
// TimestampConverter
def castToTimestamp: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => {
// Throw away extra if more than 9 decimal places
val periodIdx = s.indexOf(".");
var n = s
if (periodIdx != -1) {
if (n.length() - periodIdx > 9) {
n = n.substring(0, periodIdx + 10)
}
}
try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null}
})
case BooleanType => nullOrCast[Boolean](_, b => new Timestamp((if(b) 1 else 0) * 1000))
case LongType => nullOrCast[Long](_, l => new Timestamp(l * 1000))
case IntegerType => nullOrCast[Int](_, i => new Timestamp(i * 1000))
case ShortType => nullOrCast[Short](_, s => new Timestamp(s * 1000))
case ByteType => nullOrCast[Byte](_, b => new Timestamp(b * 1000))
// TimestampWritable.decimalToTimestamp
case DecimalType => nullOrCast[BigDecimal](_, d => decimalToTimestamp(d))
// TimestampWritable.doubleToTimestamp
case DoubleType => nullOrCast[Double](_, d => decimalToTimestamp(d))
// TimestampWritable.floatToTimestamp
case FloatType => nullOrCast[Float](_, f => decimalToTimestamp(f))
}
private def decimalToTimestamp(d: BigDecimal) = {
val seconds = d.longValue()
val bd = (d - seconds) * (1000000000)
val nanos = bd.intValue()
// Convert to millis
val millis = seconds * 1000
val t = new Timestamp(millis)
// remaining fractional portion as nanos
t.setNanos(nanos)
t
}
private def timestampToDouble(t: Timestamp) = (t.getSeconds() + t.getNanos().toDouble / 1000)
def castToLong: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try s.toLong catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toLong)
case DecimalType => nullOrCast[BigDecimal](_, _.toLong)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
}
def castToInt: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try s.toInt catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toInt)
case DecimalType => nullOrCast[BigDecimal](_, _.toInt)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
}
def castToShort: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try s.toShort catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toShort)
case DecimalType => nullOrCast[BigDecimal](_, _.toShort)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
}
def castToByte: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try s.toByte catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toByte)
case DecimalType => nullOrCast[BigDecimal](_, _.toByte)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
}
def castToDecimal: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try BigDecimal(s.toDouble) catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean](_, b => if(b) BigDecimal(1) else BigDecimal(0))
case TimestampType => nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
case x: NumericType => b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
}
def castToDouble: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try s.toDouble catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t))
case DecimalType => nullOrCast[BigDecimal](_, _.toDouble)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
}
def castToFloat: Any => Any = child.dataType match {
case StringType => nullOrCast[String](_, s => try s.toFloat catch {
case _: NumberFormatException => null
})
case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toFloat)
case DecimalType => nullOrCast[BigDecimal](_, _.toFloat)
case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
}
def cast: Any => Any = dataType match {
case StringType => castToString
case BinaryType => castToBinary
case DecimalType => castToDecimal
case TimestampType => castToTimestamp
case BooleanType => castToBoolean
case ByteType => castToByte
case ShortType => castToShort
case IntegerType => castToInt
case FloatType => castToFloat
case LongType => castToLong
case DoubleType => castToDouble
}
override def apply(input: Row): Any = {
val evaluated = child.apply(input)
if (evaluated == null) {
null
} else {
castingFunction(evaluated)
cast(evaluated)
}
}
}
......@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType}
import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType, NativeType}
abstract class Expression extends TreeNode[Expression] {
self: Product =>
......@@ -86,6 +86,11 @@ abstract class Expression extends TreeNode[Expression] {
}
}
/**
* Evaluation helper function for 2 Numeric children expressions. Those expressions are supposed
* to be in the same data type, and also the return type.
* Either one of the expressions result is null, the evaluation result should be null.
*/
@inline
protected final def n2(
i: Row,
......@@ -115,6 +120,11 @@ abstract class Expression extends TreeNode[Expression] {
}
}
/**
* Evaluation helper function for 2 Fractional children expressions. Those expressions are
* supposed to be in the same data type, and also the return type.
* Either one of the expressions result is null, the evaluation result should be null.
*/
@inline
protected final def f2(
i: Row,
......@@ -143,6 +153,11 @@ abstract class Expression extends TreeNode[Expression] {
}
}
/**
* Evaluation helper function for 2 Integral children expressions. Those expressions are
* supposed to be in the same data type, and also the return type.
* Either one of the expressions result is null, the evaluation result should be null.
*/
@inline
protected final def i2(
i: Row,
......@@ -170,6 +185,43 @@ abstract class Expression extends TreeNode[Expression] {
}
}
}
/**
* Evaluation helper function for 2 Comparable children expressions. Those expressions are
* supposed to be in the same data type, and the return type should be Integer:
* Negative value: 1st argument less than 2nd argument
* Zero: 1st argument equals 2nd argument
* Positive value: 1st argument greater than 2nd argument
*
* Either one of the expressions result is null, the evaluation result should be null.
*/
@inline
protected final def c2(
i: Row,
e1: Expression,
e2: Expression,
f: ((Ordering[Any], Any, Any) => Any)): Any = {
if (e1.dataType != e2.dataType) {
throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}")
}
val evalE1 = e1.apply(i)
if(evalE1 == null) {
null
} else {
val evalE2 = e2.apply(i)
if (evalE2 == null) {
null
} else {
e1.dataType match {
case i: NativeType =>
f.asInstanceOf[(Ordering[i.JvmType], i.JvmType, i.JvmType) => Boolean](
i.ordering, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType])
case other => sys.error(s"Type $other does not support ordered operations")
}
}
}
}
}
abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
......
......@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
import java.sql.Timestamp
import org.apache.spark.sql.catalyst.types._
object Literal {
......@@ -29,6 +31,9 @@ object Literal {
case s: Short => Literal(s, ShortType)
case s: String => Literal(s, StringType)
case b: Boolean => Literal(b, BooleanType)
case d: BigDecimal => Literal(d, DecimalType)
case t: Timestamp => Literal(t, TimestampType)
case a: Array[Byte] => Literal(a, BinaryType)
case null => Literal(null, NullType)
}
}
......
......@@ -18,8 +18,9 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.types.{BooleanType, StringType}
import org.apache.spark.sql.catalyst.types.{BooleanType, StringType, TimestampType}
object InterpretedPredicate {
def apply(expression: Expression): (Row => Boolean) = {
......@@ -123,70 +124,22 @@ case class Equals(left: Expression, right: Expression) extends BinaryComparison
case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
def symbol = "<"
override def apply(input: Row): Any = {
if (left.dataType == StringType && right.dataType == StringType) {
val l = left.apply(input)
val r = right.apply(input)
if(l == null || r == null) {
null
} else {
l.asInstanceOf[String] < r.asInstanceOf[String]
}
} else {
n2(input, left, right, _.lt(_, _))
}
}
override def apply(input: Row): Any = c2(input, left, right, _.lt(_, _))
}
case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
def symbol = "<="
override def apply(input: Row): Any = {
if (left.dataType == StringType && right.dataType == StringType) {
val l = left.apply(input)
val r = right.apply(input)
if(l == null || r == null) {
null
} else {
l.asInstanceOf[String] <= r.asInstanceOf[String]
}
} else {
n2(input, left, right, _.lteq(_, _))
}
}
override def apply(input: Row): Any = c2(input, left, right, _.lteq(_, _))
}
case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
def symbol = ">"
override def apply(input: Row): Any = {
if (left.dataType == StringType && right.dataType == StringType) {
val l = left.apply(input)
val r = right.apply(input)
if(l == null || r == null) {
null
} else {
l.asInstanceOf[String] > r.asInstanceOf[String]
}
} else {
n2(input, left, right, _.gt(_, _))
}
}
override def apply(input: Row): Any = c2(input, left, right, _.gt(_, _))
}
case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
def symbol = ">="
override def apply(input: Row): Any = {
if (left.dataType == StringType && right.dataType == StringType) {
val l = left.apply(input)
val r = right.apply(input)
if(l == null || r == null) {
null
} else {
l.asInstanceOf[String] >= r.asInstanceOf[String]
}
} else {
n2(input, left, right, _.gteq(_, _))
}
}
override def apply(input: Row): Any = c2(input, left, right, _.gteq(_, _))
}
case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
......
......@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.types
import java.sql.Timestamp
import scala.reflect.runtime.universe.{typeTag, TypeTag}
import org.apache.spark.sql.catalyst.expressions.Expression
......@@ -51,6 +53,16 @@ case object BooleanType extends NativeType {
val ordering = implicitly[Ordering[JvmType]]
}
case object TimestampType extends NativeType {
type JvmType = Timestamp
@transient lazy val tag = typeTag[JvmType]
val ordering = new Ordering[JvmType] {
def compare(x: Timestamp, y: Timestamp) = x.compareTo(y)
}
}
abstract class NumericType extends NativeType {
// Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for
// implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a
......
......@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
import java.sql.Timestamp
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.types._
......@@ -191,5 +193,56 @@ class ExpressionEvaluationSuite extends FunSuite {
evaluate("abbbbc" rlike regEx, new GenericRow(Array[Any]("**")))
}
}
test("data type casting") {
val sts = "1970-01-01 00:00:01.0"
val ts = Timestamp.valueOf(sts)
checkEvaluation("abdef" cast StringType, "abdef")
checkEvaluation("abdef" cast DecimalType, null)
checkEvaluation("abdef" cast TimestampType, null)
checkEvaluation("12.65" cast DecimalType, BigDecimal(12.65))
checkEvaluation(Literal(1) cast LongType, 1)
checkEvaluation(Cast(Literal(1) cast TimestampType, LongType), 1)
checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1)
checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble)
checkEvaluation(Cast(Literal(sts) cast TimestampType, StringType), sts)
checkEvaluation(Cast(Literal(ts) cast StringType, TimestampType), ts)
checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef")
checkEvaluation(Cast(Cast(Cast(Cast(
Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5)
checkEvaluation(Cast(Cast(Cast(Cast(
Cast("5" cast ByteType, TimestampType), DecimalType), LongType), StringType), ShortType), 5)
checkEvaluation(Cast(Cast(Cast(Cast(
Cast("5" cast TimestampType, ByteType), DecimalType), LongType), StringType), ShortType), null)
checkEvaluation(Cast(Cast(Cast(Cast(
Cast("5" cast DecimalType, ByteType), TimestampType), LongType), StringType), ShortType), 5)
checkEvaluation(Literal(true) cast IntegerType, 1)
checkEvaluation(Literal(false) cast IntegerType, 0)
checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1)
checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0)
checkEvaluation("23" cast DoubleType, 23)
checkEvaluation("23" cast IntegerType, 23)
checkEvaluation("23" cast FloatType, 23)
checkEvaluation("23" cast DecimalType, 23)
checkEvaluation("23" cast ByteType, 23)
checkEvaluation("23" cast ShortType, 23)
checkEvaluation("2012-12-11" cast DoubleType, null)
checkEvaluation(Literal(123) cast IntegerType, 123)
intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)}
}
test("timestamp") {
val ts1 = new Timestamp(12)
val ts2 = new Timestamp(123)
checkEvaluation(Literal("ab") < Literal("abc"), true)
checkEvaluation(Literal(ts1) < Literal(ts2), true)
}
}
......@@ -17,6 +17,8 @@
package org.apache.spark.sql
import java.sql.Timestamp
import org.scalatest.FunSuite
import org.apache.spark.sql.test.TestSQLContext._
......@@ -31,6 +33,7 @@ case class ReflectData(
byteField: Byte,
booleanField: Boolean,
decimalField: BigDecimal,
timestampField: Timestamp,
seqInt: Seq[Int])
case class ReflectBinary(data: Array[Byte])
......@@ -38,7 +41,7 @@ case class ReflectBinary(data: Array[Byte])
class ScalaReflectionRelationSuite extends FunSuite {
test("query case class RDD") {
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
BigDecimal(1), Seq(1,2,3))
BigDecimal(1), new Timestamp(12345), Seq(1,2,3))
val rdd = sparkContext.parallelize(data :: Nil)
rdd.registerAsTable("reflectData")
......
......@@ -300,14 +300,17 @@ object HiveQl {
}
protected def nodeToDataType(node: Node): DataType = node match {
case Token("TOK_BIGINT", Nil) => IntegerType
case Token("TOK_DECIMAL", Nil) => DecimalType
case Token("TOK_BIGINT", Nil) => LongType
case Token("TOK_INT", Nil) => IntegerType
case Token("TOK_TINYINT", Nil) => IntegerType
case Token("TOK_SMALLINT", Nil) => IntegerType
case Token("TOK_TINYINT", Nil) => ByteType
case Token("TOK_SMALLINT", Nil) => ShortType
case Token("TOK_BOOLEAN", Nil) => BooleanType
case Token("TOK_STRING", Nil) => StringType
case Token("TOK_FLOAT", Nil) => FloatType
case Token("TOK_DOUBLE", Nil) => FloatType
case Token("TOK_DOUBLE", Nil) => DoubleType
case Token("TOK_TIMESTAMP", Nil) => TimestampType
case Token("TOK_BINARY", Nil) => BinaryType
case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType))
case Token("TOK_STRUCT",
Token("TOK_TABCOLLIST", fields) :: Nil) =>
......@@ -829,6 +832,8 @@ object HiveQl {
Cast(nodeToExpr(arg), BooleanType)
case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), DecimalType)
case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), TimestampType)
/* Arithmetic */
case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child))
......
db2_insert1
db2_insert2
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment