diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index b805cfe88be63eb51f0f2bc2d650e2c6691b940a..0b6fa5646970403e6dcf2988da098f52b5daf221 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst +import java.util.TimeZone + import org.apache.spark.sql.catalyst.analysis._ /** @@ -36,6 +38,8 @@ trait CatalystConf { def warehousePath: String + def sessionLocalTimeZone: String + /** If true, cartesian products between relations will be allowed for all * join types(inner, (left|right|full) outer). * If false, cartesian products will require explicit CROSS JOIN syntax. @@ -68,5 +72,6 @@ case class SimpleCatalystConf( runSQLonFile: Boolean = true, crossJoinEnabled: Boolean = false, cboEnabled: Boolean = false, - warehousePath: String = "/user/hive/warehouse") + warehousePath: String = "/user/hive/warehouse", + sessionLocalTimeZone: String = TimeZone.getDefault().getID) extends CatalystConf diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cb56e94c0a77a82693db0ad97538fcc7a8cf5e62..8ec330455900b6a2047e0dd801c5e29f03e12f4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -155,6 +155,8 @@ class Analyzer( HandleNullInputsForUDF), Batch("FixNullability", Once, FixNullability), + Batch("ResolveTimeZone", Once, + ResolveTimeZone), Batch("Cleanup", fixedPoint, CleanupAliases) ) @@ -223,7 +225,7 @@ class Analyzer( case ne: NamedExpression => ne case e if !e.resolved => u case g: Generator => MultiAlias(g, Nil) - case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)() + case c @ Cast(ne: NamedExpression, _, _) => Alias(c, ne.name)() case e: ExtractValue => Alias(e, toPrettySQL(e))() case e if optGenAliasFunc.isDefined => Alias(child, optGenAliasFunc.get.apply(e))() @@ -2312,6 +2314,18 @@ class Analyzer( } } } + + /** + * Replace [[TimeZoneAwareExpression]] without [[TimeZone]] by its copy with session local + * time zone. + */ + object ResolveTimeZone extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { + case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => + e.withTimeZone(conf.sessionLocalTimeZone) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index b8dc5f95906b5e0a44c35355092acd6f383b5398..a8fa78d41cb573ab33a13f6c70839d27e25d908b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -26,10 +26,10 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIden import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Cast, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.StructType - /** * A function defined in the catalog. * @@ -114,7 +114,9 @@ case class CatalogTablePartition( */ def toRow(partitionSchema: StructType): InternalRow = { InternalRow.fromSeq(partitionSchema.map { field => - Cast(Literal(spec(field.name)), field.dataType).eval() + // TODO: use correct timezone for partition values. + Cast(Literal(spec(field.name)), field.dataType, + Option(DateTimeUtils.defaultTimeZone().getID)).eval() }) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index ad59271e5b02c0c753ef1697611bea0c125e3dd4..a36d3507d92ec7d8ea3e9909f94977077b534813 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -131,7 +131,12 @@ object Cast { private def resolvableNullability(from: Boolean, to: Boolean) = !from || to } -/** Cast the child expression to the target data type. */ +/** + * Cast the child expression to the target data type. + * + * When cast from/to timezone related types, we need timeZoneId, which will be resolved with + * session local timezone by an analyzer [[ResolveTimeZone]]. + */ @ExpressionDescription( usage = "_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.", extended = """ @@ -139,7 +144,10 @@ object Cast { > SELECT _FUNC_('10' as int); 10 """) -case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with NullIntolerant { +case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with NullIntolerant { + + def this(child: Expression, dataType: DataType) = this(child, dataType, None) override def toString: String = s"cast($child as ${dataType.simpleString})" @@ -154,6 +162,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + // [[func]] assumes the input is no longer null because eval already does the null check. @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) @@ -162,7 +173,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case BinaryType => buildCast[Array[Byte]](_, UTF8String.fromBytes) case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d))) case TimestampType => buildCast[Long](_, - t => UTF8String.fromString(DateTimeUtils.timestampToString(t))) + t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone))) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -207,7 +218,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // TimestampConverter private[this] def castToTimestamp(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, utfs => DateTimeUtils.stringToTimestamp(utfs).orNull) + buildCast[UTF8String](_, utfs => DateTimeUtils.stringToTimestamp(utfs, timeZone).orNull) case BooleanType => buildCast[Boolean](_, b => if (b) 1L else 0) case LongType => @@ -219,7 +230,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case ByteType => buildCast[Byte](_, b => longToTimestamp(b.toLong)) case DateType => - buildCast[Int](_, d => DateTimeUtils.daysToMillis(d) * 1000) + buildCast[Int](_, d => DateTimeUtils.daysToMillis(d, timeZone) * 1000) // TimestampWritable.decimalToTimestamp case DecimalType() => buildCast[Decimal](_, d => decimalToTimestamp(d)) @@ -254,7 +265,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case TimestampType => // throw valid precision more than seconds, according to Hive. // Timestamp.nanos is in 0 to 999,999,999, no more than a second. - buildCast[Long](_, t => DateTimeUtils.millisToDays(t / 1000L)) + buildCast[Long](_, t => DateTimeUtils.millisToDays(t / 1000L, timeZone)) } // IntervalConverter @@ -531,8 +542,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c));""" case TimestampType => + val tz = ctx.addReferenceMinorObj(timeZone) (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( - org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c));""" + org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } @@ -558,8 +570,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } """ case TimestampType => + val tz = ctx.addReferenceMinorObj(timeZone) (c, evPrim, evNull) => - s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L);"; + s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L, $tz);" case _ => (c, evPrim, evNull) => s"$evNull = true;" } @@ -637,11 +650,12 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => + val tz = ctx.addReferenceMinorObj(timeZone) val longOpt = ctx.freshName("longOpt") (c, evPrim, evNull) => s""" scala.Option<Long> $longOpt = - org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c); + org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c, $tz); if ($longOpt.isDefined()) { $evPrim = ((Long) $longOpt.get()).longValue(); } else { @@ -653,8 +667,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case _: IntegralType => (c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};" case DateType => + val tz = ctx.addReferenceMinorObj(timeZone) (c, evPrim, evNull) => - s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c) * 1000;" + s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c, $tz) * 1000;" case DecimalType() => (c, evPrim, evNull) => s"$evPrim = ${decimalToTimestampCode(c)};" case DoubleType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index ef1ac360daadabd512ccc940a359416ef069bfe8..bad8a7123017bec2ca7104925aaf5e4b9f74ac7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp -import java.text.SimpleDateFormat -import java.util.{Calendar, Locale, TimeZone} +import java.text.DateFormat +import java.util.{Calendar, TimeZone} -import scala.util.Try +import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} @@ -29,6 +29,20 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +/** + * Common base class for time zone aware expressions. + */ +trait TimeZoneAwareExpression extends Expression { + + /** the timezone ID to be used to evaluate value. */ + def timeZoneId: Option[String] + + /** Returns a copy of this expression with the specified timeZoneId. */ + def withTimeZone(timeZoneId: String): TimeZoneAwareExpression + + @transient lazy val timeZone: TimeZone = TimeZone.getTimeZone(timeZoneId.get) +} + /** * Returns the current date at the start of query evaluation. * All calls of current_date within the same query return the same value. @@ -37,14 +51,21 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} */ @ExpressionDescription( usage = "_FUNC_() - Returns the current date at the start of query evaluation.") -case class CurrentDate() extends LeafExpression with CodegenFallback { +case class CurrentDate(timeZoneId: Option[String] = None) + extends LeafExpression with TimeZoneAwareExpression with CodegenFallback { + + def this() = this(None) + override def foldable: Boolean = true override def nullable: Boolean = false override def dataType: DataType = DateType + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + override def eval(input: InternalRow): Any = { - DateTimeUtils.millisToDays(System.currentTimeMillis()) + DateTimeUtils.millisToDays(System.currentTimeMillis(), timeZone) } override def prettyName: String = "current_date" @@ -78,11 +99,19 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback { * * There is no code generation since this expression should be replaced with a literal. */ -case class CurrentBatchTimestamp(timestampMs: Long, dataType: DataType) - extends LeafExpression with Nondeterministic with CodegenFallback { +case class CurrentBatchTimestamp( + timestampMs: Long, + dataType: DataType, + timeZoneId: Option[String] = None) + extends LeafExpression with TimeZoneAwareExpression with Nondeterministic with CodegenFallback { + + def this(timestampMs: Long, dataType: DataType) = this(timestampMs, dataType, None) override def nullable: Boolean = false + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + override def prettyName: String = "current_batch_timestamp" override protected def initializeInternal(partitionIndex: Int): Unit = {} @@ -96,7 +125,7 @@ case class CurrentBatchTimestamp(timestampMs: Long, dataType: DataType) def toLiteral: Literal = dataType match { case _: TimestampType => Literal(DateTimeUtils.fromJavaTimestamp(new Timestamp(timestampMs)), TimestampType) - case _: DateType => Literal(DateTimeUtils.millisToDays(timestampMs), DateType) + case _: DateType => Literal(DateTimeUtils.millisToDays(timestampMs, timeZone), DateType) } } @@ -172,19 +201,26 @@ case class DateSub(startDate: Expression, days: Expression) > SELECT _FUNC_('2009-07-30 12:58:59'); 12 """) -case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Hour(child: Expression, timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + + def this(child: Expression) = this(child, None) override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) override def dataType: DataType = IntegerType + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + override protected def nullSafeEval(timestamp: Any): Any = { - DateTimeUtils.getHours(timestamp.asInstanceOf[Long]) + DateTimeUtils.getHours(timestamp.asInstanceOf[Long], timeZone) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val tz = ctx.addReferenceMinorObj(timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$dtu.getHours($c)") + defineCodeGen(ctx, ev, c => s"$dtu.getHours($c, $tz)") } } @@ -195,19 +231,26 @@ case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInpu > SELECT _FUNC_('2009-07-30 12:58:59'); 58 """) -case class Minute(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Minute(child: Expression, timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + + def this(child: Expression) = this(child, None) override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) override def dataType: DataType = IntegerType + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + override protected def nullSafeEval(timestamp: Any): Any = { - DateTimeUtils.getMinutes(timestamp.asInstanceOf[Long]) + DateTimeUtils.getMinutes(timestamp.asInstanceOf[Long], timeZone) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val tz = ctx.addReferenceMinorObj(timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c)") + defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c, $tz)") } } @@ -218,19 +261,26 @@ case class Minute(child: Expression) extends UnaryExpression with ImplicitCastIn > SELECT _FUNC_('2009-07-30 12:58:59'); 59 """) -case class Second(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class Second(child: Expression, timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + + def this(child: Expression) = this(child, None) override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) override def dataType: DataType = IntegerType + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + override protected def nullSafeEval(timestamp: Any): Any = { - DateTimeUtils.getSeconds(timestamp.asInstanceOf[Long]) + DateTimeUtils.getSeconds(timestamp.asInstanceOf[Long], timeZone) } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val tz = ctx.addReferenceMinorObj(timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c)") + defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c, $tz)") } } @@ -401,22 +451,28 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa 2016 """) // scalastyle:on line.size.limit -case class DateFormatClass(left: Expression, right: Expression) extends BinaryExpression - with ImplicitCastInputTypes { +case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Option[String] = None) + extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + + def this(left: Expression, right: Expression) = this(left, right, None) override def dataType: DataType = StringType override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + override protected def nullSafeEval(timestamp: Any, format: Any): Any = { - val sdf = new SimpleDateFormat(format.toString, Locale.US) - UTF8String.fromString(sdf.format(new java.util.Date(timestamp.asInstanceOf[Long] / 1000))) + val df = DateTimeUtils.newDateFormat(format.toString, timeZone) + UTF8String.fromString(df.format(new java.util.Date(timestamp.asInstanceOf[Long] / 1000))) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val sdf = classOf[SimpleDateFormat].getName + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val tz = ctx.addReferenceMinorObj(timeZone) defineCodeGen(ctx, ev, (timestamp, format) => { - s"""UTF8String.fromString((new $sdf($format.toString())) + s"""UTF8String.fromString($dtu.newDateFormat($format.toString(), $tz) .format(new java.util.Date($timestamp / 1000)))""" }) } @@ -435,10 +491,20 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx > SELECT _FUNC_('2016-04-08', 'yyyy-MM-dd'); 1460041200 """) -case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends UnixTime { +case class ToUnixTimestamp( + timeExp: Expression, + format: Expression, + timeZoneId: Option[String] = None) + extends UnixTime { + + def this(timeExp: Expression, format: Expression) = this(timeExp, format, None) + override def left: Expression = timeExp override def right: Expression = format + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + def this(time: Expression) = { this(time, Literal("yyyy-MM-dd HH:mm:ss")) } @@ -465,10 +531,17 @@ case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends Unix > SELECT _FUNC_('2016-04-08', 'yyyy-MM-dd'); 1460041200 """) -case class UnixTimestamp(timeExp: Expression, format: Expression) extends UnixTime { +case class UnixTimestamp(timeExp: Expression, format: Expression, timeZoneId: Option[String] = None) + extends UnixTime { + + def this(timeExp: Expression, format: Expression) = this(timeExp, format, None) + override def left: Expression = timeExp override def right: Expression = format + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + def this(time: Expression) = { this(time, Literal("yyyy-MM-dd HH:mm:ss")) } @@ -480,7 +553,8 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) extends UnixTi override def prettyName: String = "unix_timestamp" } -abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { +abstract class UnixTime + extends BinaryExpression with TimeZoneAwareExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, DateType, TimestampType), StringType) @@ -489,8 +563,12 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { override def nullable: Boolean = true private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] - private lazy val formatter: SimpleDateFormat = - Try(new SimpleDateFormat(constFormat.toString, Locale.US)).getOrElse(null) + private lazy val formatter: DateFormat = + try { + DateTimeUtils.newDateFormat(constFormat.toString, timeZone) + } catch { + case NonFatal(_) => null + } override def eval(input: InternalRow): Any = { val t = left.eval(input) @@ -499,15 +577,19 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { } else { left.dataType match { case DateType => - DateTimeUtils.daysToMillis(t.asInstanceOf[Int]) / 1000L + DateTimeUtils.daysToMillis(t.asInstanceOf[Int], timeZone) / 1000L case TimestampType => t.asInstanceOf[Long] / 1000000L case StringType if right.foldable => if (constFormat == null || formatter == null) { null } else { - Try(formatter.parse( - t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null) + try { + formatter.parse( + t.asInstanceOf[UTF8String].toString).getTime / 1000L + } catch { + case NonFatal(_) => null + } } case StringType => val f = right.eval(input) @@ -515,8 +597,12 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { null } else { val formatString = f.asInstanceOf[UTF8String].toString - Try(new SimpleDateFormat(formatString, Locale.US).parse( - t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null) + try { + DateTimeUtils.newDateFormat(formatString, timeZone).parse( + t.asInstanceOf[UTF8String].toString).getTime / 1000L + } catch { + case NonFatal(_) => null + } } } } @@ -525,11 +611,11 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { left.dataType match { case StringType if right.foldable => - val sdf = classOf[SimpleDateFormat].getName + val df = classOf[DateFormat].getName if (formatter == null) { ExprCode("", "true", ctx.defaultValue(dataType)) } else { - val formatterName = ctx.addReferenceObj("formatter", formatter, sdf) + val formatterName = ctx.addReferenceObj("formatter", formatter, df) val eval1 = left.genCode(ctx) ev.copy(code = s""" ${eval1.code} @@ -544,12 +630,13 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { }""") } case StringType => - val sdf = classOf[SimpleDateFormat].getName + val tz = ctx.addReferenceMinorObj(timeZone) + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (string, format) => { s""" try { - ${ev.value} = - (new $sdf($format.toString())).parse($string.toString()).getTime() / 1000L; + ${ev.value} = $dtu.newDateFormat($format.toString(), $tz) + .parse($string.toString()).getTime() / 1000L; } catch (java.lang.IllegalArgumentException e) { ${ev.isNull} = true; } catch (java.text.ParseException e) { @@ -567,6 +654,7 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { ${ev.value} = ${eval1.value} / 1000000L; }""") case DateType => + val tz = ctx.addReferenceMinorObj(timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") val eval1 = left.genCode(ctx) ev.copy(code = s""" @@ -574,7 +662,7 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { boolean ${ev.isNull} = ${eval1.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.value} = $dtu.daysToMillis(${eval1.value}) / 1000L; + ${ev.value} = $dtu.daysToMillis(${eval1.value}, $tz) / 1000L; }""") } } @@ -593,8 +681,10 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { > SELECT _FUNC_(0, 'yyyy-MM-dd HH:mm:ss'); 1970-01-01 00:00:00 """) -case class FromUnixTime(sec: Expression, format: Expression) - extends BinaryExpression with ImplicitCastInputTypes { +case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[String] = None) + extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + + def this(sec: Expression, format: Expression) = this(sec, format, None) override def left: Expression = sec override def right: Expression = format @@ -610,9 +700,16 @@ case class FromUnixTime(sec: Expression, format: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType) + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] - private lazy val formatter: SimpleDateFormat = - Try(new SimpleDateFormat(constFormat.toString, Locale.US)).getOrElse(null) + private lazy val formatter: DateFormat = + try { + DateTimeUtils.newDateFormat(constFormat.toString, timeZone) + } catch { + case NonFatal(_) => null + } override def eval(input: InternalRow): Any = { val time = left.eval(input) @@ -623,30 +720,36 @@ case class FromUnixTime(sec: Expression, format: Expression) if (constFormat == null || formatter == null) { null } else { - Try(UTF8String.fromString(formatter.format( - new java.util.Date(time.asInstanceOf[Long] * 1000L)))).getOrElse(null) + try { + UTF8String.fromString(formatter.format( + new java.util.Date(time.asInstanceOf[Long] * 1000L))) + } catch { + case NonFatal(_) => null + } } } else { val f = format.eval(input) if (f == null) { null } else { - Try( - UTF8String.fromString(new SimpleDateFormat(f.toString, Locale.US). - format(new java.util.Date(time.asInstanceOf[Long] * 1000L))) - ).getOrElse(null) + try { + UTF8String.fromString(DateTimeUtils.newDateFormat(f.toString, timeZone) + .format(new java.util.Date(time.asInstanceOf[Long] * 1000L))) + } catch { + case NonFatal(_) => null + } } } } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val sdf = classOf[SimpleDateFormat].getName + val df = classOf[DateFormat].getName if (format.foldable) { if (formatter == null) { ExprCode("", "true", "(UTF8String) null") } else { - val formatterName = ctx.addReferenceObj("formatter", formatter, sdf) + val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) ev.copy(code = s""" ${t.code} @@ -662,14 +765,16 @@ case class FromUnixTime(sec: Expression, format: Expression) }""") } } else { + val tz = ctx.addReferenceMinorObj(timeZone) + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (seconds, f) => { s""" try { - ${ev.value} = UTF8String.fromString((new $sdf($f.toString())).format( + ${ev.value} = UTF8String.fromString($dtu.newDateFormat($f.toString(), $tz).format( new java.util.Date($seconds * 1000L))); } catch (java.lang.IllegalArgumentException e) { ${ev.isNull} = true; - }""".stripMargin + }""" }) } } @@ -776,8 +881,10 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) /** * Adds an interval to timestamp. */ -case class TimeAdd(start: Expression, interval: Expression) - extends BinaryExpression with ImplicitCastInputTypes { +case class TimeAdd(start: Expression, interval: Expression, timeZoneId: Option[String] = None) + extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + + def this(start: Expression, interval: Expression) = this(start, interval, None) override def left: Expression = start override def right: Expression = interval @@ -788,16 +895,20 @@ case class TimeAdd(start: Expression, interval: Expression) override def dataType: DataType = TimestampType + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + override def nullSafeEval(start: Any, interval: Any): Any = { val itvl = interval.asInstanceOf[CalendarInterval] DateTimeUtils.timestampAddInterval( - start.asInstanceOf[Long], itvl.months, itvl.microseconds) + start.asInstanceOf[Long], itvl.months, itvl.microseconds, timeZone) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val tz = ctx.addReferenceMinorObj(timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (sd, i) => { - s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds)""" + s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds, $tz)""" }) } } @@ -863,8 +974,10 @@ case class FromUTCTimestamp(left: Expression, right: Expression) /** * Subtracts an interval from timestamp. */ -case class TimeSub(start: Expression, interval: Expression) - extends BinaryExpression with ImplicitCastInputTypes { +case class TimeSub(start: Expression, interval: Expression, timeZoneId: Option[String] = None) + extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + + def this(start: Expression, interval: Expression) = this(start, interval, None) override def left: Expression = start override def right: Expression = interval @@ -875,16 +988,20 @@ case class TimeSub(start: Expression, interval: Expression) override def dataType: DataType = TimestampType + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + override def nullSafeEval(start: Any, interval: Any): Any = { val itvl = interval.asInstanceOf[CalendarInterval] DateTimeUtils.timestampAddInterval( - start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds) + start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds, timeZone) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val tz = ctx.addReferenceMinorObj(timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (sd, i) => { - s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds)""" + s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds, $tz)""" }) } } @@ -937,8 +1054,10 @@ case class AddMonths(startDate: Expression, numMonths: Expression) 3.94959677 """) // scalastyle:on line.size.limit -case class MonthsBetween(date1: Expression, date2: Expression) - extends BinaryExpression with ImplicitCastInputTypes { +case class MonthsBetween(date1: Expression, date2: Expression, timeZoneId: Option[String] = None) + extends BinaryExpression with TimeZoneAwareExpression with ImplicitCastInputTypes { + + def this(date1: Expression, date2: Expression) = this(date1, date2, None) override def left: Expression = date1 override def right: Expression = date2 @@ -947,14 +1066,18 @@ case class MonthsBetween(date1: Expression, date2: Expression) override def dataType: DataType = DoubleType + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + override def nullSafeEval(t1: Any, t2: Any): Any = { - DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long]) + DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long], timeZone) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val tz = ctx.addReferenceMinorObj(timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (l, r) => { - s"""$dtu.monthsBetween($l, $r)""" + s"""$dtu.monthsBetween($l, $r, $tz)""" }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 20b3898f8a6fe207a87798d8ecc3ce980168b3f9..55d37cce99114063c1a90be8bd54b0c2ee2bac63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -94,7 +94,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) CombineLimits, CombineUnions, // Constant folding and strength reduction - NullPropagation, + NullPropagation(conf), FoldablePropagation, OptimizeIn(conf), ConstantFolding, @@ -114,7 +114,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) Batch("Check Cartesian Products", Once, CheckCartesianProducts(conf)) :: Batch("Decimal Optimizations", fixedPoint, - DecimalAggregates) :: + DecimalAggregates(conf)) :: Batch("Typed Filter Optimization", fixedPoint, CombineTypedFilters) :: Batch("LocalRelation", fixedPoint, @@ -1026,7 +1026,7 @@ case class CheckCartesianProducts(conf: CatalystConf) * This uses the same rules for increasing the precision and scale of the output as * [[org.apache.spark.sql.catalyst.analysis.DecimalPrecision]]. */ -object DecimalAggregates extends Rule[LogicalPlan] { +case class DecimalAggregates(conf: CatalystConf) extends Rule[LogicalPlan] { import Decimal.MAX_LONG_DIGITS /** Maximum number of decimal digits representable precisely in a Double */ @@ -1044,7 +1044,7 @@ object DecimalAggregates extends Rule[LogicalPlan] { we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e)))) Cast( Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), - DecimalType(prec + 4, scale + 4)) + DecimalType(prec + 4, scale + 4), Option(conf.sessionLocalTimeZone)) case _ => we } @@ -1056,7 +1056,7 @@ object DecimalAggregates extends Rule[LogicalPlan] { val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) Cast( Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), - DecimalType(prec + 4, scale + 4)) + DecimalType(prec + 4, scale + 4), Option(conf.sessionLocalTimeZone)) case _ => ae } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 949ccdcb458cdc44d307c6bf1db0dda36fea07cd..5bfc0ce9efcdd8bddb6fe13583510e38678a9801 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -340,7 +340,7 @@ object LikeSimplification extends Rule[LogicalPlan] { * equivalent [[Literal]] values. This rule is more specific with * Null value propagation from bottom to top of the expression tree. */ -object NullPropagation extends Rule[LogicalPlan] { +case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] { private def nonNullLiteral(e: Expression): Boolean = e match { case Literal(null, _) => false case _ => true @@ -348,10 +348,10 @@ object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case e @ WindowExpression(Cast(Literal(0L, _), _), _) => - Cast(Literal(0L), e.dataType) + case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) => + Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone)) case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) => - Cast(Literal(0L), e.dataType) + Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone)) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType) @@ -518,8 +518,8 @@ case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] { */ object SimplifyCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case Cast(e, dataType) if e.dataType == dataType => e - case c @ Cast(e, dataType) => (e.dataType, dataType) match { + case Cast(e, dataType, _) if e.dataType == dataType => e + case c @ Cast(e, dataType, _) => (e.dataType, dataType) match { case (ArrayType(from, false), ArrayType(to, true)) if from == to => e case (MapType(fromKey, fromValue, false), MapType(toKey, toValue, true)) if fromKey == toKey && fromValue == toValue => e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index f20eb958fe973400d9cdcc7921dbc6b501c91fce..89e1dc9e322e05e8a949591ad25f39c26d63d474 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -17,10 +17,15 @@ package org.apache.spark.sql.catalyst.optimizer +import java.util.TimeZone + +import scala.collection.mutable + import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -41,13 +46,18 @@ object ReplaceExpressions extends Rule[LogicalPlan] { */ object ComputeCurrentTime extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - val dateExpr = CurrentDate() + val currentDates = mutable.Map.empty[String, Literal] val timeExpr = CurrentTimestamp() - val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) - val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) + val timestamp = timeExpr.eval(EmptyRow).asInstanceOf[Long] + val currentTime = Literal.create(timestamp, timeExpr.dataType) plan transformAllExpressions { - case CurrentDate() => currentDate + case CurrentDate(Some(timeZoneId)) => + currentDates.getOrElseUpdate(timeZoneId, { + Literal.create( + DateTimeUtils.millisToDays(timestamp / 1000L, TimeZone.getTimeZone(timeZoneId)), + DateType) + }) case CurrentTimestamp() => currentTime } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index a96a3b7af29125632229c73060aa27b4ac02181b..af70efbb0a91b9b30db367ab7d1f9ef45f96d7e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -60,7 +60,7 @@ object DateTimeUtils { final val TimeZoneGMT = TimeZone.getTimeZone("GMT") final val MonthOf31Days = Set(1, 3, 5, 7, 8, 10, 12) - @transient lazy val defaultTimeZone = TimeZone.getDefault + def defaultTimeZone(): TimeZone = TimeZone.getDefault() // Reuse the Calendar object in each thread as it is expensive to create in each method call. private val threadLocalGmtCalendar = new ThreadLocal[Calendar] { @@ -69,20 +69,19 @@ object DateTimeUtils { } } - // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. - private val threadLocalLocalTimeZone = new ThreadLocal[TimeZone] { - override protected def initialValue: TimeZone = { - Calendar.getInstance.getTimeZone - } - } - // `SimpleDateFormat` is not thread-safe. - val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { + private val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { override def initialValue(): SimpleDateFormat = { new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) } } + def getThreadLocalTimestampFormat(timeZone: TimeZone): DateFormat = { + val sdf = threadLocalTimestampFormat.get() + sdf.setTimeZone(timeZone) + sdf + } + // `SimpleDateFormat` is not thread-safe. private val threadLocalDateFormat = new ThreadLocal[DateFormat] { override def initialValue(): SimpleDateFormat = { @@ -90,28 +89,54 @@ object DateTimeUtils { } } + def getThreadLocalDateFormat(): DateFormat = { + val sdf = threadLocalDateFormat.get() + sdf.setTimeZone(defaultTimeZone()) + sdf + } + + def newDateFormat(formatString: String, timeZone: TimeZone): DateFormat = { + val sdf = new SimpleDateFormat(formatString, Locale.US) + sdf.setTimeZone(timeZone) + sdf + } + // we should use the exact day as Int, for example, (year, month, day) -> day def millisToDays(millisUtc: Long): SQLDate = { + millisToDays(millisUtc, defaultTimeZone()) + } + + def millisToDays(millisUtc: Long, timeZone: TimeZone): SQLDate = { // SPARK-6785: use Math.floor so negative number of days (dates before 1970) // will correctly work as input for function toJavaDate(Int) - val millisLocal = millisUtc + threadLocalLocalTimeZone.get().getOffset(millisUtc) + val millisLocal = millisUtc + timeZone.getOffset(millisUtc) Math.floor(millisLocal.toDouble / MILLIS_PER_DAY).toInt } // reverse of millisToDays def daysToMillis(days: SQLDate): Long = { + daysToMillis(days, defaultTimeZone()) + } + + def daysToMillis(days: SQLDate, timeZone: TimeZone): Long = { val millisLocal = days.toLong * MILLIS_PER_DAY - millisLocal - getOffsetFromLocalMillis(millisLocal, threadLocalLocalTimeZone.get()) + millisLocal - getOffsetFromLocalMillis(millisLocal, timeZone) } def dateToString(days: SQLDate): String = - threadLocalDateFormat.get.format(toJavaDate(days)) + getThreadLocalDateFormat.format(toJavaDate(days)) // Converts Timestamp to string according to Hive TimestampWritable convention. def timestampToString(us: SQLTimestamp): String = { + timestampToString(us, defaultTimeZone()) + } + + // Converts Timestamp to string according to Hive TimestampWritable convention. + def timestampToString(us: SQLTimestamp, timeZone: TimeZone): String = { val ts = toJavaTimestamp(us) val timestampString = ts.toString - val formatted = threadLocalTimestampFormat.get.format(ts) + val timestampFormat = getThreadLocalTimestampFormat(timeZone) + val formatted = timestampFormat.format(ts) if (timestampString.length > 19 && timestampString.substring(19) != ".0") { formatted + timestampString.substring(19) @@ -233,10 +258,14 @@ object DateTimeUtils { * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us]+[h]h:[m]m` */ def stringToTimestamp(s: UTF8String): Option[SQLTimestamp] = { + stringToTimestamp(s, defaultTimeZone()) + } + + def stringToTimestamp(s: UTF8String, timeZone: TimeZone): Option[SQLTimestamp] = { if (s == null) { return None } - var timeZone: Option[Byte] = None + var tz: Option[Byte] = None val segments: Array[Int] = Array[Int](1, 1, 1, 0, 0, 0, 0, 0, 0) var i = 0 var currentSegmentValue = 0 @@ -289,12 +318,12 @@ object DateTimeUtils { segments(i) = currentSegmentValue currentSegmentValue = 0 i += 1 - timeZone = Some(43) + tz = Some(43) } else if (b == '-' || b == '+') { segments(i) = currentSegmentValue currentSegmentValue = 0 i += 1 - timeZone = Some(b) + tz = Some(b) } else if (b == '.' && i == 5) { segments(i) = currentSegmentValue currentSegmentValue = 0 @@ -349,11 +378,11 @@ object DateTimeUtils { return None } - val c = if (timeZone.isEmpty) { - Calendar.getInstance() + val c = if (tz.isEmpty) { + Calendar.getInstance(timeZone) } else { Calendar.getInstance( - TimeZone.getTimeZone(f"GMT${timeZone.get.toChar}${segments(7)}%02d:${segments(8)}%02d")) + TimeZone.getTimeZone(f"GMT${tz.get.toChar}${segments(7)}%02d:${segments(8)}%02d")) } c.set(Calendar.MILLISECOND, 0) @@ -452,7 +481,11 @@ object DateTimeUtils { } private def localTimestamp(microsec: SQLTimestamp): SQLTimestamp = { - absoluteMicroSecond(microsec) + defaultTimeZone.getOffset(microsec / 1000) * 1000L + localTimestamp(microsec, defaultTimeZone()) + } + + private def localTimestamp(microsec: SQLTimestamp, timeZone: TimeZone): SQLTimestamp = { + absoluteMicroSecond(microsec) + timeZone.getOffset(microsec / 1000) * 1000L } /** @@ -462,6 +495,13 @@ object DateTimeUtils { ((localTimestamp(microsec) / MICROS_PER_SECOND / 3600) % 24).toInt } + /** + * Returns the hour value of a given timestamp value. The timestamp is expressed in microseconds. + */ + def getHours(microsec: SQLTimestamp, timeZone: TimeZone): Int = { + ((localTimestamp(microsec, timeZone) / MICROS_PER_SECOND / 3600) % 24).toInt + } + /** * Returns the minute value of a given timestamp value. The timestamp is expressed in * microseconds. @@ -470,6 +510,14 @@ object DateTimeUtils { ((localTimestamp(microsec) / MICROS_PER_SECOND / 60) % 60).toInt } + /** + * Returns the minute value of a given timestamp value. The timestamp is expressed in + * microseconds. + */ + def getMinutes(microsec: SQLTimestamp, timeZone: TimeZone): Int = { + ((localTimestamp(microsec, timeZone) / MICROS_PER_SECOND / 60) % 60).toInt + } + /** * Returns the second value of a given timestamp value. The timestamp is expressed in * microseconds. @@ -478,6 +526,14 @@ object DateTimeUtils { ((localTimestamp(microsec) / MICROS_PER_SECOND) % 60).toInt } + /** + * Returns the second value of a given timestamp value. The timestamp is expressed in + * microseconds. + */ + def getSeconds(microsec: SQLTimestamp, timeZone: TimeZone): Int = { + ((localTimestamp(microsec, timeZone) / MICROS_PER_SECOND) % 60).toInt + } + private[this] def isLeapYear(year: Int): Boolean = { (year % 4) == 0 && ((year % 100) != 0 || (year % 400) == 0) } @@ -742,9 +798,23 @@ object DateTimeUtils { * Returns a timestamp value, expressed in microseconds since 1.1.1970 00:00:00. */ def timestampAddInterval(start: SQLTimestamp, months: Int, microseconds: Long): SQLTimestamp = { - val days = millisToDays(start / 1000L) + timestampAddInterval(start, months, microseconds, defaultTimeZone()) + } + + /** + * Add timestamp and full interval. + * Returns a timestamp value, expressed in microseconds since 1.1.1970 00:00:00. + */ + def timestampAddInterval( + start: SQLTimestamp, + months: Int, + microseconds: Long, + timeZone: TimeZone): SQLTimestamp = { + val days = millisToDays(start / 1000L, timeZone) val newDays = dateAddMonths(days, months) - daysToMillis(newDays) * 1000L + start - daysToMillis(days) * 1000L + microseconds + start + + daysToMillis(newDays, timeZone) * 1000L - daysToMillis(days, timeZone) * 1000L + + microseconds } /** @@ -758,10 +828,24 @@ object DateTimeUtils { * 8 digits. */ def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp): Double = { + monthsBetween(time1, time2, defaultTimeZone()) + } + + /** + * Returns number of months between time1 and time2. time1 and time2 are expressed in + * microseconds since 1.1.1970. + * + * If time1 and time2 having the same day of month, or both are the last day of month, + * it returns an integer (time under a day will be ignored). + * + * Otherwise, the difference is calculated based on 31 days per month, and rounding to + * 8 digits. + */ + def monthsBetween(time1: SQLTimestamp, time2: SQLTimestamp, timeZone: TimeZone): Double = { val millis1 = time1 / 1000L val millis2 = time2 / 1000L - val date1 = millisToDays(millis1) - val date2 = millisToDays(millis2) + val date1 = millisToDays(millis1, timeZone) + val date2 = millisToDays(millis2, timeZone) val (year1, monthInYear1, dayInMonth1, daysToMonthEnd1) = splitDate(date1) val (year2, monthInYear2, dayInMonth2, daysToMonthEnd2) = splitDate(date2) @@ -772,8 +856,8 @@ object DateTimeUtils { return (months1 - months2).toDouble } // milliseconds is enough for 8 digits precision on the right side - val timeInDay1 = millis1 - daysToMillis(date1) - val timeInDay2 = millis2 - daysToMillis(date2) + val timeInDay1 = millis1 - daysToMillis(date1, timeZone) + val timeInDay2 = millis2 - daysToMillis(date2, timeZone) val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0 // rounding to 8 digits @@ -896,7 +980,7 @@ object DateTimeUtils { */ def convertTz(ts: SQLTimestamp, fromZone: TimeZone, toZone: TimeZone): SQLTimestamp = { // We always use local timezone to parse or format a timestamp - val localZone = threadLocalLocalTimeZone.get() + val localZone = defaultTimeZone() val utcTs = if (fromZone.getID == localZone.getID) { ts } else { @@ -907,9 +991,9 @@ object DateTimeUtils { if (toZone.getID == localZone.getID) { utcTs } else { - val localTs2 = utcTs + toZone.getOffset(utcTs / 1000L) * 1000L // in toZone + val localTs = utcTs + toZone.getOffset(utcTs / 1000L) * 1000L // in toZone // treat it as local timezone, convert to UTC (we could get the expected human time back) - localTs2 - getOffsetFromLocalMillis(localTs2 / 1000L, localZone) * 1000L + localTs - getOffsetFromLocalMillis(localTs / 1000L, localZone) * 1000L } } @@ -934,7 +1018,6 @@ object DateTimeUtils { */ private[util] def resetThreadLocals(): Unit = { threadLocalGmtCalendar.remove() - threadLocalLocalTimeZone.remove() threadLocalTimestampFormat.remove() threadLocalDateFormat.remove() } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 817de48de279890aeb59a68052e214ccd780dc80..81a97dc1ff3f2652391dbef38f3504a0841648a3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.TimeZone + import org.scalatest.ShouldMatchers import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} @@ -258,7 +260,8 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { val c = testRelation2.output(2) val plan = testRelation2.select('c).orderBy(Floor('a).asc) - val expected = testRelation2.select(c, a).orderBy(Floor(a.cast(DoubleType)).asc).select(c) + val expected = testRelation2.select(c, a) + .orderBy(Floor(Cast(a, DoubleType, Option(TimeZone.getDefault().getID))).asc).select(c) checkAnalysis(plan, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala index 2a0205bdc90fe043c21a11c33452a384d807d390..553b1598e7750dd7d86a49fac1d3098b953eaf51 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.TimeZone + import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -32,7 +34,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { lazy val unresolved_c = UnresolvedAttribute("c") lazy val gid = 'spark_grouping_id.int.withNullability(false) lazy val hive_gid = 'grouping__id.int.withNullability(false) - lazy val grouping_a = Cast(ShiftRight(gid, 1) & 1, ByteType) + lazy val grouping_a = Cast(ShiftRight(gid, 1) & 1, ByteType, Option(TimeZone.getDefault().getID)) lazy val nulInt = Literal(null, IntegerType) lazy val nulStr = Literal(null, StringType) lazy val r1 = LocalRelation(a, b, c) @@ -213,7 +215,8 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { val originalPlan = Filter(Grouping(unresolved_a) === 0, GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b))) - val expected = Project(Seq(a, b), Filter(Cast(grouping_a, IntegerType) === 0, + val expected = Project(Seq(a, b), + Filter(Cast(grouping_a, IntegerType, Option(TimeZone.getDefault().getID)) === 0, Aggregate(Seq(a, b, gid), Seq(a, b, gid), Expand( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index b748595fc4f2d2a35b7e4948e75421387741fb9a..8eccadbdd8afbfa95028adc932ebd9348cd75cde 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -18,12 +18,14 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} -import java.util.{Calendar, TimeZone} +import java.util.{Calendar, Locale, TimeZone} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -32,10 +34,10 @@ import org.apache.spark.unsafe.types.UTF8String */ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { - private def cast(v: Any, targetType: DataType): Cast = { + private def cast(v: Any, targetType: DataType, timeZoneId: Option[String] = None): Cast = { v match { - case lit: Expression => Cast(lit, targetType) - case _ => Cast(Literal(v), targetType) + case lit: Expression => Cast(lit, targetType, timeZoneId) + case _ => Cast(Literal(v), targetType, timeZoneId) } } @@ -45,7 +47,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } private def checkNullCast(from: DataType, to: DataType): Unit = { - checkEvaluation(Cast(Literal.create(null, from), to), null) + checkEvaluation(cast(Literal.create(null, from), to, Option("GMT")), null) } test("null cast") { @@ -107,108 +109,98 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast string to timestamp") { - checkEvaluation(Cast(Literal("123"), TimestampType), null) - - var c = Calendar.getInstance() - c.set(2015, 0, 1, 0, 0, 0) - c.set(Calendar.MILLISECOND, 0) - checkEvaluation(Cast(Literal("2015"), TimestampType), - new Timestamp(c.getTimeInMillis)) - c = Calendar.getInstance() - c.set(2015, 2, 1, 0, 0, 0) - c.set(Calendar.MILLISECOND, 0) - checkEvaluation(Cast(Literal("2015-03"), TimestampType), - new Timestamp(c.getTimeInMillis)) - c = Calendar.getInstance() - c.set(2015, 2, 18, 0, 0, 0) - c.set(Calendar.MILLISECOND, 0) - checkEvaluation(Cast(Literal("2015-03-18"), TimestampType), - new Timestamp(c.getTimeInMillis)) - checkEvaluation(Cast(Literal("2015-03-18 "), TimestampType), - new Timestamp(c.getTimeInMillis)) - checkEvaluation(Cast(Literal("2015-03-18T"), TimestampType), - new Timestamp(c.getTimeInMillis)) - - c = Calendar.getInstance() - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 0) - checkEvaluation(Cast(Literal("2015-03-18 12:03:17"), TimestampType), - new Timestamp(c.getTimeInMillis)) - checkEvaluation(Cast(Literal("2015-03-18T12:03:17"), TimestampType), - new Timestamp(c.getTimeInMillis)) - - c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 0) - checkEvaluation(Cast(Literal("2015-03-18T12:03:17Z"), TimestampType), - new Timestamp(c.getTimeInMillis)) - checkEvaluation(Cast(Literal("2015-03-18 12:03:17Z"), TimestampType), - new Timestamp(c.getTimeInMillis)) - - c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 0) - checkEvaluation(Cast(Literal("2015-03-18T12:03:17-1:0"), TimestampType), - new Timestamp(c.getTimeInMillis)) - checkEvaluation(Cast(Literal("2015-03-18T12:03:17-01:00"), TimestampType), - new Timestamp(c.getTimeInMillis)) - - c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 0) - checkEvaluation(Cast(Literal("2015-03-18T12:03:17+07:30"), TimestampType), - new Timestamp(c.getTimeInMillis)) - - c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 0) - checkEvaluation(Cast(Literal("2015-03-18T12:03:17+7:3"), TimestampType), - new Timestamp(c.getTimeInMillis)) - - c = Calendar.getInstance() - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 123) - checkEvaluation(Cast(Literal("2015-03-18 12:03:17.123"), TimestampType), - new Timestamp(c.getTimeInMillis)) - checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123"), TimestampType), - new Timestamp(c.getTimeInMillis)) - - c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 456) - checkEvaluation(Cast(Literal("2015-03-18T12:03:17.456Z"), TimestampType), - new Timestamp(c.getTimeInMillis)) - checkEvaluation(Cast(Literal("2015-03-18 12:03:17.456Z"), TimestampType), - new Timestamp(c.getTimeInMillis)) - - c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 123) - checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123-1:0"), TimestampType), - new Timestamp(c.getTimeInMillis)) - checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123-01:00"), TimestampType), - new Timestamp(c.getTimeInMillis)) - - c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 123) - checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123+07:30"), TimestampType), - new Timestamp(c.getTimeInMillis)) - - c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 123) - checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123+7:3"), TimestampType), - new Timestamp(c.getTimeInMillis)) - - checkEvaluation(Cast(Literal("2015-03-18 123142"), TimestampType), null) - checkEvaluation(Cast(Literal("2015-03-18T123123"), TimestampType), null) - checkEvaluation(Cast(Literal("2015-03-18X"), TimestampType), null) - checkEvaluation(Cast(Literal("2015/03/18"), TimestampType), null) - checkEvaluation(Cast(Literal("2015.03.18"), TimestampType), null) - checkEvaluation(Cast(Literal("20150318"), TimestampType), null) - checkEvaluation(Cast(Literal("2015-031-8"), TimestampType), null) - checkEvaluation(Cast(Literal("2015-03-18T12:03:17-0:70"), TimestampType), null) + for (tz <- ALL_TIMEZONES) { + def checkCastStringToTimestamp(str: String, expected: Timestamp): Unit = { + checkEvaluation(cast(Literal(str), TimestampType, Option(tz.getID)), expected) + } + + checkCastStringToTimestamp("123", null) + + var c = Calendar.getInstance(tz) + c.set(2015, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkCastStringToTimestamp("2015", new Timestamp(c.getTimeInMillis)) + c = Calendar.getInstance(tz) + c.set(2015, 2, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkCastStringToTimestamp("2015-03", new Timestamp(c.getTimeInMillis)) + c = Calendar.getInstance(tz) + c.set(2015, 2, 18, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkCastStringToTimestamp("2015-03-18", new Timestamp(c.getTimeInMillis)) + checkCastStringToTimestamp("2015-03-18 ", new Timestamp(c.getTimeInMillis)) + checkCastStringToTimestamp("2015-03-18T", new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(tz) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkCastStringToTimestamp("2015-03-18 12:03:17", new Timestamp(c.getTimeInMillis)) + checkCastStringToTimestamp("2015-03-18T12:03:17", new Timestamp(c.getTimeInMillis)) + + // If the string value includes timezone string, it represents the timestamp string + // in the timezone regardless of the timeZoneId parameter. + c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkCastStringToTimestamp("2015-03-18T12:03:17Z", new Timestamp(c.getTimeInMillis)) + checkCastStringToTimestamp("2015-03-18 12:03:17Z", new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkCastStringToTimestamp("2015-03-18T12:03:17-1:0", new Timestamp(c.getTimeInMillis)) + checkCastStringToTimestamp("2015-03-18T12:03:17-01:00", new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkCastStringToTimestamp("2015-03-18T12:03:17+07:30", new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkCastStringToTimestamp("2015-03-18T12:03:17+7:3", new Timestamp(c.getTimeInMillis)) + + // tests for the string including milliseconds. + c = Calendar.getInstance(tz) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkCastStringToTimestamp("2015-03-18 12:03:17.123", new Timestamp(c.getTimeInMillis)) + checkCastStringToTimestamp("2015-03-18T12:03:17.123", new Timestamp(c.getTimeInMillis)) + + // If the string value includes timezone string, it represents the timestamp string + // in the timezone regardless of the timeZoneId parameter. + c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 456) + checkCastStringToTimestamp("2015-03-18T12:03:17.456Z", new Timestamp(c.getTimeInMillis)) + checkCastStringToTimestamp("2015-03-18 12:03:17.456Z", new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkCastStringToTimestamp("2015-03-18T12:03:17.123-1:0", new Timestamp(c.getTimeInMillis)) + checkCastStringToTimestamp("2015-03-18T12:03:17.123-01:00", new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkCastStringToTimestamp("2015-03-18T12:03:17.123+07:30", new Timestamp(c.getTimeInMillis)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkCastStringToTimestamp("2015-03-18T12:03:17.123+7:3", new Timestamp(c.getTimeInMillis)) + + checkCastStringToTimestamp("2015-03-18 123142", null) + checkCastStringToTimestamp("2015-03-18T123123", null) + checkCastStringToTimestamp("2015-03-18X", null) + checkCastStringToTimestamp("2015/03/18", null) + checkCastStringToTimestamp("2015.03.18", null) + checkCastStringToTimestamp("20150318", null) + checkCastStringToTimestamp("2015-031-8", null) + checkCastStringToTimestamp("2015-03-18T12:03:17-0:70", null) + } } test("cast from int") { @@ -316,30 +308,43 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val zts = sd + " 00:00:00" val sts = sd + " 00:00:02" val nts = sts + ".1" - val ts = Timestamp.valueOf(nts) - - var c = Calendar.getInstance() - c.set(2015, 2, 8, 2, 30, 0) - checkEvaluation(cast(cast(new Timestamp(c.getTimeInMillis), StringType), TimestampType), - c.getTimeInMillis * 1000) - c = Calendar.getInstance() - c.set(2015, 10, 1, 2, 30, 0) - checkEvaluation(cast(cast(new Timestamp(c.getTimeInMillis), StringType), TimestampType), - c.getTimeInMillis * 1000) + val ts = withDefaultTimeZone(TimeZoneGMT)(Timestamp.valueOf(nts)) + + for (tz <- ALL_TIMEZONES) { + val timeZoneId = Option(tz.getID) + var c = Calendar.getInstance(TimeZoneGMT) + c.set(2015, 2, 8, 2, 30, 0) + checkEvaluation( + cast(cast(new Timestamp(c.getTimeInMillis), StringType, timeZoneId), + TimestampType, timeZoneId), + c.getTimeInMillis * 1000) + c = Calendar.getInstance(TimeZoneGMT) + c.set(2015, 10, 1, 2, 30, 0) + checkEvaluation( + cast(cast(new Timestamp(c.getTimeInMillis), StringType, timeZoneId), + TimestampType, timeZoneId), + c.getTimeInMillis * 1000) + } + + val gmtId = Option("GMT") checkEvaluation(cast("abdef", StringType), "abdef") checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null) - checkEvaluation(cast("abdef", TimestampType), null) + checkEvaluation(cast("abdef", TimestampType, gmtId), null) checkEvaluation(cast("12.65", DecimalType.SYSTEM_DEFAULT), Decimal(12.65)) checkEvaluation(cast(cast(sd, DateType), StringType), sd) checkEvaluation(cast(cast(d, StringType), DateType), 0) - checkEvaluation(cast(cast(nts, TimestampType), StringType), nts) - checkEvaluation(cast(cast(ts, StringType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts)) + checkEvaluation(cast(cast(nts, TimestampType, gmtId), StringType, gmtId), nts) + checkEvaluation( + cast(cast(ts, StringType, gmtId), TimestampType, gmtId), + DateTimeUtils.fromJavaTimestamp(ts)) // all convert to string type to check - checkEvaluation(cast(cast(cast(nts, TimestampType), DateType), StringType), sd) - checkEvaluation(cast(cast(cast(ts, DateType), TimestampType), StringType), zts) + checkEvaluation(cast(cast(cast(nts, TimestampType, gmtId), DateType, gmtId), StringType), sd) + checkEvaluation( + cast(cast(cast(ts, DateType, gmtId), TimestampType, gmtId), StringType, gmtId), + zts) checkEvaluation(cast(cast("abdef", BinaryType), StringType), "abdef") @@ -351,7 +356,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), 5.toShort) checkEvaluation( - cast(cast(cast(cast(cast(cast("5", TimestampType), ByteType), + cast(cast(cast(cast(cast(cast("5", TimestampType, gmtId), ByteType), DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), null) checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT), @@ -466,7 +471,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(d, DecimalType.SYSTEM_DEFAULT), null) checkEvaluation(cast(d, DecimalType(10, 2)), null) checkEvaluation(cast(d, StringType), "1970-01-01") - checkEvaluation(cast(cast(d, TimestampType), StringType), "1970-01-01 00:00:00") + + val gmtId = Option("GMT") + checkEvaluation(cast(cast(d, TimestampType, gmtId), StringType, gmtId), "1970-01-01 00:00:00") } test("cast from timestamp") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 35cea25ba0b7dbb178ad335a71a723c199fabd74..9978f35a0381022ee3d45e35cc3e1ca7f6b209ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.util.{Calendar, Locale} +import java.util.{Calendar, Locale, TimeZone} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -30,16 +32,29 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { import IntegralLiteralTestUtils._ + val TimeZonePST = TimeZone.getTimeZone("PST") + val TimeZoneJST = TimeZone.getTimeZone("JST") + + val gmtId = Option(TimeZoneGMT.getID) + val pstId = Option(TimeZonePST.getID) + val jstId = Option(TimeZoneJST.getID) + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) + sdf.setTimeZone(TimeZoneGMT) val sdfDate = new SimpleDateFormat("yyyy-MM-dd", Locale.US) + sdfDate.setTimeZone(TimeZoneGMT) val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) val ts = new Timestamp(sdf.parse("2013-11-08 13:10:15").getTime) test("datetime function current_date") { - val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) - val cd = CurrentDate().eval(EmptyRow).asInstanceOf[Int] - val d1 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis(), TimeZoneGMT) + val cd = CurrentDate(gmtId).eval(EmptyRow).asInstanceOf[Int] + val d1 = DateTimeUtils.millisToDays(System.currentTimeMillis(), TimeZoneGMT) assert(d0 <= cd && cd <= d1 && d1 - d0 <= 1) + + val cdjst = CurrentDate(jstId).eval(EmptyRow).asInstanceOf[Int] + val cdpst = CurrentDate(pstId).eval(EmptyRow).asInstanceOf[Int] + assert(cdpst <= cd && cd <= cdjst) } test("datetime function current_timestamp") { @@ -50,9 +65,10 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("DayOfYear") { val sdfDay = new SimpleDateFormat("D", Locale.US) + + val c = Calendar.getInstance() (0 to 3).foreach { m => (0 to 5).foreach { i => - val c = Calendar.getInstance() c.set(2000, m, 28, 0, 0, 0) c.add(Calendar.DATE, i) checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), @@ -66,8 +82,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Year") { checkEvaluation(Year(Literal.create(null, DateType)), null) checkEvaluation(Year(Literal(d)), 2015) - checkEvaluation(Year(Cast(Literal(sdfDate.format(d)), DateType)), 2015) - checkEvaluation(Year(Cast(Literal(ts), DateType)), 2013) + checkEvaluation(Year(Cast(Literal(sdfDate.format(d)), DateType, gmtId)), 2015) + checkEvaluation(Year(Cast(Literal(ts), DateType, gmtId)), 2013) val c = Calendar.getInstance() (2000 to 2002).foreach { y => @@ -86,8 +102,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Quarter") { checkEvaluation(Quarter(Literal.create(null, DateType)), null) checkEvaluation(Quarter(Literal(d)), 2) - checkEvaluation(Quarter(Cast(Literal(sdfDate.format(d)), DateType)), 2) - checkEvaluation(Quarter(Cast(Literal(ts), DateType)), 4) + checkEvaluation(Quarter(Cast(Literal(sdfDate.format(d)), DateType, gmtId)), 2) + checkEvaluation(Quarter(Cast(Literal(ts), DateType, gmtId)), 4) val c = Calendar.getInstance() (2003 to 2004).foreach { y => @@ -106,13 +122,13 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Month") { checkEvaluation(Month(Literal.create(null, DateType)), null) checkEvaluation(Month(Literal(d)), 4) - checkEvaluation(Month(Cast(Literal(sdfDate.format(d)), DateType)), 4) - checkEvaluation(Month(Cast(Literal(ts), DateType)), 11) + checkEvaluation(Month(Cast(Literal(sdfDate.format(d)), DateType, gmtId)), 4) + checkEvaluation(Month(Cast(Literal(ts), DateType, gmtId)), 11) + val c = Calendar.getInstance() (2003 to 2004).foreach { y => (0 to 3).foreach { m => (0 to 2 * 24).foreach { i => - val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.HOUR_OF_DAY, i) checkEvaluation(Month(Literal(new Date(c.getTimeInMillis))), @@ -127,11 +143,11 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(DayOfMonth(Cast(Literal("2000-02-29"), DateType)), 29) checkEvaluation(DayOfMonth(Literal.create(null, DateType)), null) checkEvaluation(DayOfMonth(Literal(d)), 8) - checkEvaluation(DayOfMonth(Cast(Literal(sdfDate.format(d)), DateType)), 8) - checkEvaluation(DayOfMonth(Cast(Literal(ts), DateType)), 8) + checkEvaluation(DayOfMonth(Cast(Literal(sdfDate.format(d)), DateType, gmtId)), 8) + checkEvaluation(DayOfMonth(Cast(Literal(ts), DateType, gmtId)), 8) + val c = Calendar.getInstance() (1999 to 2000).foreach { y => - val c = Calendar.getInstance() c.set(y, 0, 1, 0, 0, 0) (0 to 365).foreach { d => c.add(Calendar.DATE, 1) @@ -143,72 +159,114 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("Seconds") { - assert(Second(Literal.create(null, DateType)).resolved === false) - checkEvaluation(Second(Cast(Literal(d), TimestampType)), 0) - checkEvaluation(Second(Cast(Literal(sdf.format(d)), TimestampType)), 15) - checkEvaluation(Second(Literal(ts)), 15) + assert(Second(Literal.create(null, DateType), gmtId).resolved === false) + assert(Second(Cast(Literal(d), TimestampType), None).resolved === true) + checkEvaluation(Second(Cast(Literal(d), TimestampType, gmtId), gmtId), 0) + checkEvaluation(Second(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 15) + checkEvaluation(Second(Literal(ts), gmtId), 15) val c = Calendar.getInstance() - (0 to 60 by 5).foreach { s => - c.set(2015, 18, 3, 3, 5, s) - checkEvaluation(Second(Literal(new Timestamp(c.getTimeInMillis))), - c.get(Calendar.SECOND)) + for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { + val timeZoneId = Option(tz.getID) + c.setTimeZone(tz) + (0 to 60 by 5).foreach { s => + c.set(2015, 18, 3, 3, 5, s) + checkEvaluation( + Second(Literal(new Timestamp(c.getTimeInMillis)), timeZoneId), + c.get(Calendar.SECOND)) + } + checkConsistencyBetweenInterpretedAndCodegen( + (child: Expression) => Second(child, timeZoneId), TimestampType) } - checkConsistencyBetweenInterpretedAndCodegen(Second, TimestampType) } test("WeekOfYear") { checkEvaluation(WeekOfYear(Literal.create(null, DateType)), null) checkEvaluation(WeekOfYear(Literal(d)), 15) - checkEvaluation(WeekOfYear(Cast(Literal(sdfDate.format(d)), DateType)), 15) - checkEvaluation(WeekOfYear(Cast(Literal(ts), DateType)), 45) - checkEvaluation(WeekOfYear(Cast(Literal("2011-05-06"), DateType)), 18) + checkEvaluation(WeekOfYear(Cast(Literal(sdfDate.format(d)), DateType, gmtId)), 15) + checkEvaluation(WeekOfYear(Cast(Literal(ts), DateType, gmtId)), 45) + checkEvaluation(WeekOfYear(Cast(Literal("2011-05-06"), DateType, gmtId)), 18) checkConsistencyBetweenInterpretedAndCodegen(WeekOfYear, DateType) } test("DateFormat") { - checkEvaluation(DateFormatClass(Literal.create(null, TimestampType), Literal("y")), null) - checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType), - Literal.create(null, StringType)), null) - checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType), - Literal("y")), "2015") - checkEvaluation(DateFormatClass(Literal(ts), Literal("y")), "2013") + checkEvaluation( + DateFormatClass(Literal.create(null, TimestampType), Literal("y"), gmtId), + null) + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, gmtId), + Literal.create(null, StringType), gmtId), null) + + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, gmtId), + Literal("y"), gmtId), "2015") + checkEvaluation(DateFormatClass(Literal(ts), Literal("y"), gmtId), "2013") + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, gmtId), + Literal("H"), gmtId), "0") + checkEvaluation(DateFormatClass(Literal(ts), Literal("H"), gmtId), "13") + + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, pstId), + Literal("y"), pstId), "2015") + checkEvaluation(DateFormatClass(Literal(ts), Literal("y"), pstId), "2013") + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, pstId), + Literal("H"), pstId), "0") + checkEvaluation(DateFormatClass(Literal(ts), Literal("H"), pstId), "5") + + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, jstId), + Literal("y"), jstId), "2015") + checkEvaluation(DateFormatClass(Literal(ts), Literal("y"), jstId), "2013") + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, jstId), + Literal("H"), jstId), "0") + checkEvaluation(DateFormatClass(Literal(ts), Literal("H"), jstId), "22") } test("Hour") { - assert(Hour(Literal.create(null, DateType)).resolved === false) - checkEvaluation(Hour(Cast(Literal(d), TimestampType)), 0) - checkEvaluation(Hour(Cast(Literal(sdf.format(d)), TimestampType)), 13) - checkEvaluation(Hour(Literal(ts)), 13) + assert(Hour(Literal.create(null, DateType), gmtId).resolved === false) + assert(Hour(Literal(ts), None).resolved === true) + checkEvaluation(Hour(Cast(Literal(d), TimestampType, gmtId), gmtId), 0) + checkEvaluation(Hour(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 13) + checkEvaluation(Hour(Literal(ts), gmtId), 13) val c = Calendar.getInstance() - (0 to 24).foreach { h => - (0 to 60 by 15).foreach { m => - (0 to 60 by 15).foreach { s => - c.set(2015, 18, 3, h, m, s) - checkEvaluation(Hour(Literal(new Timestamp(c.getTimeInMillis))), - c.get(Calendar.HOUR_OF_DAY)) + for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { + val timeZoneId = Option(tz.getID) + c.setTimeZone(tz) + (0 to 24).foreach { h => + (0 to 60 by 15).foreach { m => + (0 to 60 by 15).foreach { s => + c.set(2015, 18, 3, h, m, s) + checkEvaluation( + Hour(Literal(new Timestamp(c.getTimeInMillis)), timeZoneId), + c.get(Calendar.HOUR_OF_DAY)) + } } } + checkConsistencyBetweenInterpretedAndCodegen( + (child: Expression) => Hour(child, timeZoneId), TimestampType) } - checkConsistencyBetweenInterpretedAndCodegen(Hour, TimestampType) } test("Minute") { - assert(Minute(Literal.create(null, DateType)).resolved === false) - checkEvaluation(Minute(Cast(Literal(d), TimestampType)), 0) - checkEvaluation(Minute(Cast(Literal(sdf.format(d)), TimestampType)), 10) - checkEvaluation(Minute(Literal(ts)), 10) + assert(Minute(Literal.create(null, DateType), gmtId).resolved === false) + assert(Minute(Literal(ts), None).resolved === true) + checkEvaluation(Minute(Cast(Literal(d), TimestampType, gmtId), gmtId), 0) + checkEvaluation( + Minute(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 10) + checkEvaluation(Minute(Literal(ts), gmtId), 10) val c = Calendar.getInstance() - (0 to 60 by 5).foreach { m => - (0 to 60 by 15).foreach { s => - c.set(2015, 18, 3, 3, m, s) - checkEvaluation(Minute(Literal(new Timestamp(c.getTimeInMillis))), - c.get(Calendar.MINUTE)) + for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { + val timeZoneId = Option(tz.getID) + c.setTimeZone(tz) + (0 to 60 by 5).foreach { m => + (0 to 60 by 15).foreach { s => + c.set(2015, 18, 3, 3, m, s) + checkEvaluation( + Minute(Literal(new Timestamp(c.getTimeInMillis)), timeZoneId), + c.get(Calendar.MINUTE)) + } } + checkConsistencyBetweenInterpretedAndCodegen( + (child: Expression) => Minute(child, timeZoneId), TimestampType) } - checkConsistencyBetweenInterpretedAndCodegen(Minute, TimestampType) } test("date_add") { @@ -250,46 +308,86 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("time_add") { - checkEvaluation( - TimeAdd(Literal(Timestamp.valueOf("2016-01-29 10:00:00")), - Literal(new CalendarInterval(1, 123000L))), - DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-29 10:00:00.123"))) + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS", Locale.US) + for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { + val timeZoneId = Option(tz.getID) + sdf.setTimeZone(tz) - checkEvaluation( - TimeAdd(Literal.create(null, TimestampType), Literal(new CalendarInterval(1, 123000L))), - null) - checkEvaluation( - TimeAdd(Literal(Timestamp.valueOf("2016-01-29 10:00:00")), - Literal.create(null, CalendarIntervalType)), - null) - checkEvaluation( - TimeAdd(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), - null) - checkConsistencyBetweenInterpretedAndCodegen(TimeAdd, TimestampType, CalendarIntervalType) + checkEvaluation( + TimeAdd( + Literal(new Timestamp(sdf.parse("2016-01-29 10:00:00.000").getTime)), + Literal(new CalendarInterval(1, 123000L)), + timeZoneId), + DateTimeUtils.fromJavaTimestamp( + new Timestamp(sdf.parse("2016-02-29 10:00:00.123").getTime))) + + checkEvaluation( + TimeAdd( + Literal.create(null, TimestampType), + Literal(new CalendarInterval(1, 123000L)), + timeZoneId), + null) + checkEvaluation( + TimeAdd( + Literal(new Timestamp(sdf.parse("2016-01-29 10:00:00.000").getTime)), + Literal.create(null, CalendarIntervalType), + timeZoneId), + null) + checkEvaluation( + TimeAdd( + Literal.create(null, TimestampType), + Literal.create(null, CalendarIntervalType), + timeZoneId), + null) + checkConsistencyBetweenInterpretedAndCodegen( + (start: Expression, interval: Expression) => TimeAdd(start, interval, timeZoneId), + TimestampType, CalendarIntervalType) + } } test("time_sub") { - checkEvaluation( - TimeSub(Literal(Timestamp.valueOf("2016-03-31 10:00:00")), - Literal(new CalendarInterval(1, 0))), - DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-29 10:00:00"))) - checkEvaluation( - TimeSub( - Literal(Timestamp.valueOf("2016-03-30 00:00:01")), - Literal(new CalendarInterval(1, 2000000.toLong))), - DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-28 23:59:59"))) + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS", Locale.US) + for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { + val timeZoneId = Option(tz.getID) + sdf.setTimeZone(tz) - checkEvaluation( - TimeSub(Literal.create(null, TimestampType), Literal(new CalendarInterval(1, 123000L))), - null) - checkEvaluation( - TimeSub(Literal(Timestamp.valueOf("2016-01-29 10:00:00")), - Literal.create(null, CalendarIntervalType)), - null) - checkEvaluation( - TimeSub(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), - null) - checkConsistencyBetweenInterpretedAndCodegen(TimeSub, TimestampType, CalendarIntervalType) + checkEvaluation( + TimeSub( + Literal(new Timestamp(sdf.parse("2016-03-31 10:00:00.000").getTime)), + Literal(new CalendarInterval(1, 0)), + timeZoneId), + DateTimeUtils.fromJavaTimestamp( + new Timestamp(sdf.parse("2016-02-29 10:00:00.000").getTime))) + checkEvaluation( + TimeSub( + Literal(new Timestamp(sdf.parse("2016-03-30 00:00:01.000").getTime)), + Literal(new CalendarInterval(1, 2000000.toLong)), + timeZoneId), + DateTimeUtils.fromJavaTimestamp( + new Timestamp(sdf.parse("2016-02-28 23:59:59.000").getTime))) + + checkEvaluation( + TimeSub( + Literal.create(null, TimestampType), + Literal(new CalendarInterval(1, 123000L)), + timeZoneId), + null) + checkEvaluation( + TimeSub( + Literal(new Timestamp(sdf.parse("2016-01-29 10:00:00.000").getTime)), + Literal.create(null, CalendarIntervalType), + timeZoneId), + null) + checkEvaluation( + TimeSub( + Literal.create(null, TimestampType), + Literal.create(null, CalendarIntervalType), + timeZoneId), + null) + checkConsistencyBetweenInterpretedAndCodegen( + (start: Expression, interval: Expression) => TimeSub(start, interval, timeZoneId), + TimestampType, CalendarIntervalType) + } } test("add_months") { @@ -313,28 +411,44 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("months_between") { - checkEvaluation( - MonthsBetween(Literal(Timestamp.valueOf("1997-02-28 10:30:00")), - Literal(Timestamp.valueOf("1996-10-30 00:00:00"))), - 3.94959677) - checkEvaluation( - MonthsBetween(Literal(Timestamp.valueOf("2015-01-30 11:52:00")), - Literal(Timestamp.valueOf("2015-01-30 11:50:00"))), - 0.0) - checkEvaluation( - MonthsBetween(Literal(Timestamp.valueOf("2015-01-31 00:00:00")), - Literal(Timestamp.valueOf("2015-03-31 22:00:00"))), - -2.0) - checkEvaluation( - MonthsBetween(Literal(Timestamp.valueOf("2015-03-31 22:00:00")), - Literal(Timestamp.valueOf("2015-02-28 00:00:00"))), - 1.0) - val t = Literal(Timestamp.valueOf("2015-03-31 22:00:00")) - val tnull = Literal.create(null, TimestampType) - checkEvaluation(MonthsBetween(t, tnull), null) - checkEvaluation(MonthsBetween(tnull, t), null) - checkEvaluation(MonthsBetween(tnull, tnull), null) - checkConsistencyBetweenInterpretedAndCodegen(MonthsBetween, TimestampType, TimestampType) + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) + for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { + val timeZoneId = Option(tz.getID) + sdf.setTimeZone(tz) + + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("1997-02-28 10:30:00").getTime)), + Literal(new Timestamp(sdf.parse("1996-10-30 00:00:00").getTime)), + timeZoneId), + 3.94959677) + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-01-30 11:52:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-01-30 11:50:00").getTime)), + timeZoneId), + 0.0) + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-01-31 00:00:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), + timeZoneId), + -2.0) + checkEvaluation( + MonthsBetween( + Literal(new Timestamp(sdf.parse("2015-03-31 22:00:00").getTime)), + Literal(new Timestamp(sdf.parse("2015-02-28 00:00:00").getTime)), + timeZoneId), + 1.0) + val t = Literal(Timestamp.valueOf("2015-03-31 22:00:00")) + val tnull = Literal.create(null, TimestampType) + checkEvaluation(MonthsBetween(t, tnull, timeZoneId), null) + checkEvaluation(MonthsBetween(tnull, t, timeZoneId), null) + checkEvaluation(MonthsBetween(tnull, tnull, timeZoneId), null) + checkConsistencyBetweenInterpretedAndCodegen( + (time1: Expression, time2: Expression) => MonthsBetween(time1, time2, timeZoneId), + TimestampType, TimestampType) + } } test("last_day") { @@ -398,7 +512,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { expected) } val date = Date.valueOf("2015-07-22") - Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach{ fmt => + Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach { fmt => testTrunc(date, fmt, Date.valueOf("2015-01-01")) } Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt => @@ -414,19 +528,32 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" val sdf2 = new SimpleDateFormat(fmt2, Locale.US) - checkEvaluation( - FromUnixTime(Literal(0L), Literal("yyyy-MM-dd HH:mm:ss")), sdf1.format(new Timestamp(0))) - checkEvaluation(FromUnixTime( - Literal(1000L), Literal("yyyy-MM-dd HH:mm:ss")), sdf1.format(new Timestamp(1000000))) - checkEvaluation( - FromUnixTime(Literal(-1000L), Literal(fmt2)), sdf2.format(new Timestamp(-1000000))) - checkEvaluation( - FromUnixTime(Literal.create(null, LongType), Literal.create(null, StringType)), null) - checkEvaluation( - FromUnixTime(Literal.create(null, LongType), Literal("yyyy-MM-dd HH:mm:ss")), null) - checkEvaluation(FromUnixTime(Literal(1000L), Literal.create(null, StringType)), null) - checkEvaluation( - FromUnixTime(Literal(0L), Literal("not a valid format")), null) + for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { + val timeZoneId = Option(tz.getID) + sdf1.setTimeZone(tz) + sdf2.setTimeZone(tz) + + checkEvaluation( + FromUnixTime(Literal(0L), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + sdf1.format(new Timestamp(0))) + checkEvaluation(FromUnixTime( + Literal(1000L), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + sdf1.format(new Timestamp(1000000))) + checkEvaluation( + FromUnixTime(Literal(-1000L), Literal(fmt2), timeZoneId), + sdf2.format(new Timestamp(-1000000))) + checkEvaluation( + FromUnixTime(Literal.create(null, LongType), Literal.create(null, StringType), timeZoneId), + null) + checkEvaluation( + FromUnixTime(Literal.create(null, LongType), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + null) + checkEvaluation( + FromUnixTime(Literal(1000L), Literal.create(null, StringType), timeZoneId), + null) + checkEvaluation( + FromUnixTime(Literal(0L), Literal("not a valid format"), timeZoneId), null) + } } test("unix_timestamp") { @@ -435,34 +562,53 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val sdf2 = new SimpleDateFormat(fmt2, Locale.US) val fmt3 = "yy-MM-dd" val sdf3 = new SimpleDateFormat(fmt3, Locale.US) - val date1 = Date.valueOf("2015-07-24") - checkEvaluation( - UnixTimestamp(Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss")), 0L) - checkEvaluation(UnixTimestamp( - Literal(sdf1.format(new Timestamp(1000000))), Literal("yyyy-MM-dd HH:mm:ss")), 1000L) - checkEvaluation( - UnixTimestamp(Literal(new Timestamp(1000000)), Literal("yyyy-MM-dd HH:mm:ss")), 1000L) - checkEvaluation( - UnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss")), - DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1)) / 1000L) - checkEvaluation( - UnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2)), -1000L) - checkEvaluation(UnixTimestamp( - Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3)), - DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24"))) / 1000L) - val t1 = UnixTimestamp( - CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] - val t2 = UnixTimestamp( - CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] - assert(t2 - t1 <= 1) - checkEvaluation( - UnixTimestamp(Literal.create(null, DateType), Literal.create(null, StringType)), null) - checkEvaluation( - UnixTimestamp(Literal.create(null, DateType), Literal("yyyy-MM-dd HH:mm:ss")), null) - checkEvaluation(UnixTimestamp( - Literal(date1), Literal.create(null, StringType)), date1.getTime / 1000L) - checkEvaluation( - UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null) + sdf3.setTimeZone(TimeZoneGMT) + + withDefaultTimeZone(TimeZoneGMT) { + for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { + val timeZoneId = Option(tz.getID) + sdf1.setTimeZone(tz) + sdf2.setTimeZone(tz) + + val date1 = Date.valueOf("2015-07-24") + checkEvaluation(UnixTimestamp( + Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), 0L) + checkEvaluation(UnixTimestamp( + Literal(sdf1.format(new Timestamp(1000000))), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + 1000L) + checkEvaluation( + UnixTimestamp( + Literal(new Timestamp(1000000)), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + 1000L) + checkEvaluation( + UnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz) / 1000L) + checkEvaluation( + UnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2), timeZoneId), + -1000L) + checkEvaluation(UnixTimestamp( + Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3), timeZoneId), + DateTimeUtils.daysToMillis( + DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz) / 1000L) + val t1 = UnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + val t2 = UnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + assert(t2 - t1 <= 1) + checkEvaluation( + UnixTimestamp( + Literal.create(null, DateType), Literal.create(null, StringType), timeZoneId), + null) + checkEvaluation( + UnixTimestamp(Literal.create(null, DateType), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + null) + checkEvaluation( + UnixTimestamp(Literal(date1), Literal.create(null, StringType), timeZoneId), + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz) / 1000L) + checkEvaluation( + UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format"), timeZoneId), null) + } + } } test("to_unix_timestamp") { @@ -471,34 +617,51 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val sdf2 = new SimpleDateFormat(fmt2, Locale.US) val fmt3 = "yy-MM-dd" val sdf3 = new SimpleDateFormat(fmt3, Locale.US) - val date1 = Date.valueOf("2015-07-24") - checkEvaluation( - ToUnixTimestamp(Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss")), 0L) - checkEvaluation(ToUnixTimestamp( - Literal(sdf1.format(new Timestamp(1000000))), Literal("yyyy-MM-dd HH:mm:ss")), 1000L) - checkEvaluation( - ToUnixTimestamp(Literal(new Timestamp(1000000)), Literal("yyyy-MM-dd HH:mm:ss")), 1000L) - checkEvaluation( - ToUnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss")), - DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1)) / 1000L) - checkEvaluation( - ToUnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2)), -1000L) - checkEvaluation(ToUnixTimestamp( - Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3)), - DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24"))) / 1000L) - val t1 = ToUnixTimestamp( - CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] - val t2 = ToUnixTimestamp( - CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] - assert(t2 - t1 <= 1) - checkEvaluation( - ToUnixTimestamp(Literal.create(null, DateType), Literal.create(null, StringType)), null) - checkEvaluation( - ToUnixTimestamp(Literal.create(null, DateType), Literal("yyyy-MM-dd HH:mm:ss")), null) - checkEvaluation(ToUnixTimestamp( - Literal(date1), Literal.create(null, StringType)), date1.getTime / 1000L) - checkEvaluation( - ToUnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null) + sdf3.setTimeZone(TimeZoneGMT) + + withDefaultTimeZone(TimeZoneGMT) { + for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { + val timeZoneId = Option(tz.getID) + sdf1.setTimeZone(tz) + sdf2.setTimeZone(tz) + + val date1 = Date.valueOf("2015-07-24") + checkEvaluation(ToUnixTimestamp( + Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), 0L) + checkEvaluation(ToUnixTimestamp( + Literal(sdf1.format(new Timestamp(1000000))), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + 1000L) + checkEvaluation(ToUnixTimestamp( + Literal(new Timestamp(1000000)), Literal("yyyy-MM-dd HH:mm:ss")), + 1000L) + checkEvaluation( + ToUnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz) / 1000L) + checkEvaluation( + ToUnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2), timeZoneId), + -1000L) + checkEvaluation(ToUnixTimestamp( + Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3), timeZoneId), + DateTimeUtils.daysToMillis( + DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz) / 1000L) + val t1 = ToUnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + val t2 = ToUnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + assert(t2 - t1 <= 1) + checkEvaluation(ToUnixTimestamp( + Literal.create(null, DateType), Literal.create(null, StringType), timeZoneId), null) + checkEvaluation( + ToUnixTimestamp( + Literal.create(null, DateType), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + null) + checkEvaluation(ToUnixTimestamp( + Literal(date1), Literal.create(null, StringType), timeZoneId), + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz) / 1000L) + checkEvaluation( + ToUnixTimestamp(Literal("2015-07-24"), Literal("not a valid format"), timeZoneId), null) + } + } } test("datediff") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala index a313681eeb8f010386f132e413cf42b39dbb6546..a0d489681fd9f5a1767918688484b8969da4ac61 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -33,7 +34,7 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("Constant Folding", FixedPoint(50), - NullPropagation, + NullPropagation(SimpleCatalystConf(caseSensitiveAnalysis = true)), ConstantFolding, BooleanSimplification, SimplifyBinaryComparison, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 8147d06969bbe9452aefb6dd01a402714cc37266..1b9db06014921b71ad6760bc53461acae4939e1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -34,7 +34,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("Constant Folding", FixedPoint(50), - NullPropagation, + NullPropagation(SimpleCatalystConf(caseSensitiveAnalysis = true)), ConstantFolding, BooleanSimplification, PruneFilters) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 87ad81db11b64187930f3e9d1b5a239df22c0721..276b8055b08d07ac043fea61922be025ed5c179b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -32,7 +33,7 @@ class CombiningLimitsSuite extends PlanTest { Batch("Combine Limit", FixedPoint(10), CombineLimits) :: Batch("Constant Folding", FixedPoint(10), - NullPropagation, + NullPropagation(SimpleCatalystConf(caseSensitiveAnalysis = true)), ConstantFolding, BooleanSimplification, SimplifyConditionals) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala index 711294ed619284ca71d0e38bad873ec8c1f02a27..a491f4433370d07c28f0ab89685f065d248a7fba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ @@ -29,7 +30,7 @@ class DecimalAggregatesSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Decimal Optimizations", FixedPoint(100), - DecimalAggregates) :: Nil + DecimalAggregates(SimpleCatalystConf(caseSensitiveAnalysis = true))) :: Nil } val testRelation = LocalRelation('a.decimal(2, 1), 'b.decimal(12, 1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 0877207728b389a5a510eddb793cd655b5d7d872..9daede1a5f9576da60625d573061a0451d37b307 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -34,7 +34,7 @@ class OptimizeInSuite extends PlanTest { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("ConstantFolding", FixedPoint(10), - NullPropagation, + NullPropagation(SimpleCatalystConf(caseSensitiveAnalysis = true)), ConstantFolding, BooleanSimplification, OptimizeIn(SimpleCatalystConf(caseSensitiveAnalysis = true))) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index a191aa8fee70278e83f504504efa0183f4fd83ac..908b370408280c3513fd2eb9196e5e577ef195ee 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.catalyst.plans +import java.util.TimeZone + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType} +import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, LongType, StringType} class ConstraintPropagationSuite extends SparkFunSuite { @@ -49,6 +51,10 @@ class ConstraintPropagationSuite extends SparkFunSuite { } } + private def castWithTimeZone(expr: Expression, dataType: DataType) = { + Cast(expr, dataType, Option(TimeZone.getDefault().getID)) + } + test("propagating constraints in filters") { val tr = LocalRelation('a.int, 'b.string, 'c.int) @@ -276,14 +282,15 @@ class ConstraintPropagationSuite extends SparkFunSuite { tr.where('a.attr === 'b.attr && 'c.attr + 100 > 'd.attr && IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))).analyze.constraints, - ExpressionSet(Seq(Cast(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"), - Cast(resolveColumn(tr, "c") + 100, LongType) > resolveColumn(tr, "d"), + ExpressionSet(Seq( + castWithTimeZone(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"), + castWithTimeZone(resolveColumn(tr, "c") + 100, LongType) > resolveColumn(tr, "d"), IsNotNull(resolveColumn(tr, "a")), IsNotNull(resolveColumn(tr, "b")), IsNotNull(resolveColumn(tr, "c")), IsNotNull(resolveColumn(tr, "d")), IsNotNull(resolveColumn(tr, "e")), - IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))))) + IsNotNull(castWithTimeZone(castWithTimeZone(resolveColumn(tr, "e"), LongType), LongType))))) } test("infer isnotnull constraints from compound expressions") { @@ -294,22 +301,25 @@ class ConstraintPropagationSuite extends SparkFunSuite { Cast( Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))).analyze.constraints, ExpressionSet(Seq( - Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b") === - Cast(resolveColumn(tr, "c"), LongType), + castWithTimeZone(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b") === + castWithTimeZone(resolveColumn(tr, "c"), LongType), IsNotNull(resolveColumn(tr, "a")), IsNotNull(resolveColumn(tr, "b")), IsNotNull(resolveColumn(tr, "c")), IsNotNull(resolveColumn(tr, "e")), - IsNotNull(Cast(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))))) + IsNotNull( + castWithTimeZone(castWithTimeZone(castWithTimeZone( + resolveColumn(tr, "e"), LongType), LongType), LongType))))) verifyConstraints( tr.where(('a.attr * 'b.attr + 100) === 'c.attr && 'd / 10 === 'e).analyze.constraints, ExpressionSet(Seq( - Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) === - Cast(resolveColumn(tr, "c"), LongType), - Cast(resolveColumn(tr, "d"), DoubleType) / - Cast(10, DoubleType) === - Cast(resolveColumn(tr, "e"), DoubleType), + castWithTimeZone(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + + castWithTimeZone(100, LongType) === + castWithTimeZone(resolveColumn(tr, "c"), LongType), + castWithTimeZone(resolveColumn(tr, "d"), DoubleType) / + castWithTimeZone(10, DoubleType) === + castWithTimeZone(resolveColumn(tr, "e"), DoubleType), IsNotNull(resolveColumn(tr, "a")), IsNotNull(resolveColumn(tr, "b")), IsNotNull(resolveColumn(tr, "c")), @@ -319,11 +329,12 @@ class ConstraintPropagationSuite extends SparkFunSuite { verifyConstraints( tr.where(('a.attr * 'b.attr - 10) >= 'c.attr && 'd / 10 < 'e).analyze.constraints, ExpressionSet(Seq( - Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >= - Cast(resolveColumn(tr, "c"), LongType), - Cast(resolveColumn(tr, "d"), DoubleType) / - Cast(10, DoubleType) < - Cast(resolveColumn(tr, "e"), DoubleType), + castWithTimeZone(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - + castWithTimeZone(10, LongType) >= + castWithTimeZone(resolveColumn(tr, "c"), LongType), + castWithTimeZone(resolveColumn(tr, "d"), DoubleType) / + castWithTimeZone(10, DoubleType) < + castWithTimeZone(resolveColumn(tr, "e"), DoubleType), IsNotNull(resolveColumn(tr, "a")), IsNotNull(resolveColumn(tr, "b")), IsNotNull(resolveColumn(tr, "c")), @@ -333,9 +344,9 @@ class ConstraintPropagationSuite extends SparkFunSuite { verifyConstraints( tr.where('a.attr + 'b.attr - 'c.attr * 'd.attr > 'e.attr * 1000).analyze.constraints, ExpressionSet(Seq( - (Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b")) - - (Cast(resolveColumn(tr, "c"), LongType) * resolveColumn(tr, "d")) > - Cast(resolveColumn(tr, "e") * 1000, LongType), + (castWithTimeZone(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b")) - + (castWithTimeZone(resolveColumn(tr, "c"), LongType) * resolveColumn(tr, "d")) > + castWithTimeZone(resolveColumn(tr, "e") * 1000, LongType), IsNotNull(resolveColumn(tr, "a")), IsNotNull(resolveColumn(tr, "b")), IsNotNull(resolveColumn(tr, "c")), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index e0a9a0c3d5c00fa560994f891cd4d42466692423..9799817494f15838e1a63e3eba6646c5e6f5e75d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -27,6 +27,8 @@ import org.apache.spark.unsafe.types.UTF8String class DateTimeUtilsSuite extends SparkFunSuite { + val TimeZonePST = TimeZone.getTimeZone("PST") + private[this] def getInUTCDays(timestamp: Long): Int = { val tz = TimeZone.getDefault ((timestamp + tz.getOffset(timestamp)) / MILLIS_PER_DAY).toInt @@ -177,180 +179,155 @@ class DateTimeUtilsSuite extends SparkFunSuite { } test("string to timestamp") { - var c = Calendar.getInstance() - c.set(1969, 11, 31, 16, 0, 0) - c.set(Calendar.MILLISECOND, 0) - assert(stringToTimestamp(UTF8String.fromString("1969-12-31 16:00:00")).get === - c.getTimeInMillis * 1000) - c.set(1, 0, 1, 0, 0, 0) - c.set(Calendar.MILLISECOND, 0) - assert(stringToTimestamp(UTF8String.fromString("0001")).get === - c.getTimeInMillis * 1000) - c = Calendar.getInstance() - c.set(2015, 2, 1, 0, 0, 0) - c.set(Calendar.MILLISECOND, 0) - assert(stringToTimestamp(UTF8String.fromString("2015-03")).get === - c.getTimeInMillis * 1000) - c = Calendar.getInstance() - c.set(2015, 2, 18, 0, 0, 0) - c.set(Calendar.MILLISECOND, 0) - assert(stringToTimestamp(UTF8String.fromString("2015-03-18")).get === - c.getTimeInMillis * 1000) - assert(stringToTimestamp(UTF8String.fromString("2015-03-18 ")).get === - c.getTimeInMillis * 1000) - assert(stringToTimestamp(UTF8String.fromString("2015-03-18T")).get === - c.getTimeInMillis * 1000) - - c = Calendar.getInstance() - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 0) - assert(stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17")).get === - c.getTimeInMillis * 1000) - assert(stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17")).get === - c.getTimeInMillis * 1000) - - c = Calendar.getInstance(TimeZone.getTimeZone("GMT-13:53")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 0) - assert(stringToTimestamp( - UTF8String.fromString("2015-03-18T12:03:17-13:53")).get === c.getTimeInMillis * 1000) - - c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 0) - assert(stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17Z")).get === - c.getTimeInMillis * 1000) - assert(stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17Z")).get === - c.getTimeInMillis * 1000) - - c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 0) - assert(stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17-1:0")).get === - c.getTimeInMillis * 1000) - assert(stringToTimestamp( - UTF8String.fromString("2015-03-18T12:03:17-01:00")).get === c.getTimeInMillis * 1000) - - c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 0) - assert(stringToTimestamp( - UTF8String.fromString("2015-03-18T12:03:17+07:30")).get === c.getTimeInMillis * 1000) - - c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 0) - assert(stringToTimestamp( - UTF8String.fromString("2015-03-18T12:03:17+07:03")).get === c.getTimeInMillis * 1000) - - c = Calendar.getInstance() - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 123) - assert(stringToTimestamp( - UTF8String.fromString("2015-03-18 12:03:17.123")).get === c.getTimeInMillis * 1000) - assert(stringToTimestamp( - UTF8String.fromString("2015-03-18T12:03:17.123")).get === c.getTimeInMillis * 1000) - - c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 456) - assert(stringToTimestamp( - UTF8String.fromString("2015-03-18T12:03:17.456Z")).get === c.getTimeInMillis * 1000) - assert(stringToTimestamp( - UTF8String.fromString("2015-03-18 12:03:17.456Z")).get === c.getTimeInMillis * 1000) - - c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 123) - assert(stringToTimestamp( - UTF8String.fromString("2015-03-18T12:03:17.123-1:0")).get === c.getTimeInMillis * 1000) - assert(stringToTimestamp( - UTF8String.fromString("2015-03-18T12:03:17.123-01:00")).get === c.getTimeInMillis * 1000) - - c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 123) - assert(stringToTimestamp( - UTF8String.fromString("2015-03-18T12:03:17.123+07:30")).get === c.getTimeInMillis * 1000) - - c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 123) - assert(stringToTimestamp( - UTF8String.fromString("2015-03-18T12:03:17.123+07:30")).get === c.getTimeInMillis * 1000) - - c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 123) - assert(stringToTimestamp( - UTF8String.fromString("2015-03-18T12:03:17.123121+7:30")).get === - c.getTimeInMillis * 1000 + 121) - - c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 123) - assert(stringToTimestamp( - UTF8String.fromString("2015-03-18T12:03:17.12312+7:30")).get === - c.getTimeInMillis * 1000 + 120) - - c = Calendar.getInstance() - c.set(Calendar.HOUR_OF_DAY, 18) - c.set(Calendar.MINUTE, 12) - c.set(Calendar.SECOND, 15) - c.set(Calendar.MILLISECOND, 0) - assert(stringToTimestamp( - UTF8String.fromString("18:12:15")).get === - c.getTimeInMillis * 1000) - - c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) - c.set(Calendar.HOUR_OF_DAY, 18) - c.set(Calendar.MINUTE, 12) - c.set(Calendar.SECOND, 15) - c.set(Calendar.MILLISECOND, 123) - assert(stringToTimestamp( - UTF8String.fromString("T18:12:15.12312+7:30")).get === - c.getTimeInMillis * 1000 + 120) - - c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) - c.set(Calendar.HOUR_OF_DAY, 18) - c.set(Calendar.MINUTE, 12) - c.set(Calendar.SECOND, 15) - c.set(Calendar.MILLISECOND, 123) - assert(stringToTimestamp( - UTF8String.fromString("18:12:15.12312+7:30")).get === - c.getTimeInMillis * 1000 + 120) - - c = Calendar.getInstance() - c.set(2011, 4, 6, 7, 8, 9) - c.set(Calendar.MILLISECOND, 100) - assert(stringToTimestamp( - UTF8String.fromString("2011-05-06 07:08:09.1000")).get === c.getTimeInMillis * 1000) - - assert(stringToTimestamp(UTF8String.fromString("238")).isEmpty) - assert(stringToTimestamp(UTF8String.fromString("00238")).isEmpty) - assert(stringToTimestamp(UTF8String.fromString("2015-03-18 123142")).isEmpty) - assert(stringToTimestamp(UTF8String.fromString("2015-03-18T123123")).isEmpty) - assert(stringToTimestamp(UTF8String.fromString("2015-03-18X")).isEmpty) - assert(stringToTimestamp(UTF8String.fromString("2015/03/18")).isEmpty) - assert(stringToTimestamp(UTF8String.fromString("2015.03.18")).isEmpty) - assert(stringToTimestamp(UTF8String.fromString("20150318")).isEmpty) - assert(stringToTimestamp(UTF8String.fromString("2015-031-8")).isEmpty) - assert(stringToTimestamp(UTF8String.fromString("02015-01-18")).isEmpty) - assert(stringToTimestamp(UTF8String.fromString("015-01-18")).isEmpty) - assert(stringToTimestamp( - UTF8String.fromString("2015-03-18T12:03.17-20:0")).isEmpty) - assert(stringToTimestamp( - UTF8String.fromString("2015-03-18T12:03.17-0:70")).isEmpty) - assert(stringToTimestamp( - UTF8String.fromString("2015-03-18T12:03.17-1:0:0")).isEmpty) + for (tz <- DateTimeTestUtils.ALL_TIMEZONES) { + def checkStringToTimestamp(str: String, expected: Option[Long]): Unit = { + assert(stringToTimestamp(UTF8String.fromString(str), tz) === expected) + } - // Truncating the fractional seconds - c = Calendar.getInstance(TimeZone.getTimeZone("GMT+00:00")) - c.set(2015, 2, 18, 12, 3, 17) - c.set(Calendar.MILLISECOND, 0) - assert(stringToTimestamp( - UTF8String.fromString("2015-03-18T12:03:17.123456789+0:00")).get === - c.getTimeInMillis * 1000 + 123456) + var c = Calendar.getInstance(tz) + c.set(1969, 11, 31, 16, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkStringToTimestamp("1969-12-31 16:00:00", Option(c.getTimeInMillis * 1000)) + c.set(1, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkStringToTimestamp("0001", Option(c.getTimeInMillis * 1000)) + c = Calendar.getInstance(tz) + c.set(2015, 2, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkStringToTimestamp("2015-03", Option(c.getTimeInMillis * 1000)) + c = Calendar.getInstance(tz) + c.set(2015, 2, 18, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkStringToTimestamp("2015-03-18", Option(c.getTimeInMillis * 1000)) + checkStringToTimestamp("2015-03-18 ", Option(c.getTimeInMillis * 1000)) + checkStringToTimestamp("2015-03-18T", Option(c.getTimeInMillis * 1000)) + + c = Calendar.getInstance(tz) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkStringToTimestamp("2015-03-18 12:03:17", Option(c.getTimeInMillis * 1000)) + checkStringToTimestamp("2015-03-18T12:03:17", Option(c.getTimeInMillis * 1000)) + + // If the string value includes timezone string, it represents the timestamp string + // in the timezone regardless of the tz parameter. + c = Calendar.getInstance(TimeZone.getTimeZone("GMT-13:53")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkStringToTimestamp("2015-03-18T12:03:17-13:53", Option(c.getTimeInMillis * 1000)) + + c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkStringToTimestamp("2015-03-18T12:03:17Z", Option(c.getTimeInMillis * 1000)) + checkStringToTimestamp("2015-03-18 12:03:17Z", Option(c.getTimeInMillis * 1000)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkStringToTimestamp("2015-03-18T12:03:17-1:0", Option(c.getTimeInMillis * 1000)) + checkStringToTimestamp("2015-03-18T12:03:17-01:00", Option(c.getTimeInMillis * 1000)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkStringToTimestamp("2015-03-18T12:03:17+07:30", Option(c.getTimeInMillis * 1000)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkStringToTimestamp("2015-03-18T12:03:17+07:03", Option(c.getTimeInMillis * 1000)) + + // tests for the string including milliseconds. + c = Calendar.getInstance(tz) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkStringToTimestamp("2015-03-18 12:03:17.123", Option(c.getTimeInMillis * 1000)) + checkStringToTimestamp("2015-03-18T12:03:17.123", Option(c.getTimeInMillis * 1000)) + + // If the string value includes timezone string, it represents the timestamp string + // in the timezone regardless of the tz parameter. + c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 456) + checkStringToTimestamp("2015-03-18T12:03:17.456Z", Option(c.getTimeInMillis * 1000)) + checkStringToTimestamp("2015-03-18 12:03:17.456Z", Option(c.getTimeInMillis * 1000)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkStringToTimestamp("2015-03-18T12:03:17.123-1:0", Option(c.getTimeInMillis * 1000)) + checkStringToTimestamp("2015-03-18T12:03:17.123-01:00", Option(c.getTimeInMillis * 1000)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkStringToTimestamp("2015-03-18T12:03:17.123+07:30", Option(c.getTimeInMillis * 1000)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkStringToTimestamp("2015-03-18T12:03:17.123+07:30", Option(c.getTimeInMillis * 1000)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkStringToTimestamp( + "2015-03-18T12:03:17.123121+7:30", Option(c.getTimeInMillis * 1000 + 121)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 123) + checkStringToTimestamp( + "2015-03-18T12:03:17.12312+7:30", Option(c.getTimeInMillis * 1000 + 120)) + + c = Calendar.getInstance(tz) + c.set(Calendar.HOUR_OF_DAY, 18) + c.set(Calendar.MINUTE, 12) + c.set(Calendar.SECOND, 15) + c.set(Calendar.MILLISECOND, 0) + checkStringToTimestamp("18:12:15", Option(c.getTimeInMillis * 1000)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(Calendar.HOUR_OF_DAY, 18) + c.set(Calendar.MINUTE, 12) + c.set(Calendar.SECOND, 15) + c.set(Calendar.MILLISECOND, 123) + checkStringToTimestamp("T18:12:15.12312+7:30", Option(c.getTimeInMillis * 1000 + 120)) + + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) + c.set(Calendar.HOUR_OF_DAY, 18) + c.set(Calendar.MINUTE, 12) + c.set(Calendar.SECOND, 15) + c.set(Calendar.MILLISECOND, 123) + checkStringToTimestamp("18:12:15.12312+7:30", Option(c.getTimeInMillis * 1000 + 120)) + + c = Calendar.getInstance(tz) + c.set(2011, 4, 6, 7, 8, 9) + c.set(Calendar.MILLISECOND, 100) + checkStringToTimestamp("2011-05-06 07:08:09.1000", Option(c.getTimeInMillis * 1000)) + + checkStringToTimestamp("238", None) + checkStringToTimestamp("00238", None) + checkStringToTimestamp("2015-03-18 123142", None) + checkStringToTimestamp("2015-03-18T123123", None) + checkStringToTimestamp("2015-03-18X", None) + checkStringToTimestamp("2015/03/18", None) + checkStringToTimestamp("2015.03.18", None) + checkStringToTimestamp("20150318", None) + checkStringToTimestamp("2015-031-8", None) + checkStringToTimestamp("02015-01-18", None) + checkStringToTimestamp("015-01-18", None) + checkStringToTimestamp("2015-03-18T12:03.17-20:0", None) + checkStringToTimestamp("2015-03-18T12:03.17-0:70", None) + checkStringToTimestamp("2015-03-18T12:03.17-1:0:0", None) + + // Truncating the fractional seconds + c = Calendar.getInstance(TimeZone.getTimeZone("GMT+00:00")) + c.set(2015, 2, 18, 12, 3, 17) + c.set(Calendar.MILLISECOND, 0) + checkStringToTimestamp( + "2015-03-18T12:03:17.123456789+0:00", Option(c.getTimeInMillis * 1000 + 123456)) + } } test("SPARK-15379: special invalid date string") { @@ -373,27 +350,35 @@ class DateTimeUtilsSuite extends SparkFunSuite { } test("hours") { - val c = Calendar.getInstance() + val c = Calendar.getInstance(TimeZonePST) c.set(2015, 2, 18, 13, 2, 11) - assert(getHours(c.getTimeInMillis * 1000) === 13) + assert(getHours(c.getTimeInMillis * 1000, TimeZonePST) === 13) + assert(getHours(c.getTimeInMillis * 1000, TimeZoneGMT) === 20) c.set(2015, 12, 8, 2, 7, 9) - assert(getHours(c.getTimeInMillis * 1000) === 2) + assert(getHours(c.getTimeInMillis * 1000, TimeZonePST) === 2) + assert(getHours(c.getTimeInMillis * 1000, TimeZoneGMT) === 10) } test("minutes") { - val c = Calendar.getInstance() + val c = Calendar.getInstance(TimeZonePST) c.set(2015, 2, 18, 13, 2, 11) - assert(getMinutes(c.getTimeInMillis * 1000) === 2) + assert(getMinutes(c.getTimeInMillis * 1000, TimeZonePST) === 2) + assert(getMinutes(c.getTimeInMillis * 1000, TimeZoneGMT) === 2) + assert(getMinutes(c.getTimeInMillis * 1000, TimeZone.getTimeZone("Australia/North")) === 32) c.set(2015, 2, 8, 2, 7, 9) - assert(getMinutes(c.getTimeInMillis * 1000) === 7) + assert(getMinutes(c.getTimeInMillis * 1000, TimeZonePST) === 7) + assert(getMinutes(c.getTimeInMillis * 1000, TimeZoneGMT) === 7) + assert(getMinutes(c.getTimeInMillis * 1000, TimeZone.getTimeZone("Australia/North")) === 37) } test("seconds") { - val c = Calendar.getInstance() + val c = Calendar.getInstance(TimeZonePST) c.set(2015, 2, 18, 13, 2, 11) - assert(getSeconds(c.getTimeInMillis * 1000) === 11) + assert(getSeconds(c.getTimeInMillis * 1000, TimeZonePST) === 11) + assert(getSeconds(c.getTimeInMillis * 1000, TimeZoneGMT) === 11) c.set(2015, 2, 8, 2, 7, 9) - assert(getSeconds(c.getTimeInMillis * 1000) === 9) + assert(getSeconds(c.getTimeInMillis * 1000, TimeZonePST) === 9) + assert(getSeconds(c.getTimeInMillis * 1000, TimeZoneGMT) === 9) } test("hours / minutes / seconds") { @@ -467,6 +452,21 @@ class DateTimeUtilsSuite extends SparkFunSuite { c2.set(Calendar.MILLISECOND, 123) val ts2 = c2.getTimeInMillis * 1000L assert(timestampAddInterval(ts1, 36, 123000) === ts2) + + val c3 = Calendar.getInstance(TimeZonePST) + c3.set(1997, 1, 27, 16, 0, 0) + c3.set(Calendar.MILLISECOND, 0) + val ts3 = c3.getTimeInMillis * 1000L + val c4 = Calendar.getInstance(TimeZonePST) + c4.set(2000, 1, 27, 16, 0, 0) + c4.set(Calendar.MILLISECOND, 123) + val ts4 = c4.getTimeInMillis * 1000L + val c5 = Calendar.getInstance(TimeZoneGMT) + c5.set(2000, 1, 29, 0, 0, 0) + c5.set(Calendar.MILLISECOND, 123) + val ts5 = c5.getTimeInMillis * 1000L + assert(timestampAddInterval(ts3, 36, 123000, TimeZonePST) === ts4) + assert(timestampAddInterval(ts3, 36, 123000, TimeZoneGMT) === ts5) } test("monthsBetween") { @@ -481,6 +481,17 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) c2.set(1996, 2, 31, 0, 0, 0) assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 11) + + val c3 = Calendar.getInstance(TimeZonePST) + c3.set(2000, 1, 28, 16, 0, 0) + val c4 = Calendar.getInstance(TimeZonePST) + c4.set(1997, 1, 28, 16, 0, 0) + assert( + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, TimeZonePST) + === 36.0) + assert( + monthsBetween(c3.getTimeInMillis * 1000L, c4.getTimeInMillis * 1000L, TimeZoneGMT) + === 35.90322581) } test("from UTC timestamp") { @@ -537,6 +548,21 @@ class DateTimeUtilsSuite extends SparkFunSuite { } test("daysToMillis and millisToDays") { + val c = Calendar.getInstance(TimeZonePST) + + c.set(2015, 11, 31, 16, 0, 0) + assert(millisToDays(c.getTimeInMillis, TimeZonePST) === 16800) + assert(millisToDays(c.getTimeInMillis, TimeZoneGMT) === 16801) + + c.set(2015, 11, 31, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(daysToMillis(16800, TimeZonePST) === c.getTimeInMillis) + + c.setTimeZone(TimeZoneGMT) + c.set(2015, 11, 31, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + assert(daysToMillis(16800, TimeZoneGMT) === c.getTimeInMillis) + // There are some days are skipped entirely in some timezone, skip them here. val skipped_days = Map[String, Int]( "Kwajalein" -> 8632, @@ -547,13 +573,11 @@ class DateTimeUtilsSuite extends SparkFunSuite { "Pacific/Kwajalein" -> 8632, "MIT" -> 15338) for (tz <- DateTimeTestUtils.ALL_TIMEZONES) { - DateTimeTestUtils.withDefaultTimeZone(tz) { - val skipped = skipped_days.getOrElse(tz.getID, Int.MinValue) - (-20000 to 20000).foreach { d => - if (d != skipped) { - assert(millisToDays(daysToMillis(d)) === d, - s"Round trip of ${d} did not work in tz ${tz}") - } + val skipped = skipped_days.getOrElse(tz.getID, Int.MinValue) + (-20000 to 20000).foreach { d => + if (d != skipped) { + assert(millisToDays(daysToMillis(d, tz), tz) === d, + s"Round trip of ${d} did not work in tz ${tz}") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 60182befd7586aabb11fe6a0cc3780d7246d03b8..38029552d13bdc045f1ba46869521f636e4d9ea5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -174,7 +174,7 @@ class Column(val expr: Expression) extends Logging { // NamedExpression under this Cast. case c: Cast => c.transformUp { - case Cast(ne: NamedExpression, to) => UnresolvedAlias(Cast(ne, to)) + case c @ Cast(_: NamedExpression, _, _) => UnresolvedAlias(c) } match { case ne: NamedExpression => ne case other => Alias(expr, usePrettyExpression(expr).sql)() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 5ee173f72e66dc3f5ef03d6129ffaaae017d5bc8..391c34f1285ed84e7a9227f3d6e68a2e4591532f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql import java.io.CharArrayWriter +import java.sql.{Date, Timestamp} +import java.util.TimeZone import scala.collection.JavaConverters._ import scala.language.implicitConversions @@ -43,7 +45,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection} -import org.apache.spark.sql.catalyst.util.usePrettyExpression +import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateViewCommand, ExplainCommand, GlobalTempView, LocalTempView} import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -250,6 +252,8 @@ class Dataset[T] private[sql]( val hasMoreData = takeResult.length > numRows val data = takeResult.take(numRows) + lazy val timeZone = TimeZone.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone) + // For array values, replace Seq and Array with square brackets // For cells that are beyond `truncate` characters, replace it with the // first `truncate-3` and "..." @@ -260,6 +264,10 @@ class Dataset[T] private[sql]( case binary: Array[Byte] => binary.map("%02X".format(_)).mkString("[", " ", "]") case array: Array[_] => array.mkString("[", ", ", "]") case seq: Seq[_] => seq.mkString("[", ", ", "]") + case d: Date => + DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d)) + case ts: Timestamp => + DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(ts), timeZone) case _ => cell.toString } if (truncate > 0 && str.length > truncate) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala index 0384c0f2360285a23348ddab0ea66037aa5ab474..d5a8566d078f9cfcce04acd15398f9554c976b8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala @@ -369,7 +369,7 @@ class SQLBuilder private ( case ar: AttributeReference if groupByAttrMap.contains(ar) => groupByAttrMap(ar) case a @ Cast(BitwiseAnd( ShiftRight(ar: AttributeReference, Literal(value: Any, IntegerType)), - Literal(1, IntegerType)), ByteType) if ar == gid => + Literal(1, IntegerType)), ByteType, _) if ar == gid => // for converting an expression to its original SQL format grouping(col) val idx = groupByExprs.length - 1 - value.asInstanceOf[Int] groupByExprs.lift(idx).map(Grouping).getOrElse(a) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index 1b7fedca8484c1a3f1bb59fc6764ec5b11d744a8..b8ac070e3a95907cce622a20846e0225f8af0642 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.internal.SQLConf @@ -104,7 +105,9 @@ case class OptimizeMetadataOnlyQuery( val partAttrs = getPartitionAttrs(relation.catalogTable.partitionColumnNames, relation) val partitionData = catalog.listPartitions(relation.catalogTable.identifier).map { p => InternalRow.fromSeq(partAttrs.map { attr => - Cast(Literal(p.spec(attr.name)), attr.dataType).eval() + // TODO: use correct timezone for partition values. + Cast(Literal(p.spec(attr.name)), attr.dataType, + Option(DateTimeUtils.defaultTimeZone().getID)).eval() }) } LocalRelation(partAttrs, partitionData) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index dcd9003ec66f5619aa056753fc984eeba29862c1..9d046c0766aa5bff3c9968a32648fdedba4fe280 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution import java.nio.charset.StandardCharsets -import java.sql.Timestamp +import java.sql.{Date, Timestamp} +import java.util.TimeZone import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} @@ -139,22 +140,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, ShortType, DateType, TimestampType, BinaryType) - /** Implementation following Hive's TimestampWritable.toString */ - def formatTimestamp(timestamp: Timestamp): String = { - val timestampString = timestamp.toString - if (timestampString.length() > 19) { - if (timestampString.length() == 21) { - if (timestampString.substring(19).compareTo(".0") == 0) { - return DateTimeUtils.threadLocalTimestampFormat.get().format(timestamp) - } - } - return DateTimeUtils.threadLocalTimestampFormat.get().format(timestamp) + - timestampString.substring(19) - } - - return DateTimeUtils.threadLocalTimestampFormat.get().format(timestamp) - } - def formatDecimal(d: java.math.BigDecimal): String = { if (d.compareTo(java.math.BigDecimal.ZERO) == 0) { java.math.BigDecimal.ZERO.toPlainString @@ -195,8 +180,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) }.toSeq.sorted.mkString("{", ",", "}") case (null, _) => "NULL" - case (d: Int, DateType) => new java.util.Date(DateTimeUtils.daysToMillis(d)).toString - case (t: Timestamp, TimestampType) => formatTimestamp(t) + case (d: Date, DateType) => + DateTimeUtils.dateToString(DateTimeUtils.fromJavaDate(d)) + case (t: Timestamp, TimestampType) => + DateTimeUtils.timestampToString(DateTimeUtils.fromJavaTimestamp(t), + TimeZone.getTimeZone(sparkSession.sessionState.conf.sessionLocalTimeZone)) case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) case (decimal: java.math.BigDecimal, DecimalType()) => formatDecimal(decimal) case (other, tpe) if primitiveTypes.contains(tpe) => other.toString diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 16c5193eda8dff1b5e4612ae5377e9d015cd3224..be13cbc51a9d3f2f636e9e48d57f414aa3043b6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{QueryExecution, SQLExecution, UnsafeKVExternalSorter} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -311,10 +312,11 @@ object FileFormatWriter extends Logging { /** Expressions that given a partition key build a string like: col1=val/col2=val/... */ private def partitionStringExpression: Seq[Expression] = { description.partitionColumns.zipWithIndex.flatMap { case (c, i) => + // TODO: use correct timezone for partition values. val escaped = ScalaUDF( ExternalCatalogUtils.escapePathName _, StringType, - Seq(Cast(c, StringType)), + Seq(Cast(c, StringType, Option(DateTimeUtils.defaultTimeZone().getID))), Seq(StringType)) val str = If(IsNull(c), Literal(ExternalCatalogUtils.DEFAULT_PARTITION_NAME), escaped) val partitionName = Literal(c.name + "=") :: str :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index fe9c6578b1e010e5d9b1782ce0217592f4044efc..75f87a5503b8c38236500d4202ab0ffe9f5a9f8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -30,6 +30,7 @@ import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -135,9 +136,11 @@ abstract class PartitioningAwareFileIndex( // we need to cast into the data type that user specified. def castPartitionValuesToUserSchema(row: InternalRow) = { InternalRow((0 until row.numFields).map { i => + // TODO: use correct timezone for partition values. Cast( Literal.create(row.getUTF8String(i), StringType), - userProvidedSchema.fields(i).dataType).eval() + userProvidedSchema.fields(i).dataType, + Option(DateTimeUtils.defaultTimeZone().getID)).eval() }: _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 6ab6fa61dc200d7bb51054e44fb325b9fbf9c6e5..bd7cec391796456178f9861fb52996306afaa009 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -58,7 +58,7 @@ class IncrementalExecution( */ override lazy val optimizedPlan: LogicalPlan = { sparkSession.sessionState.optimizer.execute(withCachedData) transformAllExpressions { - case ts @ CurrentBatchTimestamp(timestamp, _) => + case ts @ CurrentBatchTimestamp(timestamp, _, _) => logInfo(s"Current batch timestamp = $timestamp") ts.toLiteral } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index a35950e2dc17f3fe5e92fb2bcbf91e4a59688664..ea3719421b8a01818a7701a72fe378b6a9f15029 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -502,7 +502,7 @@ class StreamExecution( ct.dataType) case cd: CurrentDate => CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs, - cd.dataType) + cd.dataType, cd.timeZoneId) } reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d0c86ffc27d07e3d153f448e01eac3fdbbbd2d54..5ba4192512a598da8c131cbd91c45ff59be69a82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.internal -import java.util.{NoSuchElementException, Properties} +import java.util.{NoSuchElementException, Properties, TimeZone} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -660,6 +660,12 @@ object SQLConf { .booleanConf .createWithDefault(false) + val SESSION_LOCAL_TIMEZONE = + SQLConfigBuilder("spark.sql.session.timeZone") + .doc("""The ID of session local timezone, e.g. "GMT", "America/Los_Angeles", etc.""") + .stringConf + .createWithDefault(TimeZone.getDefault().getID()) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -858,6 +864,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { override def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) + override def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE) + def ndvMaxError: Double = getConf(NDV_MAX_ERROR) override def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index cb7b97906a7d7ffea7931eba93f9c799cb7b51c7..6a190b98ea983fda4b4f71bb61d8a232b2a841bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -29,6 +29,7 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Project, Union} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{FilterExec, QueryExecution} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} @@ -869,6 +870,30 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").filter($"key" < 0).showString(1) === expectedAnswer) } + test("SPARK-18350 show with session local timezone") { + val d = Date.valueOf("2016-12-01") + val ts = Timestamp.valueOf("2016-12-01 00:00:00") + val df = Seq((d, ts)).toDF("d", "ts") + val expectedAnswer = """+----------+-------------------+ + ||d |ts | + |+----------+-------------------+ + ||2016-12-01|2016-12-01 00:00:00| + |+----------+-------------------+ + |""".stripMargin + assert(df.showString(1, truncate = 0) === expectedAnswer) + + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { + + val expectedAnswer = """+----------+-------------------+ + ||d |ts | + |+----------+-------------------+ + ||2016-12-01|2016-12-01 08:00:00| + |+----------+-------------------+ + |""".stripMargin + assert(df.showString(1, truncate = 0) === expectedAnswer) + } + } + test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { val rowRDD = sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index d217e9b4feb6d2ba84f6f67e8e0a46c5968d1f16..f78660f7c14b6fba997b2ba94157e026bff07d5f 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -41,6 +41,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning private val originalConvertMetastoreOrc = TestHive.sessionState.convertMetastoreOrc private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled + private val originalSessionLocalTimeZone = TestHive.conf.sessionLocalTimeZone def testCases: Seq[(String, File)] = { hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) @@ -62,6 +63,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, false) // Ensures that cross joins are enabled so that we can test them TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true) + // Fix session local timezone to America/Los_Angeles for those timezone sensitive tests + // (timestamp_*) + TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, "America/Los_Angeles") RuleExecutor.resetTime() } @@ -74,6 +78,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, originalConvertMetastoreOrc) TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled) + TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone) // For debugging dump some statistics about how much time was spent in various optimizer rules logWarning(RuleExecutor.dumpTimeSpent())