diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index f4d016cb96711481900fc2b6e728948bda90bc17..d8aad42edcf5fdee3325fe1e4d2a58a5a39efd45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -46,8 +46,13 @@ object UnsupportedOperationChecker { "Queries without streaming sources cannot be executed with writeStream.start()")(plan) } + /** Collect all the streaming aggregates in a sub plan */ + def collectStreamingAggregates(subplan: LogicalPlan): Seq[Aggregate] = { + subplan.collect { case a: Aggregate if a.isStreaming => a } + } + // Disallow multiple streaming aggregations - val aggregates = plan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a } + val aggregates = collectStreamingAggregates(plan) if (aggregates.size > 1) { throwError( @@ -114,6 +119,10 @@ object UnsupportedOperationChecker { case _: InsertIntoTable => throwError("InsertIntoTable is not supported with streaming DataFrames/Datasets") + case m: MapGroupsWithState if collectStreamingAggregates(m).nonEmpty => + throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " + + "streaming DataFrame/Dataset") + case Join(left, right, joinType, _) => joinType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 0ab4c9016623e10ebf6bd0dee0a6fb12e63ea538..0be4823bbc895a66dcb7096cf822c1bbc1537133 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -313,6 +313,55 @@ case class MapGroups( outputObjAttr: Attribute, child: LogicalPlan) extends UnaryNode with ObjectProducer +/** Internal class representing State */ +trait LogicalKeyedState[S] + +/** Factory for constructing new `MapGroupsWithState` nodes. */ +object MapGroupsWithState { + def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder]( + func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + child: LogicalPlan): LogicalPlan = { + val mapped = new MapGroupsWithState( + func, + UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), + UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes), + groupingAttributes, + dataAttributes, + CatalystSerde.generateObjAttr[U], + encoderFor[S].resolveAndBind().deserializer, + encoderFor[S].namedExpressions, + child) + CatalystSerde.serialize[U](mapped) + } +} + +/** + * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`, + * while using state data. + * Func is invoked with an object representation of the grouping key an iterator containing the + * object representation of all the rows with that key. + * + * @param keyDeserializer used to extract the key object for each group. + * @param valueDeserializer used to extract the items in the iterator from an input row. + * @param groupingAttributes used to group the data + * @param dataAttributes used to read the data + * @param outputObjAttr used to define the output object + * @param stateDeserializer used to deserialize state before calling `func` + * @param stateSerializer used to serialize updated state after calling `func` + */ +case class MapGroupsWithState( + func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + outputObjAttr: Attribute, + stateDeserializer: Expression, + stateSerializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectProducer + /** Factory for constructing new `FlatMapGroupsInR` nodes. */ object FlatMapGroupsInR { def apply( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index dcdb1ae089328918d5a6873f74923f8ea140a1ea..3b756e89d903611fac3b763e05f4a9251324b6b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -22,13 +22,13 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{IntegerType, LongType} /** A dummy command for testing unsupported operations. */ case class DummyCommand() extends Command @@ -111,6 +111,24 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Complete, expectedMsgs = Seq("distinct aggregation")) + // MapGroupsWithState: Not supported after a streaming aggregation + val att = new AttributeReference(name = "a", dataType = LongType)() + assertSupportedInBatchPlan( + "mapGroupsWithState - mapGroupsWithState on batch relation", + MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation)) + + assertSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState on streaming relation before aggregation", + MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), streamRelation), + outputMode = Append) + + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState on streaming relation after aggregation", + MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), + Aggregate(Nil, aggExprs("c"), streamRelation)), + outputMode = Complete, + expectedMsgs = Seq("(map/flatMap)GroupsWithState")) + // Inner joins: Stream-stream not supported testBinaryOperationInStreamingPlan( "inner join", diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java new file mode 100644 index 0000000000000000000000000000000000000000..2570c8d02ab7cd6d358e03db46459e3e80758733 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.KeyedState; + +/** + * ::Experimental:: + * Base interface for a map function used in + * {@link org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroupsWithState(FlatMapGroupsWithStateFunction, Encoder, Encoder)}. + * @since 2.1.1 + */ +@Experimental +@InterfaceStability.Evolving +public interface FlatMapGroupsWithStateFunction<K, V, S, R> extends Serializable { + Iterator<R> call(K key, Iterator<V> values, KeyedState<S> state) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java new file mode 100644 index 0000000000000000000000000000000000000000..614d3925e0510d9a2d84b8a0a678580084e39db0 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.KeyedState; + +/** + * ::Experimental:: + * Base interface for a map function used in + * {@link org.apache.spark.sql.KeyValueGroupedDataset#mapGroupsWithState(MapGroupsWithStateFunction, Encoder, Encoder)} + * @since 2.1.1 + */ +@Experimental +@InterfaceStability.Evolving +public interface MapGroupsWithStateFunction<K, V, S, R> extends Serializable { + R call(K key, Iterator<V> values, KeyedState<S> state) throws Exception; +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 395d709f265913fbad5c5e4eda6e462567856c2a..94e689a4d5b97149ff07cb584da22c4f69b4ea8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -218,6 +218,119 @@ class KeyValueGroupedDataset[K, V] private[sql]( mapGroups((key, data) => f.call(key, data.asJava))(encoder) } + /** + * ::Experimental:: + * (Scala-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving + def mapGroupsWithState[S: Encoder, U: Encoder]( + func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = { + flatMapGroupsWithState[S, U]( + (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s))) + } + + /** + * ::Experimental:: + * (Java-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param stateEncoder Encoder for the state type. + * @param outputEncoder Encoder for the output type. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving + def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U]): Dataset[U] = { + flatMapGroupsWithState[S, U]( + (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func.call(key, it.asJava, s)) + )(stateEncoder, outputEncoder) + } + + /** + * ::Experimental:: + * (Scala-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving + def flatMapGroupsWithState[S: Encoder, U: Encoder]( + func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = { + Dataset[U]( + sparkSession, + MapGroupsWithState[K, V, S, U]( + func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], + groupingAttributes, + dataAttributes, + logicalPlan)) + } + + /** + * ::Experimental:: + * (Java-specific) + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param stateEncoder Encoder for the state type. + * @param outputEncoder Encoder for the output type. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving + def flatMapGroupsWithState[S, U]( + func: FlatMapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U]): Dataset[U] = { + flatMapGroupsWithState[S, U]( + (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala + )(stateEncoder, outputEncoder) + } + /** * (Scala-specific) * Reduces the elements of each group of data using the specified binary function. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala new file mode 100644 index 0000000000000000000000000000000000000000..6864b6f6b4fd2bae851de27e51cb96e84f7a1c84 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.lang.IllegalArgumentException + +import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState + +/** + * :: Experimental :: + * + * Wrapper class for interacting with keyed state data in `mapGroupsWithState` and + * `flatMapGroupsWithState` operations on + * [[KeyValueGroupedDataset]]. + * + * Detail description on `[map/flatMap]GroupsWithState` operation + * ------------------------------------------------------------ + * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in [[KeyValueGroupedDataset]] + * will invoke the user-given function on each group (defined by the grouping function in + * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger. + * That is, in every batch of the [[streaming.StreamingQuery StreamingQuery]], + * the function will be invoked once for each group that has data in the batch. + * + * The function is invoked with following parameters. + * - The key of the group. + * - An iterator containing all the values for this key. + * - A user-defined state object set by previous invocations of the given function. + * In case of a batch Dataset, there is only one invocation and state object will be empty as + * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` + * is equivalent to `[map/flatMap]Groups`. + * + * Important points to note about the function. + * - In a trigger, the function will be called only the groups present in the batch. So do not + * assume that the function will be called in every trigger for every group that has state. + * - There is no guaranteed ordering of values in the iterator in the function, neither with + * batch, nor with streaming Datasets. + * - All the data will be shuffled before applying the function. + * + * Important points to note about using KeyedState. + * - The value of the state cannot be null. So updating state with null will throw + * `IllegalArgumentException`. + * - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers. + * - If `remove()` is called, then `exists()` will return `false`, + * `get()` will throw `NoSuchElementException` and `getOption()` will return `None` + * - After that, if `update(newState)` is called, then `exists()` will again return `true`, + * `get()` and `getOption()`will return the updated value. + * + * Scala example of using KeyedState in `mapGroupsWithState`: + * {{{ + * /* A mapping function that maintains an integer state for string keys and returns a string. */ + * def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = { + * // Check if state exists + * if (state.exists) { + * val existingState = state.get // Get the existing state + * val shouldRemove = ... // Decide whether to remove the state + * if (shouldRemove) { + * state.remove() // Remove the state + * } else { + * val newState = ... + * state.update(newState) // Set the new state + * } + * } else { + * val initialState = ... + * state.update(initialState) // Set the initial state + * } + * ... // return something + * } + * + * }}} + * + * Java example of using `KeyedState`: + * {{{ + * /* A mapping function that maintains an integer state for string keys and returns a string. */ + * MapGroupsWithStateFunction<String, Integer, Integer, String> mappingFunction = + * new MapGroupsWithStateFunction<String, Integer, Integer, String>() { + * + * @Override + * public String call(String key, Iterator<Integer> value, KeyedState<Integer> state) { + * if (state.exists()) { + * int existingState = state.get(); // Get the existing state + * boolean shouldRemove = ...; // Decide whether to remove the state + * if (shouldRemove) { + * state.remove(); // Remove the state + * } else { + * int newState = ...; + * state.update(newState); // Set the new state + * } + * } else { + * int initialState = ...; // Set the initial state + * state.update(initialState); + * } + * ... // return something + * } + * }; + * }}} + * + * @tparam S User-defined type of the state to be stored for each key. Must be encodable into + * Spark SQL types (see [[Encoder]] for more details). + * @since 2.1.1 + */ +@Experimental +@InterfaceStability.Evolving +trait KeyedState[S] extends LogicalKeyedState[S] { + + /** Whether state exists or not. */ + def exists: Boolean + + /** Get the state value if it exists, or throw NoSuchElementException. */ + @throws[NoSuchElementException]("when state does not exist") + def get: S + + /** Get the state value as a scala Option. */ + def getOption: Option[S] + + /** + * Update the value of the state. Note that `null` is not a valid value, and it throws + * IllegalArgumentException. + */ + @throws[IllegalArgumentException]("when updating with null") + def update(newState: S): Unit + + /** Remove this keyed state. */ + def remove(): Unit +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ba82ec156e85040e3bcc20331329ad9a38310b7b..adea358594a0fba60593772af55c2b127b721a71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan, MapGroupsWithState} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} @@ -324,6 +324,23 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Strategy to convert MapGroupsWithState logical operator to physical operator + * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. + */ + object MapGroupsWithStateStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case MapGroupsWithState( + f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateDeser, stateSer, child) => + val execPlan = MapGroupsWithStateExec( + f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer, + planLater(child)) + execPlan :: Nil + case _ => + Nil + } + } + // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def numPartitions: Int = self.numPartitions @@ -365,6 +382,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil + case logical.MapGroupsWithState(f, key, value, grouping, data, output, _, _, child) => + execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index fde3b2a528994a3c2f7a97e3dddbb9b071d0d17c..199ba5ce6969bb3e1196691d764e6c2914b42a4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -30,6 +30,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState +import org.apache.spark.sql.execution.streaming.KeyedStateImpl import org.apache.spark.sql.types.{DataType, ObjectType, StructType} @@ -144,6 +146,11 @@ object ObjectOperator { (i: InternalRow) => proj(i).get(0, deserializer.dataType) } + def deserializeRowToObject(deserializer: Expression): InternalRow => Any = { + val proj = GenerateSafeProjection.generate(deserializer :: Nil) + (i: InternalRow) => proj(i).get(0, deserializer.dataType) + } + def serializeObjectToRow(serializer: Seq[Expression]): Any => UnsafeRow = { val proj = GenerateUnsafeProjection.generate(serializer) val objType = serializer.head.collect { case b: BoundReference => b.dataType }.head @@ -344,6 +351,21 @@ case class MapGroupsExec( } } +object MapGroupsExec { + def apply( + func: (Any, Iterator[Any], LogicalKeyedState[Any]) => TraversableOnce[Any], + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + outputObjAttr: Attribute, + child: SparkPlan): MapGroupsExec = { + val f = (key: Any, values: Iterator[Any]) => func(key, values, new KeyedStateImpl[Any](None)) + new MapGroupsExec(f, keyDeserializer, valueDeserializer, + groupingAttributes, dataAttributes, outputObjAttr, child) + } +} + /** * Groups the input rows together and calls the R function with each group and an iterator * containing all elements in the group. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 6ab6fa61dc200d7bb51054e44fb325b9fbf9c6e5..5c4cbfa7552cbcb8eb1b79f35534fa0300c2b5bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.util.concurrent.atomic.AtomicInteger + import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, Literal} import org.apache.spark.sql.SparkSession @@ -39,8 +41,9 @@ class IncrementalExecution( extends QueryExecution(sparkSession, logicalPlan) with Logging { // TODO: make this always part of planning. - val stateStrategy = + val streamingExtraStrategies = sparkSession.sessionState.planner.StatefulAggregationStrategy +: + sparkSession.sessionState.planner.MapGroupsWithStateStrategy +: sparkSession.sessionState.planner.StreamingRelationStrategy +: sparkSession.sessionState.experimentalMethods.extraStrategies @@ -49,7 +52,7 @@ class IncrementalExecution( new SparkPlanner( sparkSession.sparkContext, sparkSession.sessionState.conf, - stateStrategy) + streamingExtraStrategies) /** * See [SPARK-18339] @@ -68,7 +71,7 @@ class IncrementalExecution( * Records the current id for a given stateful operator in the query plan as the `state` * preparation walks the query plan. */ - private var operatorId = 0 + private val operatorId = new AtomicInteger(0) /** Locates save/restore pairs surrounding aggregation. */ val state = new Rule[SparkPlan] { @@ -77,8 +80,8 @@ class IncrementalExecution( case StateStoreSaveExec(keys, None, None, None, UnaryExecNode(agg, StateStoreRestoreExec(keys2, None, child))) => - val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId) - operatorId += 1 + val stateId = + OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) StateStoreSaveExec( keys, @@ -90,6 +93,12 @@ class IncrementalExecution( keys, Some(stateId), child) :: Nil)) + case MapGroupsWithStateExec( + f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) => + val stateId = + OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) + MapGroupsWithStateExec( + f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala new file mode 100644 index 0000000000000000000000000000000000000000..eee7ec45dd77bcfd9a22cfb398f64d52a221fd2c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.KeyedState + +/** Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. */ +private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedState[S] { + private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) + private var defined: Boolean = optionalValue.isDefined + private var updated: Boolean = false + // whether value has been updated (but not removed) + private var removed: Boolean = false // whether value has been removed + + // ========= Public API ========= + override def exists: Boolean = defined + + override def get: S = { + if (defined) { + value + } else { + throw new NoSuchElementException("State is either not defined or has already been removed") + } + } + + override def getOption: Option[S] = { + if (defined) { + Some(value) + } else { + None + } + } + + override def update(newValue: S): Unit = { + if (newValue == null) { + throw new IllegalArgumentException("'null' is not a valid state value") + } + value = newValue + defined = true + updated = true + removed = false + } + + override def remove(): Unit = { + defined = false + updated = false + removed = true + } + + override def toString: String = { + s"KeyedState(${getOption.map(_.toString).getOrElse("<undefined>")})" + } + + // ========= Internal API ========= + + /** Whether the state has been marked for removing */ + def isRemoved: Boolean = { + removed + } + + /** Whether the state has been been updated */ + def isUpdated: Boolean = { + updated + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 1f74fffbe6e67f55839b8aec8557cb8939d54047..693933f95a231d5b5e4e4b70494e0a01b00d9303 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -186,7 +186,7 @@ trait ProgressReporter extends Logging { // lastExecution could belong to one of the previous triggers if `!hasNewData`. // Walking the plan again should be inexpensive. val stateNodes = lastExecution.executedPlan.collect { - case p if p.isInstanceOf[StateStoreSaveExec] => p + case p if p.isInstanceOf[StateStoreWriter] => p } stateNodes.map { node => val numRowsUpdated = if (hasNewData) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 1279b71c4d6eef0fb7f610d5d9c821e5424f4e69..61eb601a18c3228f63cbe7e70ef82631ced38479 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -147,6 +147,25 @@ private[state] class HDFSBackedStateStoreProvider( } } + /** Remove a single key. */ + override def remove(key: UnsafeRow): Unit = { + verify(state == UPDATING, "Cannot remove after already committed or aborted") + if (mapToUpdate.containsKey(key)) { + val value = mapToUpdate.remove(key) + Option(allUpdates.get(key)) match { + case Some(ValueUpdated(_, _)) | None => + // Value existed in previous version and maybe was updated, mark removed + allUpdates.put(key, ValueRemoved(key, value)) + case Some(ValueAdded(_, _)) => + // Value did not exist in previous version and was added, should not appear in updates + allUpdates.remove(key) + case Some(ValueRemoved(_, _)) => + // Remove already in update map, no need to change + } + writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value)) + } + } + /** Commit all the updates that have been made to the store, and return the new version. */ override def commit(): Long = { verify(state == UPDATING, "Cannot commit after already committed or aborted") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index e61d95a1b1bb070aa13745b4202e0bc7257e98cc..dcb24b26f78f3a8abd6b2e825160d8c01cfbf189 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -58,6 +58,11 @@ trait StateStore { */ def remove(condition: UnsafeRow => Boolean): Unit + /** + * Remove a single key. + */ + def remove(key: UnsafeRow): Unit + /** * Commit all the updates that have been made to the store, and return the new version. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 1b56c08f729c6927f474c5a23442b5d216bbadd5..589042afb1e52ce646cd136675154e93c8d4dcd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming import scala.reflect.ClassTag +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.internal.SessionState @@ -59,10 +60,18 @@ package object state { sessionState: SessionState, storeCoordinator: Option[StateStoreCoordinatorRef])( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { + val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) + val wrappedF = (store: StateStore, iter: Iterator[T]) => { + // Abort the state store in case of error + TaskContext.get().addTaskCompletionListener(_ => { + if (!store.hasCommitted) store.abort() + }) + cleanedF(store, iter) + } new StateStoreRDD( dataRDD, - cleanedF, + wrappedF, checkpointLocation, operatorId, storeVersion, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala similarity index 63% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index d4ccced9ac9b4c7c2a1a17ea97b363803a63b773..12924525745942242ee061a332b748e54c18008f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -22,16 +22,16 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} -import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark -import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType -import org.apache.spark.TaskContext +import org.apache.spark.util.CompletionIterator /** Used to identify the state store for a given operator. */ @@ -41,7 +41,7 @@ case class OperatorStateId( batchId: Long) /** - * An operator that saves or restores state from the [[StateStore]]. The [[OperatorStateId]] should + * An operator that reads or writes state from the [[StateStore]]. The [[OperatorStateId]] should * be filled in by `prepareForExecution` in [[IncrementalExecution]]. */ trait StatefulOperator extends SparkPlan { @@ -54,6 +54,20 @@ trait StatefulOperator extends SparkPlan { } } +/** An operator that reads from a StateStore. */ +trait StateStoreReader extends StatefulOperator { + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) +} + +/** An operator that writes to a StateStore. */ +trait StateStoreWriter extends StatefulOperator { + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"), + "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows")) +} + /** * For each input tuple, the key is calculated and the value from the [[StateStore]] is added * to the stream (in addition to the input tuple) if present. @@ -62,10 +76,7 @@ case class StateStoreRestoreExec( keyExpressions: Seq[Attribute], stateId: Option[OperatorStateId], child: SparkPlan) - extends execution.UnaryExecNode with StatefulOperator { - - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + extends execution.UnaryExecNode with StateStoreReader { override protected def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") @@ -102,12 +113,7 @@ case class StateStoreSaveExec( outputMode: Option[OutputMode] = None, eventTimeWatermark: Option[Long] = None, child: SparkPlan) - extends execution.UnaryExecNode with StatefulOperator { - - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"), - "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows")) + extends execution.UnaryExecNode with StateStoreWriter { /** Generate a predicate that matches data older than the watermark */ private lazy val watermarkPredicate: Option[Predicate] = { @@ -151,13 +157,6 @@ case class StateStoreSaveExec( val numTotalStateRows = longMetric("numTotalStateRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") - // Abort the state store in case of error - TaskContext.get().addTaskCompletionListener(_ => { - if (!store.hasCommitted) { - store.abort() - } - }) - outputMode match { // Update and output all rows in the StateStore. case Some(Complete) => @@ -184,7 +183,7 @@ case class StateStoreSaveExec( } // Assumption: Append mode can be done only when watermark has been specified - store.remove(watermarkPredicate.get.eval) + store.remove(watermarkPredicate.get.eval _) store.commit() numTotalStateRows += store.numKeys() @@ -207,7 +206,7 @@ case class StateStoreSaveExec( override def hasNext: Boolean = { if (!baseIterator.hasNext) { // Remove old aggregates if watermark specified - if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval) + if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval _) store.commit() numTotalStateRows += store.numKeys() false @@ -235,3 +234,90 @@ case class StateStoreSaveExec( override def outputPartitioning: Partitioning = child.outputPartitioning } + + +/** Physical operator for executing streaming mapGroupsWithState. */ +case class MapGroupsWithStateExec( + func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + outputObjAttr: Attribute, + stateId: Option[OperatorStateId], + stateDeserializer: Expression, + stateSerializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter { + + override def outputPartitioning: Partitioning = child.outputPartitioning + + /** Distribute by grouping attributes */ + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(groupingAttributes) :: Nil + + /** Ordering needed for using GroupingIterator */ + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsWithStateStore[InternalRow]( + getStateId.checkpointLocation, + operatorId = getStateId.operatorId, + storeVersion = getStateId.batchId, + groupingAttributes.toStructType, + child.output.toStructType, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + val numTotalStateRows = longMetric("numTotalStateRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val numOutputRows = longMetric("numOutputRows") + + // Generate a iterator that returns the rows grouped by the grouping function + val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) + + // Converters to and from object and rows + val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) + val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) + val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + val getStateObj = + ObjectOperator.deserializeRowToObject(stateDeserializer) + val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer) + + // For every group, get the key, values and corresponding state and call the function, + // and return an iterator of rows + val allRowsIterator = groupedIter.flatMap { case (keyRow, valueRowIter) => + + val key = keyRow.asInstanceOf[UnsafeRow] + val keyObj = getKeyObj(keyRow) // convert key to objects + val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects + val stateObjOption = store.get(key).map(getStateObj) // get existing state if any + val wrappedState = new KeyedStateImpl(stateObjOption) + val mappedIterator = func(keyObj, valueObjIter, wrappedState).map { obj => + numOutputRows += 1 + getOutputRow(obj) // convert back to rows + } + + // Return an iterator of rows generated this key, + // such that fully consumed, the updated state value will be saved + CompletionIterator[InternalRow, Iterator[InternalRow]]( + mappedIterator, { + // When the iterator is consumed, then write changes to state + if (wrappedState.isRemoved) { + store.remove(key) + numUpdatedStateRows += 1 + } else if (wrappedState.isUpdated) { + store.put(key, outputStateObj(wrappedState.get)) + numUpdatedStateRows += 1 + } + }) + } + + // Return an iterator of all the rows generated by all the keys, such that when fully + // consumer, all the state updates will be committed by the state store + CompletionIterator[InternalRow, Iterator[InternalRow]](allRowsIterator, { + store.commit() + numTotalStateRows += store.numKeys() + }) + } + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 8304b728aa2389635e82991f2f36c4e6cfc91315..5ef4e887ded09dfc068dae1a75569476af500906 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -225,6 +225,38 @@ public class JavaDatasetSuite implements Serializable { Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList())); + Dataset<String> mapped2 = grouped.mapGroupsWithState( + new MapGroupsWithStateFunction<Integer, String, Long, String>() { + @Override + public String call(Integer key, Iterator<String> values, KeyedState<Long> s) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); + } + return sb.toString(); + } + }, + Encoders.LONG(), + Encoders.STRING()); + + Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped2.collectAsList())); + + Dataset<String> flatMapped2 = grouped.flatMapGroupsWithState( + new FlatMapGroupsWithStateFunction<Integer, String, Long, String>() { + @Override + public Iterator<String> call(Integer key, Iterator<String> values, KeyedState<Long> s) { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); + } + return Collections.singletonList(sb.toString()).iterator(); + } + }, + Encoders.LONG(), + Encoders.STRING()); + + Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList())); + Dataset<Tuple2<Integer, String>> reduced = grouped.reduceGroups(new ReduceFunction<String>() { @Override public String call(String v1, String v2) throws Exception { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..0524898b15ead4ef16e47baf404ebe63725c1989 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -0,0 +1,335 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkException +import org.apache.spark.sql.KeyedState +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.streaming.{KeyedStateImpl, MemoryStream} +import org.apache.spark.sql.execution.streaming.state.StateStore + +/** Class to check custom state types */ +case class RunningCount(count: Long) + +class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { + + import testImplicits._ + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() + } + + test("KeyedState - get, exists, update, remove") { + var state: KeyedStateImpl[String] = null + + def testState( + expectedData: Option[String], + shouldBeUpdated: Boolean = false, + shouldBeRemoved: Boolean = false): Unit = { + if (expectedData.isDefined) { + assert(state.exists) + assert(state.get === expectedData.get) + } else { + assert(!state.exists) + intercept[NoSuchElementException] { + state.get + } + } + assert(state.getOption === expectedData) + assert(state.isUpdated === shouldBeUpdated) + assert(state.isRemoved === shouldBeRemoved) + } + + // Updating empty state + state = new KeyedStateImpl[String](None) + testState(None) + state.update("") + testState(Some(""), shouldBeUpdated = true) + + // Updating exiting state + state = new KeyedStateImpl[String](Some("2")) + testState(Some("2")) + state.update("3") + testState(Some("3"), shouldBeUpdated = true) + + // Removing state + state.remove() + testState(None, shouldBeRemoved = true, shouldBeUpdated = false) + state.remove() // should be still callable + state.update("4") + testState(Some("4"), shouldBeRemoved = false, shouldBeUpdated = true) + + // Updating by null throw exception + intercept[IllegalArgumentException] { + state.update(null) + } + } + + test("KeyedState - primitive type") { + var intState = new KeyedStateImpl[Int](None) + intercept[NoSuchElementException] { + intState.get + } + assert(intState.getOption === None) + + intState = new KeyedStateImpl[Int](Some(10)) + assert(intState.get == 10) + intState.update(0) + assert(intState.get == 0) + intState.remove() + intercept[NoSuchElementException] { + intState.get + } + } + + test("flatMapGroupsWithState - streaming") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + if (count == 3) { + state.remove() + Iterator.empty + } else { + state.update(RunningCount(count)) + Iterator((key, count.toString)) + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) + + testStream(result, Append)( + AddData(inputData, "a"), + CheckLastBatch(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a", "b"), + CheckLastBatch(("a", "2"), ("b", "1")), + assertNumStateRows(total = 2, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckLastBatch(("b", "2")), + assertNumStateRows(total = 1, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckLastBatch(("a", "1"), ("c", "1")), + assertNumStateRows(total = 3, updated = 2) + ) + } + + test("flatMapGroupsWithState - streaming + func returns iterator that updates state lazily") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + // Additionally, it updates state lazily as the returned iterator get consumed + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + values.flatMap { _ => + val count = state.getOption.map(_.count).getOrElse(0L) + 1 + if (count == 3) { + state.remove() + None + } else { + state.update(RunningCount(count)) + Some((key, count.toString)) + } + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) + + testStream(result, Append)( + AddData(inputData, "a", "a", "b"), + CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckLastBatch(("b", "2")), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckLastBatch(("a", "1"), ("c", "1")) + ) + } + + test("flatMapGroupsWithState - batch") { + // Function that returns running count only if its even, otherwise does not return + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + if (state.exists) throw new IllegalArgumentException("state.exists should be false") + Iterator((key, values.size)) + } + checkAnswer( + Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc).toDF, + Seq(("a", 2), ("b", 1)).toDF) + } + + test("mapGroupsWithState - streaming") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + if (count == 3) { + state.remove() + (key, "-1") + } else { + state.update(RunningCount(count)) + (key, count.toString) + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) + + testStream(result, Append)( + AddData(inputData, "a"), + CheckLastBatch(("a", "1")), + assertNumStateRows(total = 1, updated = 1), + AddData(inputData, "a", "b"), + CheckLastBatch(("a", "2"), ("b", "1")), + assertNumStateRows(total = 2, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 + CheckLastBatch(("a", "-1"), ("b", "2")), + assertNumStateRows(total = 1, updated = 2), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 + CheckLastBatch(("a", "1"), ("c", "1")), + assertNumStateRows(total = 3, updated = 2) + ) + } + + test("mapGroupsWithState - streaming + aggregation") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + if (count == 3) { + state.remove() + (key, "-1") + } else { + state.update(RunningCount(count)) + (key, count.toString) + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) + .groupByKey(_._1) + .count() + + testStream(result, Complete)( + AddData(inputData, "a"), + CheckLastBatch(("a", 1)), + AddData(inputData, "a", "b"), + // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 + CheckLastBatch(("a", 2), ("b", 1)), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), + // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; + // so increment a and b by 1 + CheckLastBatch(("a", 3), ("b", 2)), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), + // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; + // so increment a and c by 1 + CheckLastBatch(("a", 4), ("b", 2), ("c", 1)) + ) + } + + test("mapGroupsWithState - batch") { + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + if (state.exists) throw new IllegalArgumentException("state.exists should be false") + (key, values.size) + } + + checkAnswer( + spark.createDataset(Seq("a", "a", "b")) + .groupByKey(x => x) + .mapGroupsWithState(stateFunc) + .toDF, + spark.createDataset(Seq(("a", 2), ("b", 1))).toDF) + } + + testQuietly("StateStore.abort on task failure handling") { + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure") + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + (key, count) + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) + + def setFailInTask(value: Boolean): AssertOnQuery = AssertOnQuery { q => + MapGroupsWithStateSuite.failInTask = value + true + } + + testStream(result, Append)( + setFailInTask(false), + AddData(inputData, "a"), + CheckLastBatch(("a", 1L)), + AddData(inputData, "a"), + CheckLastBatch(("a", 2L)), + setFailInTask(true), + AddData(inputData, "a"), + ExpectFailure[SparkException](), // task should fail but should not increment count + setFailInTask(false), + StartStream(), + CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count + ) + } + + private def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = AssertOnQuery { q => + val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get + assert(progressWithData.stateOperators(0).numRowsTotal === total, "incorrect total rows") + assert(progressWithData.stateOperators(0).numRowsUpdated === updated, "incorrect updates rows") + true + } +} + +object MapGroupsWithStateSuite { + var failInTask = true +}