diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java index cf112f2e02a9555deebac4c5ea50c86364f42354..e2e7ab1d2609f3ff82eca9ca5527bebfd6efaa24 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/KeyedStateTimeout.java @@ -19,9 +19,7 @@ package org.apache.spark.sql.streaming; import org.apache.spark.annotation.Experimental; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.plans.logical.NoTimeout$; -import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout; -import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout$; +import org.apache.spark.sql.catalyst.plans.logical.*; /** * Represents the type of timeouts possible for the Dataset operations @@ -34,9 +32,23 @@ import org.apache.spark.sql.catalyst.plans.logical.ProcessingTimeTimeout$; @InterfaceStability.Evolving public class KeyedStateTimeout { - /** Timeout based on processing time. */ + /** + * Timeout based on processing time. The duration of timeout can be set for each group in + * `map/flatMapGroupsWithState` by calling `KeyedState.setTimeoutDuration()`. See documentation + * on `KeyedState` for more details. + */ public static KeyedStateTimeout ProcessingTimeTimeout() { return ProcessingTimeTimeout$.MODULE$; } - /** No timeout */ + /** + * Timeout based on event-time. The event-time timestamp for timeout can be set for each + * group in `map/flatMapGroupsWithState` by calling `KeyedState.setTimeoutTimestamp()`. + * In addition, you have to define the watermark in the query using `Dataset.withWatermark`. + * When the watermark advances beyond the set timestamp of a group and the group has not + * received any data, then the group times out. See documentation on + * `KeyedState` for more details. + */ + public static KeyedStateTimeout EventTimeTimeout() { return EventTimeTimeout$.MODULE$; } + + /** No timeout. */ public static KeyedStateTimeout NoTimeout() { return NoTimeout$.MODULE$; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index a9ff61e0e88023c00182a74d279958956e063ad0..7da7f55aa5d7f6d42408c3b9becaf66a2125f1e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -147,49 +147,69 @@ object UnsupportedOperationChecker { throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " + "streaming DataFrames/Datasets") - // mapGroupsWithState: Allowed only when no aggregation + Update output mode - case m: FlatMapGroupsWithState if m.isStreaming && m.isMapGroupsWithState => - if (collectStreamingAggregates(plan).isEmpty) { - if (outputMode != InternalOutputModes.Update) { - throwError("mapGroupsWithState is not supported with " + - s"$outputMode output mode on a streaming DataFrame/Dataset") - } else { - // Allowed when no aggregation + Update output mode - } - } else { - throwError("mapGroupsWithState is not supported with aggregation " + - "on a streaming DataFrame/Dataset") - } - - // flatMapGroupsWithState without aggregation - case m: FlatMapGroupsWithState - if m.isStreaming && collectStreamingAggregates(plan).isEmpty => - m.outputMode match { - case InternalOutputModes.Update => - if (outputMode != InternalOutputModes.Update) { - throwError("flatMapGroupsWithState in update mode is not supported with " + + // mapGroupsWithState and flatMapGroupsWithState + case m: FlatMapGroupsWithState if m.isStreaming => + + // Check compatibility with output modes and aggregations in query + val aggsAfterFlatMapGroups = collectStreamingAggregates(plan) + + if (m.isMapGroupsWithState) { // check mapGroupsWithState + // allowed only in update query output mode and without aggregation + if (aggsAfterFlatMapGroups.nonEmpty) { + throwError( + "mapGroupsWithState is not supported with aggregation " + + "on a streaming DataFrame/Dataset") + } else if (outputMode != InternalOutputModes.Update) { + throwError( + "mapGroupsWithState is not supported with " + s"$outputMode output mode on a streaming DataFrame/Dataset") + } + } else { // check latMapGroupsWithState + if (aggsAfterFlatMapGroups.isEmpty) { + // flatMapGroupsWithState without aggregation: operation's output mode must + // match query output mode + m.outputMode match { + case InternalOutputModes.Update if outputMode != InternalOutputModes.Update => + throwError( + "flatMapGroupsWithState in update mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + + case InternalOutputModes.Append if outputMode != InternalOutputModes.Append => + throwError( + "flatMapGroupsWithState in append mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + + case _ => } - case InternalOutputModes.Append => - if (outputMode != InternalOutputModes.Append) { - throwError("flatMapGroupsWithState in append mode is not supported with " + - s"$outputMode output mode on a streaming DataFrame/Dataset") + } else { + // flatMapGroupsWithState with aggregation: update operation mode not allowed, and + // *groupsWithState after aggregation not allowed + if (m.outputMode == InternalOutputModes.Update) { + throwError( + "flatMapGroupsWithState in update mode is not supported with " + + "aggregation on a streaming DataFrame/Dataset") + } else if (collectStreamingAggregates(m).nonEmpty) { + throwError( + "flatMapGroupsWithState in append mode is not supported after " + + s"aggregation on a streaming DataFrame/Dataset") } + } } - // flatMapGroupsWithState(Update) with aggregation - case m: FlatMapGroupsWithState - if m.isStreaming && m.outputMode == InternalOutputModes.Update - && collectStreamingAggregates(plan).nonEmpty => - throwError("flatMapGroupsWithState in update mode is not supported with " + - "aggregation on a streaming DataFrame/Dataset") - - // flatMapGroupsWithState(Append) with aggregation - case m: FlatMapGroupsWithState - if m.isStreaming && m.outputMode == InternalOutputModes.Append - && collectStreamingAggregates(m).nonEmpty => - throwError("flatMapGroupsWithState in append mode is not supported after " + - s"aggregation on a streaming DataFrame/Dataset") + // Check compatibility with timeout configs + if (m.timeout == EventTimeTimeout) { + // With event time timeout, watermark must be defined. + val watermarkAttributes = m.child.output.collect { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => a + } + if (watermarkAttributes.isEmpty) { + throwError( + "Watermark must be specified in the query using " + + "'[Dataset/DataFrame].withWatermark()' for using event-time timeout in a " + + "[map|flatMap]GroupsWithState. Event-time timeout not supported without " + + "watermark.")(plan) + } + } case d: Deduplicate if collectStreamingAggregates(d).nonEmpty => throwError("dropDuplicates is not supported after aggregation on a " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index d1f95faf2db0cdf42e8d5f2573e497d6aa33f6f9..e0ecf8c5f2643fd70aa9cfaeebc5f969404558ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -353,9 +353,10 @@ case class MapGroups( /** Internal class representing State */ trait LogicalKeyedState[S] -/** Possible types of timeouts used in FlatMapGroupsWithState */ +/** Types of timeouts used in FlatMapGroupsWithState */ case object NoTimeout extends KeyedStateTimeout case object ProcessingTimeTimeout extends KeyedStateTimeout +case object EventTimeTimeout extends KeyedStateTimeout /** Factory for constructing new `MapGroupsWithState` nodes. */ object FlatMapGroupsWithState { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 08216e26604004a425d475d95bf91677f46acaed..8f0a0c0d99d15e5f882ce21d6cfba0efacd500f0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -345,6 +345,22 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Append, expectedMsgs = Seq("Mixing mapGroupsWithStates and flatMapGroupsWithStates")) + // mapGroupsWithState with event time timeout + watermark + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState with event time timeout without watermark", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, + EventTimeTimeout, streamRelation), + outputMode = Update, + expectedMsgs = Seq("watermark")) + + assertSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState with event time timeout with watermark", + FlatMapGroupsWithState( + null, att, att, Seq(att), Seq(att), att, null, Update, isMapGroupsWithState = true, + EventTimeTimeout, new TestStreamingRelation(attributeWithWatermark)), + outputMode = Update) + // Deduplicate assertSupportedInStreamingPlan( "Deduplicate - Deduplicate on streaming relation before aggregation", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 9e58e8ce3d5f8741ad53f163d4adf3048170a5f3..ca2f6dd7a84b28ac6df29b413bf9ecc05058ff32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -336,8 +336,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { timeout, child) => val execPlan = FlatMapGroupsWithStateExec( func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateEnc, outputMode, - timeout, batchTimestampMs = KeyedStateImpl.NO_BATCH_PROCESSING_TIMESTAMP, - planLater(child)) + timeout, batchTimestampMs = None, eventTimeWatermark = None, planLater(child)) execPlan :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 991d8ef70756748920b0b993942c8c36e3273190..52ad70c7dc886afcf03734b09ae7bdc588ec5346 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, Expression, Literal, SortOrder, SpecificInternalRow, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalKeyedState, ProcessingTimeTimeout} -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, Expression, Literal, SortOrder, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.streaming.KeyedStateImpl.NO_TIMESTAMP import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{KeyedStateTimeout, OutputMode} -import org.apache.spark.sql.types.{BooleanType, IntegerType} +import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.CompletionIterator /** @@ -39,7 +40,7 @@ import org.apache.spark.util.CompletionIterator * @param outputObjAttr used to define the output object * @param stateEncoder used to serialize/deserialize state before calling `func` * @param outputMode the output mode of `func` - * @param timeout used to timeout groups that have not received data in a while + * @param timeoutConf used to timeout groups that have not received data in a while * @param batchTimestampMs processing timestamp of the current batch. */ case class FlatMapGroupsWithStateExec( @@ -52,11 +53,15 @@ case class FlatMapGroupsWithStateExec( stateId: Option[OperatorStateId], stateEncoder: ExpressionEncoder[Any], outputMode: OutputMode, - timeout: KeyedStateTimeout, - batchTimestampMs: Long, - child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter { + timeoutConf: KeyedStateTimeout, + batchTimestampMs: Option[Long], + override val eventTimeWatermark: Option[Long], + child: SparkPlan + ) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport { - private val isTimeoutEnabled = timeout == ProcessingTimeTimeout + import KeyedStateImpl._ + + private val isTimeoutEnabled = timeoutConf != NoTimeout private val timestampTimeoutAttribute = AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() private val stateAttributes: Seq[Attribute] = { @@ -64,8 +69,6 @@ case class FlatMapGroupsWithStateExec( if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs } - import KeyedStateImpl._ - /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(groupingAttributes) :: Nil @@ -74,9 +77,21 @@ case class FlatMapGroupsWithStateExec( override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(groupingAttributes.map(SortOrder(_, Ascending))) + override def keyExpressions: Seq[Attribute] = groupingAttributes + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver + // Throw errors early if parameters are not as expected + timeoutConf match { + case ProcessingTimeTimeout => + require(batchTimestampMs.nonEmpty) + case EventTimeTimeout => + require(eventTimeWatermark.nonEmpty) // watermark value has been populated + require(watermarkExpression.nonEmpty) // input schema has watermark attribute + case _ => + } + child.execute().mapPartitionsWithStateStore[InternalRow]( getStateId.checkpointLocation, getStateId.operatorId, @@ -84,15 +99,23 @@ case class FlatMapGroupsWithStateExec( groupingAttributes.toStructType, stateAttributes.toStructType, sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iterator) => + Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => val updater = new StateStoreUpdater(store) + // If timeout is based on event time, then filter late data based on watermark + val filteredIter = watermarkPredicateForData match { + case Some(predicate) if timeoutConf == EventTimeTimeout => + iter.filter(row => !predicate.eval(row)) + case None => + iter + } + // Generate a iterator that returns the rows grouped by the grouping function // Note that this code ensures that the filtering for timeout occurs only after // all the data has been processed. This is to ensure that the timeout information of all // the keys with data is updated before they are processed for timeouts. val outputIterator = - updater.updateStateForKeysWithData(iterator) ++ updater.updateStateForTimedOutKeys() + updater.updateStateForKeysWithData(filteredIter) ++ updater.updateStateForTimedOutKeys() // Return an iterator of all the rows generated by all the keys, such that when fully // consumed, all the state updates will be committed by the state store @@ -124,7 +147,7 @@ case class FlatMapGroupsWithStateExec( private val stateSerializer = { val encoderSerializer = stateEncoder.namedExpressions if (isTimeoutEnabled) { - encoderSerializer :+ Literal(KeyedStateImpl.TIMEOUT_TIMESTAMP_NOT_SET) + encoderSerializer :+ Literal(KeyedStateImpl.NO_TIMESTAMP) } else { encoderSerializer } @@ -157,16 +180,19 @@ case class FlatMapGroupsWithStateExec( /** Find the groups that have timeout set and are timing out right now, and call the function */ def updateStateForTimedOutKeys(): Iterator[InternalRow] = { if (isTimeoutEnabled) { + val timeoutThreshold = timeoutConf match { + case ProcessingTimeTimeout => batchTimestampMs.get + case EventTimeTimeout => eventTimeWatermark.get + case _ => + throw new IllegalStateException( + s"Cannot filter timed out keys for $timeoutConf") + } val timingOutKeys = store.filter { case (_, stateRow) => val timeoutTimestamp = getTimeoutTimestamp(stateRow) - timeoutTimestamp != TIMEOUT_TIMESTAMP_NOT_SET && timeoutTimestamp < batchTimestampMs + timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold } timingOutKeys.flatMap { case (keyRow, stateRow) => - callFunctionAndUpdateState( - keyRow, - Iterator.empty, - Some(stateRow), - hasTimedOut = true) + callFunctionAndUpdateState(keyRow, Iterator.empty, Some(stateRow), hasTimedOut = true) } } else Iterator.empty } @@ -186,7 +212,11 @@ case class FlatMapGroupsWithStateExec( val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects val stateObjOption = getStateObj(prevStateRowOption) val keyedState = new KeyedStateImpl( - stateObjOption, batchTimestampMs, isTimeoutEnabled, hasTimedOut) + stateObjOption, + batchTimestampMs.getOrElse(NO_TIMESTAMP), + eventTimeWatermark.getOrElse(NO_TIMESTAMP), + timeoutConf, + hasTimedOut) // Call function, get the returned objects and convert them to rows val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj => @@ -196,8 +226,6 @@ case class FlatMapGroupsWithStateExec( // When the iterator is consumed, then write changes to state def onIteratorCompletion: Unit = { - // Has the timeout information changed - if (keyedState.hasRemoved) { store.remove(keyRow) numUpdatedStateRows += 1 @@ -205,26 +233,25 @@ case class FlatMapGroupsWithStateExec( } else { val previousTimeoutTimestamp = prevStateRowOption match { case Some(row) => getTimeoutTimestamp(row) - case None => TIMEOUT_TIMESTAMP_NOT_SET + case None => NO_TIMESTAMP } - + val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp val stateRowToWrite = if (keyedState.hasUpdated) { getStateRow(keyedState.get) } else { prevStateRowOption.orNull } - val hasTimeoutChanged = keyedState.getTimeoutTimestamp != previousTimeoutTimestamp + val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged if (shouldWriteState) { if (stateRowToWrite == null) { // This should never happen because checks in KeyedStateImpl should avoid cases // where empty state would need to be written - throw new IllegalStateException( - "Attempting to write empty state") + throw new IllegalStateException("Attempting to write empty state") } - setTimeoutTimestamp(stateRowToWrite, keyedState.getTimeoutTimestamp) + setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp) store.put(keyRow.copy(), stateRowToWrite.copy()) numUpdatedStateRows += 1 } @@ -247,7 +274,7 @@ case class FlatMapGroupsWithStateExec( /** Returns the timeout timestamp of a state row is set */ def getTimeoutTimestamp(stateRow: UnsafeRow): Long = { - if (isTimeoutEnabled) stateRow.getLong(timeoutTimestampIndex) else TIMEOUT_TIMESTAMP_NOT_SET + if (isTimeoutEnabled) stateRow.getLong(timeoutTimestampIndex) else NO_TIMESTAMP } /** Set the timestamp in a state row */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index a934c75a024572d80a6194dc2ed1467dab00b018..0f0e4a91f8cc74e8fa01e89dc30fa46103db89f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -108,7 +108,10 @@ class IncrementalExecution( case m: FlatMapGroupsWithStateExec => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - m.copy(stateId = Some(stateId), batchTimestampMs = offsetSeqMetadata.batchTimestampMs) + m.copy( + stateId = Some(stateId), + batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), + eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala index ac421d395beb44a9236a2c1742c75ad8007f251f..edfd35bd5dd756124f75ea0413de3cefc8349013 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala @@ -17,37 +17,45 @@ package org.apache.spark.sql.execution.streaming +import java.sql.Date + import org.apache.commons.lang3.StringUtils -import org.apache.spark.sql.streaming.KeyedState +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout} +import org.apache.spark.sql.execution.streaming.KeyedStateImpl._ +import org.apache.spark.sql.streaming.{KeyedState, KeyedStateTimeout} import org.apache.spark.unsafe.types.CalendarInterval + /** * Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. * @param optionalValue Optional value of the state * @param batchProcessingTimeMs Processing time of current batch, used to calculate timestamp * for processing time timeouts - * @param isTimeoutEnabled Whether timeout is enabled. This will be used to check whether the user - * is allowed to configure timeouts. + * @param timeoutConf Type of timeout configured. Based on this, different operations will + * be supported. * @param hasTimedOut Whether the key for which this state wrapped is being created is * getting timed out or not. */ private[sql] class KeyedStateImpl[S]( optionalValue: Option[S], batchProcessingTimeMs: Long, - isTimeoutEnabled: Boolean, + eventTimeWatermarkMs: Long, + timeoutConf: KeyedStateTimeout, override val hasTimedOut: Boolean) extends KeyedState[S] { - import KeyedStateImpl._ - // Constructor to create dummy state when using mapGroupsWithState in a batch query def this(optionalValue: Option[S]) = this( - optionalValue, -1, isTimeoutEnabled = false, hasTimedOut = false) + optionalValue, + batchProcessingTimeMs = NO_TIMESTAMP, + eventTimeWatermarkMs = NO_TIMESTAMP, + timeoutConf = KeyedStateTimeout.NoTimeout, + hasTimedOut = false) private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) private var defined: Boolean = optionalValue.isDefined private var updated: Boolean = false // whether value has been updated (but not removed) private var removed: Boolean = false // whether value has been removed - private var timeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET + private var timeoutTimestamp: Long = NO_TIMESTAMP // ========= Public API ========= override def exists: Boolean = defined @@ -82,13 +90,14 @@ private[sql] class KeyedStateImpl[S]( defined = false updated = false removed = true - timeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET + timeoutTimestamp = NO_TIMESTAMP } override def setTimeoutDuration(durationMs: Long): Unit = { - if (!isTimeoutEnabled) { + if (timeoutConf != ProcessingTimeTimeout) { throw new UnsupportedOperationException( - "Cannot set timeout information without enabling timeout in map/flatMapGroupsWithState") + "Cannot set timeout duration without enabling processing time timeout in " + + "map/flatMapGroupsWithState") } if (!defined) { throw new IllegalStateException( @@ -99,7 +108,7 @@ private[sql] class KeyedStateImpl[S]( if (durationMs <= 0) { throw new IllegalArgumentException("Timeout duration must be positive") } - if (!removed && batchProcessingTimeMs != NO_BATCH_PROCESSING_TIMESTAMP) { + if (!removed && batchProcessingTimeMs != NO_TIMESTAMP) { timeoutTimestamp = durationMs + batchProcessingTimeMs } else { // This is being called in a batch query, hence no processing timestamp. @@ -108,29 +117,55 @@ private[sql] class KeyedStateImpl[S]( } override def setTimeoutDuration(duration: String): Unit = { - if (StringUtils.isBlank(duration)) { - throw new IllegalArgumentException( - "The window duration, slide duration and start time cannot be null or blank.") - } - val intervalString = if (duration.startsWith("interval")) { - duration - } else { - "interval " + duration + setTimeoutDuration(parseDuration(duration)) + } + + @throws[IllegalArgumentException]("if 'timestampMs' is not positive") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + override def setTimeoutTimestamp(timestampMs: Long): Unit = { + checkTimeoutTimestampAllowed() + if (timestampMs <= 0) { + throw new IllegalArgumentException("Timeout timestamp must be positive") } - val cal = CalendarInterval.fromString(intervalString) - if (cal == null) { + if (eventTimeWatermarkMs != NO_TIMESTAMP && timestampMs < eventTimeWatermarkMs) { throw new IllegalArgumentException( - s"The provided duration ($duration) is not valid.") + s"Timeout timestamp ($timestampMs) cannot be earlier than the " + + s"current watermark ($eventTimeWatermarkMs)") } - if (cal.milliseconds < 0 || cal.months < 0) { - throw new IllegalArgumentException("Timeout duration must be positive") + if (!removed && batchProcessingTimeMs != NO_TIMESTAMP) { + timeoutTimestamp = timestampMs + } else { + // This is being called in a batch query, hence no processing timestamp. + // Just ignore any attempts to set timeout. } + } - val delayMs = { - val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31 - cal.milliseconds + cal.months * millisPerMonth - } - setTimeoutDuration(delayMs) + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + override def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit = { + checkTimeoutTimestampAllowed() + setTimeoutTimestamp(parseDuration(additionalDuration) + timestampMs) + } + + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + override def setTimeoutTimestamp(timestamp: Date): Unit = { + checkTimeoutTimestampAllowed() + setTimeoutTimestamp(timestamp.getTime) + } + + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + override def setTimeoutTimestamp(timestamp: Date, additionalDuration: String): Unit = { + checkTimeoutTimestampAllowed() + setTimeoutTimestamp(timestamp.getTime + parseDuration(additionalDuration)) } override def toString: String = { @@ -147,14 +182,46 @@ private[sql] class KeyedStateImpl[S]( /** Return timeout timestamp or `TIMEOUT_TIMESTAMP_NOT_SET` if not set */ def getTimeoutTimestamp: Long = timeoutTimestamp + + private def parseDuration(duration: String): Long = { + if (StringUtils.isBlank(duration)) { + throw new IllegalArgumentException( + "Provided duration is null or blank.") + } + val intervalString = if (duration.startsWith("interval")) { + duration + } else { + "interval " + duration + } + val cal = CalendarInterval.fromString(intervalString) + if (cal == null) { + throw new IllegalArgumentException( + s"Provided duration ($duration) is not valid.") + } + if (cal.milliseconds < 0 || cal.months < 0) { + throw new IllegalArgumentException(s"Provided duration ($duration) is not positive") + } + + val millisPerMonth = CalendarInterval.MICROS_PER_DAY / 1000 * 31 + cal.milliseconds + cal.months * millisPerMonth + } + + private def checkTimeoutTimestampAllowed(): Unit = { + if (timeoutConf != EventTimeTimeout) { + throw new UnsupportedOperationException( + "Cannot set timeout timestamp without enabling event time timeout in " + + "map/flatMapGroupsWithState") + } + if (!defined) { + throw new IllegalStateException( + "Cannot set timeout timestamp without any state value, " + + "state has either not been initialized, or has already been removed") + } + } } private[sql] object KeyedStateImpl { - // Value used in the state row to represent the lack of any timeout timestamp - val TIMEOUT_TIMESTAMP_NOT_SET = -1L - - // Value to represent that no batch processing timestamp is passed to KeyedStateImpl. This is - // used in batch queries where there are no streaming batches and timeouts. - val NO_BATCH_PROCESSING_TIMESTAMP = -1L + // Value used represent the lack of valid timestamp as a long + val NO_TIMESTAMP = -1L } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 6d2de441eb44c13670a78b4cdd9cca684d014252..f72144a25d5ccd03b00f6ae4beda25198067c6eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -80,7 +80,7 @@ trait WatermarkSupport extends UnaryExecNode { /** Generate an expression that matches data older than the watermark */ lazy val watermarkExpression: Option[Expression] = { val optionalWatermarkAttribute = - keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)) + child.output.find(_.metadata.contains(EventTimeWatermark.delayKey)) optionalWatermarkAttribute.map { watermarkAttribute => // If we are evicting based on a window, use the end of the window. Otherwise just @@ -101,14 +101,12 @@ trait WatermarkSupport extends UnaryExecNode { } } - /** Generate a predicate based on keys that matches data older than the watermark */ + /** Predicate based on keys that matches data older than the watermark */ lazy val watermarkPredicateForKeys: Option[Predicate] = watermarkExpression.map(newPredicate(_, keyExpressions)) - /** - * Generate a predicate based on the child output that matches data older than the watermark. - */ - lazy val watermarkPredicate: Option[Predicate] = + /** Predicate based on the child output that matches data older than the watermark. */ + lazy val watermarkPredicateForData: Option[Predicate] = watermarkExpression.map(newPredicate(_, child.output)) } @@ -218,7 +216,7 @@ case class StateStoreSaveExec( new Iterator[InternalRow] { // Filter late date using watermark if specified - private[this] val baseIterator = watermarkPredicate match { + private[this] val baseIterator = watermarkPredicateForData match { case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) case None => iter } @@ -285,7 +283,7 @@ case class StreamingDeduplicateExec( val numTotalStateRows = longMetric("numTotalStateRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") - val baseIterator = watermarkPredicate match { + val baseIterator = watermarkPredicateForData match { case Some(predicate) => iter.filter(row => !predicate.eval(row)) case None => iter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala index 6b4b1ced98a34d98a21e18e1fff395d71ed37e5e..461de04f6bbe2c5211c1d251c2d5428e26f79ab2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/KeyedState.scala @@ -55,7 +55,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * batch, nor with streaming Datasets. * - All the data will be shuffled before applying the function. * - If timeout is set, then the function will also be called with no values. - * See more details on KeyedStateTimeout` below. + * See more details on `KeyedStateTimeout` below. * * Important points to note about using `KeyedState`. * - The value of the state cannot be null. So updating state with null will throw @@ -68,20 +68,38 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * * Important points to note about using `KeyedStateTimeout`. * - The timeout type is a global param across all the keys (set as `timeout` param in - * `[map|flatMap]GroupsWithState`, but the exact timeout duration is configurable per key - * (by calling `setTimeout...()` in `KeyedState`). - * - When the timeout occurs for a key, the function is called with no values, and + * `[map|flatMap]GroupsWithState`, but the exact timeout duration/timestamp is configurable per + * key by calling `setTimeout...()` in `KeyedState`. + * - Timeouts can be either based on processing time (i.e. + * [[KeyedStateTimeout.ProcessingTimeTimeout]]) or event time (i.e. + * [[KeyedStateTimeout.EventTimeTimeout]]). + * - With `ProcessingTimeTimeout`, the timeout duration can be set by calling + * `KeyedState.setTimeoutDuration`. The timeout will occur when the clock has advanced by the set + * duration. Guarantees provided by this timeout with a duration of D ms are as follows: + * - Timeout will never be occur before the clock time has advanced by D ms + * - Timeout will occur eventually when there is a trigger in the query + * (i.e. after D ms). So there is a no strict upper bound on when the timeout would occur. + * For example, the trigger interval of the query will affect when the timeout actually occurs. + * If there is no data in the stream (for any key) for a while, then their will not be + * any trigger and timeout function call will not occur until there is data. + * - Since the processing time timeout is based on the clock time, it is affected by the + * variations in the system clock (i.e. time zone changes, clock skew, etc.). + * - With `EventTimeTimeout`, the user also has to specify the the the event time watermark in + * the query using `Dataset.withWatermark()`. With this setting, data that is older than the + * watermark are filtered out. The timeout can be enabled for a key by setting a timestamp using + * `KeyedState.setTimeoutTimestamp()`, and the timeout would occur when the watermark advances + * beyond the set timestamp. You can control the timeout delay by two parameters - (i) watermark + * delay and an additional duration beyond the timestamp in the event (which is guaranteed to + * > watermark due to the filtering). Guarantees provided by this timeout are as follows: + * - Timeout will never be occur before watermark has exceeded the set timeout. + * - Similar to processing time timeouts, there is a no strict upper bound on the delay when + * the timeout actually occurs. The watermark can advance only when there is data in the + * stream, and the event time of the data has actually advanced. + * - When the timeout occurs for a key, the function is called for that key with no values, and * `KeyedState.hasTimedOut()` set to true. * - The timeout is reset for key every time the function is called on the key, that is, * when the key has new data, or the key has timed out. So the user has to set the timeout * duration every time the function is called, otherwise there will not be any timeout set. - * - Guarantees provided on processing-time-based timeout of key, when timeout duration is D ms: - * - Timeout will never be called before real clock time has advanced by D ms - * - Timeout will be called eventually when there is a trigger in the query - * (i.e. after D ms). So there is a no strict upper bound on when the timeout would occur. - * For example, the trigger interval of the query will affect when the timeout is actually hit. - * If there is no data in the stream (for any key) for a while, then their will not be - * any trigger and timeout will not be hit until there is data. * * Scala example of using KeyedState in `mapGroupsWithState`: * {{{ @@ -194,7 +212,8 @@ trait KeyedState[S] extends LogicalKeyedState[S] { /** * Set the timeout duration in ms for this key. - * @note Timeouts must be enabled in `[map/flatmap]GroupsWithStates`. + * + * @note ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. */ @throws[IllegalArgumentException]("if 'durationMs' is not positive") @throws[IllegalStateException]("when state is either not initialized, or already removed") @@ -204,11 +223,63 @@ trait KeyedState[S] extends LogicalKeyedState[S] { /** * Set the timeout duration for this key as a string. For example, "1 hour", "2 days", etc. - * @note, Timeouts must be enabled in `[map/flatmap]GroupsWithStates`. + * + * @note, ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. */ @throws[IllegalArgumentException]("if 'duration' is not a valid duration") @throws[IllegalStateException]("when state is either not initialized, or already removed") @throws[UnsupportedOperationException]( "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") def setTimeoutDuration(duration: String): Unit + + @throws[IllegalArgumentException]("if 'timestampMs' is not positive") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** + * Set the timeout timestamp for this key as milliseconds in epoch time. + * This timestamp cannot be older than the current watermark. + * + * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + */ + def setTimeoutTimestamp(timestampMs: Long): Unit + + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** + * Set the timeout timestamp for this key as milliseconds in epoch time and an additional + * duration as a string (e.g. "1 hour", "2 days", etc.). + * The final timestamp (including the additional duration) cannot be older than the + * current watermark. + * + * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + */ + def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit + + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** + * Set the timeout timestamp for this key as a java.sql.Date. + * This timestamp cannot be older than the current watermark. + * + * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + */ + def setTimeoutTimestamp(timestamp: java.sql.Date): Unit + + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[IllegalStateException]("when state is either not initialized, or already removed") + @throws[UnsupportedOperationException]( + "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** + * Set the timeout timestamp for this key as a java.sql.Date and an additional + * duration as a string (e.g. "1 hour", "2 days", etc.). + * The final timestamp (including the additional duration) cannot be older than the + * current watermark. + * + * @note, EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + */ + def setTimeoutTimestamp(timestamp: java.sql.Date, additionalDuration: String): Unit } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 7daa5e6a0f61f29eee8bd33be042d246031c470a..fe72283bb608f79b96e0dde0e208dd62ede6eef4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import java.util +import java.sql.Date import java.util.concurrent.ConcurrentHashMap import org.scalatest.BeforeAndAfterAll @@ -44,6 +44,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf import testImplicits._ import KeyedStateImpl._ + import KeyedStateTimeout._ override def afterAll(): Unit = { super.afterAll() @@ -96,77 +97,93 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } } - test("KeyedState - setTimeoutDuration, hasTimedOut") { - import KeyedStateImpl._ - var state: KeyedStateImpl[Int] = null - - // When isTimeoutEnabled = false, then setTimeoutDuration() is not allowed + test("KeyedState - setTimeout**** with NoTimeout") { for (initState <- Seq(None, Some(5))) { // for different initial state - state = new KeyedStateImpl(initState, 1000, isTimeoutEnabled = false, hasTimedOut = false) - assert(state.hasTimedOut === false) - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) - intercept[UnsupportedOperationException] { - state.setTimeoutDuration(1000) - } - intercept[UnsupportedOperationException] { - state.setTimeoutDuration("1 day") - } - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + implicit val state = new KeyedStateImpl(initState, 1000, 1000, NoTimeout, hasTimedOut = false) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) } + } - def testTimeoutNotAllowed(): Unit = { - intercept[IllegalStateException] { - state.setTimeoutDuration(1000) - } - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) - intercept[IllegalStateException] { - state.setTimeoutDuration("2 second") - } - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) - } + test("KeyedState - setTimeout**** with ProcessingTimeTimeout") { + implicit var state: KeyedStateImpl[Int] = null - // When isTimeoutEnabled = true, then setTimeoutDuration() is not allowed until the - // state is be defined - state = new KeyedStateImpl(None, 1000, isTimeoutEnabled = true, hasTimedOut = false) - assert(state.hasTimedOut === false) - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) - testTimeoutNotAllowed() + state = new KeyedStateImpl[Int](None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + testTimeoutDurationNotAllowed[IllegalStateException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) - // After state has been set, setTimeoutDuration() is allowed, and - // getTimeoutTimestamp returned correct timestamp state.update(5) - assert(state.hasTimedOut === false) - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) state.setTimeoutDuration(1000) assert(state.getTimeoutTimestamp === 2000) state.setTimeoutDuration("2 second") assert(state.getTimeoutTimestamp === 3000) - assert(state.hasTimedOut === false) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + + state.remove() + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + testTimeoutDurationNotAllowed[IllegalStateException](state) + testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) + } + + test("KeyedState - setTimeout**** with EventTimeTimeout") { + implicit val state = new KeyedStateImpl[Int]( + None, 1000, 1000, EventTimeTimeout, hasTimedOut = false) + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[IllegalStateException](state) + + state.update(5) + state.setTimeoutTimestamp(10000) + assert(state.getTimeoutTimestamp === 10000) + state.setTimeoutTimestamp(new Date(20000)) + assert(state.getTimeoutTimestamp === 20000) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + + state.remove() + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + testTimeoutDurationNotAllowed[UnsupportedOperationException](state) + testTimeoutTimestampNotAllowed[IllegalStateException](state) + } + + test("KeyedState - illegal params to setTimeout****") { + var state: KeyedStateImpl[Int] = null - // setTimeoutDuration() with negative values or 0 is not allowed + // Test setTimeout****() with illegal values def testIllegalTimeout(body: => Unit): Unit = { intercept[IllegalArgumentException] { body } - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + assert(state.getTimeoutTimestamp === NO_TIMESTAMP) } - state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = false) + + state = new KeyedStateImpl(Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) testIllegalTimeout { state.setTimeoutDuration(-1000) } testIllegalTimeout { state.setTimeoutDuration(0) } testIllegalTimeout { state.setTimeoutDuration("-2 second") } testIllegalTimeout { state.setTimeoutDuration("-1 month") } testIllegalTimeout { state.setTimeoutDuration("1 month -1 day") } - // Test remove() clear timeout timestamp, and setTimeoutDuration() is not allowed after that - state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = false) - state.remove() - assert(state.hasTimedOut === false) - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) - testTimeoutNotAllowed() - - // Test hasTimedOut = true - state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = true) - assert(state.hasTimedOut === true) - assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET) + state = new KeyedStateImpl(Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) + testIllegalTimeout { state.setTimeoutTimestamp(-10000) } + testIllegalTimeout { state.setTimeoutTimestamp(10000, "-3 second") } + testIllegalTimeout { state.setTimeoutTimestamp(10000, "-1 month") } + testIllegalTimeout { state.setTimeoutTimestamp(10000, "1 month -1 day") } + testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000)) } + testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-3 second") } + testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "-1 month") } + testIllegalTimeout { state.setTimeoutTimestamp(new Date(-10000), "1 month -1 day") } + } + + test("KeyedState - hasTimedOut") { + for (timeoutConf <- Seq(NoTimeout, ProcessingTimeTimeout, EventTimeTimeout)) { + for (initState <- Seq(None, Some(5))) { + val state1 = new KeyedStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = false) + assert(state1.hasTimedOut === false) + val state2 = new KeyedStateImpl(initState, 1000, 1000, timeoutConf, hasTimedOut = true) + assert(state2.hasTimedOut === true) + } + } } test("KeyedState - primitive type") { @@ -187,133 +204,186 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } // Values used for testing StateStoreUpdater - val currentTimestamp = 1000 - val beforeCurrentTimestamp = 999 - val afterCurrentTimestamp = 1001 + val currentBatchTimestamp = 1000 + val currentBatchWatermark = 1000 + val beforeTimeoutThreshold = 999 + val afterTimeoutThreshold = 1001 + - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout is disabled + // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout for (priorState <- Seq(None, Some(0))) { val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" - val testName = s"timeout disabled - $priorStateStr - " + val testName = s"NoTimeout - $priorStateStr - " testStateUpdateWithData( testName + "no update", stateUpdates = state => { /* do nothing */ }, - timeoutType = KeyedStateTimeout.NoTimeout, + timeoutConf = KeyedStateTimeout.NoTimeout, priorState = priorState, expectedState = priorState) // should not change testStateUpdateWithData( testName + "state updated", stateUpdates = state => { state.update(5) }, - timeoutType = KeyedStateTimeout.NoTimeout, + timeoutConf = KeyedStateTimeout.NoTimeout, priorState = priorState, expectedState = Some(5)) // should change testStateUpdateWithData( testName + "state removed", stateUpdates = state => { state.remove() }, - timeoutType = KeyedStateTimeout.NoTimeout, + timeoutConf = KeyedStateTimeout.NoTimeout, priorState = priorState, expectedState = None) // should be removed } - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout is enabled + // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout != NoTimeout for (priorState <- Seq(None, Some(0))) { - for (priorTimeoutTimestamp <- Seq(TIMEOUT_TIMESTAMP_NOT_SET, 1000)) { - var testName = s"timeout enabled - " + for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { + var testName = s"" if (priorState.nonEmpty) { testName += "prior state set, " if (priorTimeoutTimestamp == 1000) { - testName += "prior timeout set - " + testName += "prior timeout set" } else { - testName += "no prior timeout - " + testName += "no prior timeout" } } else { - testName += "no prior state - " + testName += "no prior state" + } + for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { + + testStateUpdateWithData( + s"$timeoutConf - $testName - no update", + stateUpdates = state => { /* do nothing */ }, + timeoutConf = timeoutConf, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = priorState, // state should not change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should be reset + + testStateUpdateWithData( + s"$timeoutConf - $testName - state updated", + stateUpdates = state => { state.update(5) }, + timeoutConf = timeoutConf, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should be reset + + testStateUpdateWithData( + s"$timeoutConf - $testName - state removed", + stateUpdates = state => { state.remove() }, + timeoutConf = timeoutConf, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None) // state should be removed } testStateUpdateWithData( - testName + "no update", - stateUpdates = state => { /* do nothing */ }, - timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, - priorState = priorState, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = priorState, // state should not change - expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset - - testStateUpdateWithData( - testName + "state updated", - stateUpdates = state => { state.update(5) }, - timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, + s"ProcessingTimeTimeout - $testName - state and timeout duration updated", + stateUpdates = + (state: KeyedState[Int]) => { state.update(5); state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = Some(5), // state should change - expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) // timestamp should change testStateUpdateWithData( - testName + "state removed", - stateUpdates = state => { state.remove() }, - timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, + s"EventTimeTimeout - $testName - state and timeout timestamp updated", + stateUpdates = + (state: KeyedState[Int]) => { state.update(5); state.setTimeoutTimestamp(5000) }, + timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = None) // state should be removed + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = 5000) // timestamp should change testStateUpdateWithData( - testName + "timeout and state updated", - stateUpdates = state => { state.update(5); state.setTimeoutDuration(5000) }, - timeoutType = KeyedStateTimeout.ProcessingTimeTimeout, + s"EventTimeTimeout - $testName - timeout timestamp updated to before watermark", + stateUpdates = + (state: KeyedState[Int]) => { + state.update(5) + intercept[IllegalArgumentException] { + state.setTimeoutTimestamp(currentBatchWatermark - 1) // try to set to < watermark + } + }, + timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = Some(5), // state should change - expectedTimeoutTimestamp = currentTimestamp + 5000) // timestamp should change + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update } } // Tests for StateStoreUpdater.updateStateForTimedOutKeys() val preTimeoutState = Some(5) + for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { + testStateUpdateWithTimeout( + s"$timeoutConf - should not timeout", + stateUpdates = state => { assert(false, "function called without timeout") }, + timeoutConf = timeoutConf, + priorTimeoutTimestamp = afterTimeoutThreshold, + expectedState = preTimeoutState, // state should not change + expectedTimeoutTimestamp = afterTimeoutThreshold) // timestamp should not change + + testStateUpdateWithTimeout( + s"$timeoutConf - should timeout - no update/remove", + stateUpdates = state => { /* do nothing */ }, + timeoutConf = timeoutConf, + priorTimeoutTimestamp = beforeTimeoutThreshold, + expectedState = preTimeoutState, // state should not change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should be reset - testStateUpdateWithTimeout( - "should not timeout", - stateUpdates = state => { assert(false, "function called without timeout") }, - priorTimeoutTimestamp = afterCurrentTimestamp, - expectedState = preTimeoutState, // state should not change - expectedTimeoutTimestamp = afterCurrentTimestamp) // timestamp should not change + testStateUpdateWithTimeout( + s"$timeoutConf - should timeout - update state", + stateUpdates = state => { state.update(5) }, + timeoutConf = timeoutConf, + priorTimeoutTimestamp = beforeTimeoutThreshold, + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should be reset + + testStateUpdateWithTimeout( + s"$timeoutConf - should timeout - remove state", + stateUpdates = state => { state.remove() }, + timeoutConf = timeoutConf, + priorTimeoutTimestamp = beforeTimeoutThreshold, + expectedState = None, // state should be removed + expectedTimeoutTimestamp = NO_TIMESTAMP) + } testStateUpdateWithTimeout( - "should timeout - no update/remove", - stateUpdates = state => { /* do nothing */ }, - priorTimeoutTimestamp = beforeCurrentTimestamp, + "ProcessingTimeTimeout - should timeout - timeout duration updated", + stateUpdates = state => { state.setTimeoutDuration(2000) }, + timeoutConf = ProcessingTimeTimeout, + priorTimeoutTimestamp = beforeTimeoutThreshold, expectedState = preTimeoutState, // state should not change - expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset + expectedTimeoutTimestamp = currentBatchTimestamp + 2000) // timestamp should change testStateUpdateWithTimeout( - "should timeout - update state", - stateUpdates = state => { state.update(5) }, - priorTimeoutTimestamp = beforeCurrentTimestamp, + "ProcessingTimeTimeout - should timeout - timeout duration and state updated", + stateUpdates = state => { state.update(5); state.setTimeoutDuration(2000) }, + timeoutConf = ProcessingTimeTimeout, + priorTimeoutTimestamp = beforeTimeoutThreshold, expectedState = Some(5), // state should change - expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset + expectedTimeoutTimestamp = currentBatchTimestamp + 2000) // timestamp should change testStateUpdateWithTimeout( - "should timeout - remove state", - stateUpdates = state => { state.remove() }, - priorTimeoutTimestamp = beforeCurrentTimestamp, - expectedState = None, // state should be removed - expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) - - testStateUpdateWithTimeout( - "should timeout - timeout updated", - stateUpdates = state => { state.setTimeoutDuration(2000) }, - priorTimeoutTimestamp = beforeCurrentTimestamp, + "EventTimeTimeout - should timeout - timeout timestamp updated", + stateUpdates = state => { state.setTimeoutTimestamp(5000) }, + timeoutConf = EventTimeTimeout, + priorTimeoutTimestamp = beforeTimeoutThreshold, expectedState = preTimeoutState, // state should not change - expectedTimeoutTimestamp = currentTimestamp + 2000) // timestamp should change + expectedTimeoutTimestamp = 5000) // timestamp should change testStateUpdateWithTimeout( - "should timeout - timeout and state updated", - stateUpdates = state => { state.update(5); state.setTimeoutDuration(2000) }, - priorTimeoutTimestamp = beforeCurrentTimestamp, + "EventTimeTimeout - should timeout - timeout and state updated", + stateUpdates = state => { state.update(5); state.setTimeoutTimestamp(5000) }, + timeoutConf = EventTimeTimeout, + priorTimeoutTimestamp = beforeTimeoutThreshold, expectedState = Some(5), // state should change - expectedTimeoutTimestamp = currentTimestamp + 2000) // timestamp should change + expectedTimeoutTimestamp = 5000) // timestamp should change test("StateStoreUpdater - rows are cloned before writing to StateStore") { // function for running count @@ -481,11 +551,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val clock = new StreamManualClock val inputData = MemoryStream[String] - val timeout = KeyedStateTimeout.ProcessingTimeTimeout val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(Update, timeout)(stateFunc) + .flatMapGroupsWithState(Update, ProcessingTimeTimeout)(stateFunc) testStream(result, Update)( StartStream(ProcessingTime("1 second"), triggerClock = clock), @@ -519,6 +588,52 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf ) } + test("flatMapGroupsWithState - streaming with event time timeout") { + // Function to maintain the max event time + // Returns the max event time in the state, or -1 if the state was removed by timeout + val stateFunc = ( + key: String, + values: Iterator[(String, Long)], + state: KeyedState[Long]) => { + val timeoutDelay = 5 + if (key != "a") { + Iterator.empty + } else { + if (state.hasTimedOut) { + state.remove() + Iterator((key, -1)) + } else { + val valuesSeq = values.toSeq + val maxEventTime = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) + val timeoutTimestampMs = maxEventTime + timeoutDelay + state.update(maxEventTime) + state.setTimeoutTimestamp(timeoutTimestampMs * 1000) + Iterator((key, maxEventTime.toInt)) + } + } + } + val inputData = MemoryStream[(String, Int)] + val result = + inputData.toDS + .select($"_1".as("key"), $"_2".cast("timestamp").as("eventTime")) + .withWatermark("eventTime", "10 seconds") + .as[(String, Long)] + .groupByKey(_._1) + .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc) + + testStream(result, Update)( + StartStream(ProcessingTime("1 second")), + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), // Set timeout timestamp of ... + CheckLastBatch(("a", 15)), // "a" to 15 + 5 = 20s, watermark to 5s + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" + CheckLastBatch(), // No output as data should get filtered by watermark + AddData(inputData, ("dummy", 35)), // Set watermark = 35 - 10 = 25s + CheckLastBatch(), // No output as no data for "a" + AddData(inputData, ("a", 24)), // Add data older than watermark, should be ignored + CheckLastBatch(("a", -1)) // State for "a" should timeout and emit -1 + ) + } + test("mapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) @@ -612,7 +727,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => key val inputData = MemoryStream[String] val result = inputData.toDS.groupByKey(x => x).mapGroupsWithState(stateFunc) - result testStream(result, Update)( AddData(inputData, "a"), CheckLastBatch("a"), @@ -649,13 +763,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf def testStateUpdateWithData( testName: String, stateUpdates: KeyedState[Int] => Unit, - timeoutType: KeyedStateTimeout = KeyedStateTimeout.NoTimeout, + timeoutConf: KeyedStateTimeout, priorState: Option[Int], - priorTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET, + priorTimeoutTimestamp: Long = NO_TIMESTAMP, expectedState: Option[Int] = None, - expectedTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET): Unit = { + expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { - if (priorState.isEmpty && priorTimeoutTimestamp != TIMEOUT_TIMESTAMP_NOT_SET) { + if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) { return // there can be no prior timestamp, when there is no prior state } test(s"StateStoreUpdater - updates with data - $testName") { @@ -666,7 +780,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf Iterator.empty } testStateUpdate( - testTimeoutUpdates = false, mapGroupsFunc, timeoutType, + testTimeoutUpdates = false, mapGroupsFunc, timeoutConf, priorState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp) } } @@ -674,9 +788,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf def testStateUpdateWithTimeout( testName: String, stateUpdates: KeyedState[Int] => Unit, + timeoutConf: KeyedStateTimeout, priorTimeoutTimestamp: Long, expectedState: Option[Int], - expectedTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET): Unit = { + expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { test(s"StateStoreUpdater - updates for timeout - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => { @@ -686,16 +801,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf Iterator.empty } testStateUpdate( - testTimeoutUpdates = true, mapGroupsFunc, KeyedStateTimeout.ProcessingTimeTimeout, - preTimeoutState, priorTimeoutTimestamp, - expectedState, expectedTimeoutTimestamp) + testTimeoutUpdates = true, mapGroupsFunc, timeoutConf = timeoutConf, + preTimeoutState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp) } } def testStateUpdate( testTimeoutUpdates: Boolean, mapGroupsFunc: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int], - timeoutType: KeyedStateTimeout, + timeoutConf: KeyedStateTimeout, priorState: Option[Int], priorTimeoutTimestamp: Long, expectedState: Option[Int], @@ -703,7 +817,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val store = newStateStore() val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( - mapGroupsFunc, timeoutType, currentTimestamp) + mapGroupsFunc, timeoutConf, currentBatchTimestamp) val updater = new mapGroupsSparkPlan.StateStoreUpdater(store) val key = intToRow(0) // Prepare store with prior state configs @@ -736,7 +850,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf def newFlatMapGroupsWithStateExec( func: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int], timeoutType: KeyedStateTimeout = KeyedStateTimeout.NoTimeout, - batchTimestampMs: Long = NO_BATCH_PROCESSING_TIMESTAMP): FlatMapGroupsWithStateExec = { + batchTimestampMs: Long = NO_TIMESTAMP): FlatMapGroupsWithStateExec = { MemoryStream[Int] .toDS .groupByKey(x => x) @@ -744,11 +858,31 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf .logicalPlan.collectFirst { case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) => FlatMapGroupsWithStateExec( - f, k, v, g, d, o, None, s, m, t, currentTimestamp, - RDDScanExec(g, null, "rdd")) + f, k, v, g, d, o, None, s, m, t, + Some(currentBatchTimestamp), Some(currentBatchWatermark), RDDScanExec(g, null, "rdd")) }.get } + def testTimeoutDurationNotAllowed[T <: Exception: Manifest](state: KeyedStateImpl[_]): Unit = { + val prevTimestamp = state.getTimeoutTimestamp + intercept[T] { state.setTimeoutDuration(1000) } + assert(state.getTimeoutTimestamp === prevTimestamp) + intercept[T] { state.setTimeoutDuration("2 second") } + assert(state.getTimeoutTimestamp === prevTimestamp) + } + + def testTimeoutTimestampNotAllowed[T <: Exception: Manifest](state: KeyedStateImpl[_]): Unit = { + val prevTimestamp = state.getTimeoutTimestamp + intercept[T] { state.setTimeoutTimestamp(2000) } + assert(state.getTimeoutTimestamp === prevTimestamp) + intercept[T] { state.setTimeoutTimestamp(2000, "1 second") } + assert(state.getTimeoutTimestamp === prevTimestamp) + intercept[T] { state.setTimeoutTimestamp(new Date(2000)) } + assert(state.getTimeoutTimestamp === prevTimestamp) + intercept[T] { state.setTimeoutTimestamp(new Date(2000), "1 second") } + assert(state.getTimeoutTimestamp === prevTimestamp) + } + def newStateStore(): StateStore = new MemoryStateStore() val intProj = UnsafeProjection.create(Array[DataType](IntegerType))