diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7e5ebfc93286f5ca945e6b7e00d9ed972d4d484f..434b6ffee37fac4112262c22074e07401bea6002 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2301,6 +2301,7 @@ object EliminateEventTimeWatermark extends Rule[LogicalPlan] { object TimeWindowing extends Rule[LogicalPlan] { import org.apache.spark.sql.catalyst.dsl.expressions._ + private final val WINDOW_COL_NAME = "window" private final val WINDOW_START = "start" private final val WINDOW_END = "end" @@ -2336,49 +2337,76 @@ object TimeWindowing extends Rule[LogicalPlan] { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = - p.expressions.flatMap(_.collect { case t: TimeWindow => t }).distinct.toList // Not correct. + p.expressions.flatMap(_.collect { case t: TimeWindow => t }).toSet // Only support a single window expression for now if (windowExpressions.size == 1 && windowExpressions.head.timeColumn.resolved && windowExpressions.head.checkInputDataTypes().isSuccess) { + val window = windowExpressions.head val metadata = window.timeColumn match { case a: Attribute => a.metadata case _ => Metadata.empty } - val windowAttr = - AttributeReference("window", window.dataType, metadata = metadata)() - - val maxNumOverlapping = math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt - val windows = Seq.tabulate(maxNumOverlapping + 1) { i => - val windowId = Ceil((PreciseTimestamp(window.timeColumn) - window.startTime) / - window.slideDuration) - val windowStart = (windowId + i - maxNumOverlapping) * - window.slideDuration + window.startTime + + def getWindow(i: Int, overlappingWindows: Int): Expression = { + val division = (PreciseTimestampConversion( + window.timeColumn, TimestampType, LongType) - window.startTime) / window.slideDuration + val ceil = Ceil(division) + // if the division is equal to the ceiling, our record is the start of a window + val windowId = CaseWhen(Seq((ceil === division, ceil + 1)), Some(ceil)) + val windowStart = (windowId + i - overlappingWindows) * + window.slideDuration + window.startTime val windowEnd = windowStart + window.windowDuration CreateNamedStruct( - Literal(WINDOW_START) :: windowStart :: - Literal(WINDOW_END) :: windowEnd :: Nil) + Literal(WINDOW_START) :: + PreciseTimestampConversion(windowStart, LongType, TimestampType) :: + Literal(WINDOW_END) :: + PreciseTimestampConversion(windowEnd, LongType, TimestampType) :: + Nil) } - val projections = windows.map(_ +: p.children.head.output) + val windowAttr = AttributeReference( + WINDOW_COL_NAME, window.dataType, metadata = metadata)() + + if (window.windowDuration == window.slideDuration) { + val windowStruct = Alias(getWindow(0, 1), WINDOW_COL_NAME)( + exprId = windowAttr.exprId) + + val replacedPlan = p transformExpressions { + case t: TimeWindow => windowAttr + } + + // For backwards compatibility we add a filter to filter out nulls + val filterExpr = IsNotNull(window.timeColumn) - val filterExpr = - window.timeColumn >= windowAttr.getField(WINDOW_START) && - window.timeColumn < windowAttr.getField(WINDOW_END) + replacedPlan.withNewChildren( + Filter(filterExpr, + Project(windowStruct +: child.output, child)) :: Nil) + } else { + val overlappingWindows = + math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt + val windows = + Seq.tabulate(overlappingWindows)(i => getWindow(i, overlappingWindows)) + + val projections = windows.map(_ +: child.output) + + val filterExpr = + window.timeColumn >= windowAttr.getField(WINDOW_START) && + window.timeColumn < windowAttr.getField(WINDOW_END) - val expandedPlan = - Filter(filterExpr, + val substitutedPlan = Filter(filterExpr, Expand(projections, windowAttr +: child.output, child)) - val substitutedPlan = p transformExpressions { - case t: TimeWindow => windowAttr - } + val renamedPlan = p transformExpressions { + case t: TimeWindow => windowAttr + } - substitutedPlan.withNewChildren(expandedPlan :: Nil) + renamedPlan.withNewChildren(substitutedPlan :: Nil) + } } else if (windowExpressions.size > 1) { p.failAnalysis("Multiple time window expressions would result in a cartesian product " + "of rows, therefore they are currently not supported.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 7ff61ee479452af9d56d4f0a6ba07e6167b5b717..9a9f579b37f58afd1edbcbb4b71be1a54cbaee2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -152,12 +152,15 @@ object TimeWindow { } /** - * Expression used internally to convert the TimestampType to Long without losing + * Expression used internally to convert the TimestampType to Long and back without losing * precision, i.e. in microseconds. Used in time windowing. */ -case class PreciseTimestamp(child: Expression) extends UnaryExpression with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) - override def dataType: DataType = LongType +case class PreciseTimestampConversion( + child: Expression, + fromType: DataType, + toType: DataType) extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(fromType) + override def dataType: DataType = toType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) ev.copy(code = eval.code + @@ -165,4 +168,5 @@ case class PreciseTimestamp(child: Expression) extends UnaryExpression with Expe |${ctx.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } + override def nullSafeEval(input: Any): Any = input } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index 22d5c47a6fb51cb4e6c1ab569f22a319291d6d43..6fe356877c268e1d8b32d7ea21c2c91f9598f9ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql -import java.util.TimeZone - import org.scalatest.BeforeAndAfterEach +import org.apache.spark.sql.catalyst.plans.logical.Expand import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StringType @@ -29,11 +28,27 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B import testImplicits._ + test("simple tumbling window with record at window start") { + val df = Seq( + ("2016-03-27 19:39:30", 1, "a")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(window($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"), + Seq( + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1) + ) + ) + } + test("tumbling window groupBy statement") { val df = Seq( ("2016-03-27 19:39:34", 1, "a"), ("2016-03-27 19:39:56", 2, "a"), ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + checkAnswer( df.groupBy(window($"time", "10 seconds")) .agg(count("*").as("counts")) @@ -59,14 +74,18 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B test("tumbling window with multi-column projection") { val df = Seq( - ("2016-03-27 19:39:34", 1, "a"), - ("2016-03-27 19:39:56", 2, "a"), - ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + .select(window($"time", "10 seconds"), $"value") + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"value") + + val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand]) + assert(expands.isEmpty, "Tumbling windows shouldn't require expand") checkAnswer( - df.select(window($"time", "10 seconds"), $"value") - .orderBy($"window.start".asc) - .select($"window.start".cast("string"), $"window.end".cast("string"), $"value"), + df, Seq( Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1), @@ -104,13 +123,17 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B test("sliding window projection") { val df = Seq( - ("2016-03-27 19:39:34", 1, "a"), - ("2016-03-27 19:39:56", 2, "a"), - ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + .select(window($"time", "10 seconds", "3 seconds", "0 second"), $"value") + .orderBy($"window.start".asc, $"value".desc).select("value") + + val expands = df.queryExecution.optimizedPlan.find(_.isInstanceOf[Expand]) + assert(expands.nonEmpty, "Sliding windows require expand") checkAnswer( - df.select(window($"time", "10 seconds", "3 seconds", "0 second"), $"value") - .orderBy($"window.start".asc, $"value".desc).select("value"), + df, // 2016-03-27 19:39:27 UTC -> 4 bins // 2016-03-27 19:39:34 UTC -> 3 bins // 2016-03-27 19:39:56 UTC -> 3 bins