From 3457c32297e0150a4fbc80a30f84b9c62ca7c372 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu <shixiong@databricks.com> Date: Wed, 8 Mar 2017 14:30:54 -0800 Subject: [PATCH] Revert "[SPARK-19413][SS] MapGroupsWithState for arbitrary stateful operations for branch-2.1" This reverts commit 502c927b8c8a99ef2adf4e6e1d7a6d9232d45ef5. --- .../UnsupportedOperationChecker.scala | 11 +- .../sql/catalyst/plans/logical/object.scala | 49 --- .../analysis/UnsupportedOperationsSuite.scala | 24 +- .../FlatMapGroupsWithStateFunction.java | 38 -- .../function/MapGroupsWithStateFunction.java | 38 -- .../spark/sql/KeyValueGroupedDataset.scala | 113 ------ .../org/apache/spark/sql/KeyedState.scala | 142 -------- .../spark/sql/execution/SparkStrategies.scala | 21 +- .../apache/spark/sql/execution/objects.scala | 22 -- .../streaming/IncrementalExecution.scala | 19 +- .../execution/streaming/KeyedStateImpl.scala | 80 ----- .../streaming/ProgressReporter.scala | 2 +- ...perators.scala => StatefulAggregate.scala} | 134 ++----- .../state/HDFSBackedStateStoreProvider.scala | 19 - .../streaming/state/StateStore.scala | 5 - .../execution/streaming/state/package.scala | 11 +- .../apache/spark/sql/JavaDatasetSuite.java | 32 -- .../streaming/MapGroupsWithStateSuite.scala | 335 ------------------ 18 files changed, 36 insertions(+), 1059 deletions(-) delete mode 100644 sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java delete mode 100644 sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{statefulOperators.scala => StatefulAggregate.scala} (63%) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index d8aad42edc..f4d016cb96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -46,13 +46,8 @@ object UnsupportedOperationChecker { "Queries without streaming sources cannot be executed with writeStream.start()")(plan) } - /** Collect all the streaming aggregates in a sub plan */ - def collectStreamingAggregates(subplan: LogicalPlan): Seq[Aggregate] = { - subplan.collect { case a: Aggregate if a.isStreaming => a } - } - // Disallow multiple streaming aggregations - val aggregates = collectStreamingAggregates(plan) + val aggregates = plan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a } if (aggregates.size > 1) { throwError( @@ -119,10 +114,6 @@ object UnsupportedOperationChecker { case _: InsertIntoTable => throwError("InsertIntoTable is not supported with streaming DataFrames/Datasets") - case m: MapGroupsWithState if collectStreamingAggregates(m).nonEmpty => - throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " + - "streaming DataFrame/Dataset") - case Join(left, right, joinType, _) => joinType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 0be4823bbc..0ab4c90166 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -313,55 +313,6 @@ case class MapGroups( outputObjAttr: Attribute, child: LogicalPlan) extends UnaryNode with ObjectProducer -/** Internal class representing State */ -trait LogicalKeyedState[S] - -/** Factory for constructing new `MapGroupsWithState` nodes. */ -object MapGroupsWithState { - def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder]( - func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], - groupingAttributes: Seq[Attribute], - dataAttributes: Seq[Attribute], - child: LogicalPlan): LogicalPlan = { - val mapped = new MapGroupsWithState( - func, - UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), - UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes), - groupingAttributes, - dataAttributes, - CatalystSerde.generateObjAttr[U], - encoderFor[S].resolveAndBind().deserializer, - encoderFor[S].namedExpressions, - child) - CatalystSerde.serialize[U](mapped) - } -} - -/** - * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`, - * while using state data. - * Func is invoked with an object representation of the grouping key an iterator containing the - * object representation of all the rows with that key. - * - * @param keyDeserializer used to extract the key object for each group. - * @param valueDeserializer used to extract the items in the iterator from an input row. - * @param groupingAttributes used to group the data - * @param dataAttributes used to read the data - * @param outputObjAttr used to define the output object - * @param stateDeserializer used to deserialize state before calling `func` - * @param stateSerializer used to serialize updated state after calling `func` - */ -case class MapGroupsWithState( - func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], - keyDeserializer: Expression, - valueDeserializer: Expression, - groupingAttributes: Seq[Attribute], - dataAttributes: Seq[Attribute], - outputObjAttr: Attribute, - stateDeserializer: Expression, - stateSerializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode with ObjectProducer - /** Factory for constructing new `FlatMapGroupsInR` nodes. */ object FlatMapGroupsInR { def apply( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 3b756e89d9..dcdb1ae089 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -22,13 +22,13 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.{IntegerType, LongType} +import org.apache.spark.sql.types.IntegerType /** A dummy command for testing unsupported operations. */ case class DummyCommand() extends Command @@ -111,24 +111,6 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Complete, expectedMsgs = Seq("distinct aggregation")) - // MapGroupsWithState: Not supported after a streaming aggregation - val att = new AttributeReference(name = "a", dataType = LongType)() - assertSupportedInBatchPlan( - "mapGroupsWithState - mapGroupsWithState on batch relation", - MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation)) - - assertSupportedInStreamingPlan( - "mapGroupsWithState - mapGroupsWithState on streaming relation before aggregation", - MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), streamRelation), - outputMode = Append) - - assertNotSupportedInStreamingPlan( - "mapGroupsWithState - mapGroupsWithState on streaming relation after aggregation", - MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), - Aggregate(Nil, aggExprs("c"), streamRelation)), - outputMode = Complete, - expectedMsgs = Seq("(map/flatMap)GroupsWithState")) - // Inner joins: Stream-stream not supported testBinaryOperationInStreamingPlan( "inner join", diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java deleted file mode 100644 index 2570c8d02a..0000000000 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.java.function; - -import java.io.Serializable; -import java.util.Iterator; - -import org.apache.spark.annotation.Experimental; -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Encoder; -import org.apache.spark.sql.KeyedState; - -/** - * ::Experimental:: - * Base interface for a map function used in - * {@link org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroupsWithState(FlatMapGroupsWithStateFunction, Encoder, Encoder)}. - * @since 2.1.1 - */ -@Experimental -@InterfaceStability.Evolving -public interface FlatMapGroupsWithStateFunction<K, V, S, R> extends Serializable { - Iterator<R> call(K key, Iterator<V> values, KeyedState<S> state) throws Exception; -} diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java deleted file mode 100644 index 614d3925e0..0000000000 --- a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.java.function; - -import java.io.Serializable; -import java.util.Iterator; - -import org.apache.spark.annotation.Experimental; -import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.Encoder; -import org.apache.spark.sql.KeyedState; - -/** - * ::Experimental:: - * Base interface for a map function used in - * {@link org.apache.spark.sql.KeyValueGroupedDataset#mapGroupsWithState(MapGroupsWithStateFunction, Encoder, Encoder)} - * @since 2.1.1 - */ -@Experimental -@InterfaceStability.Evolving -public interface MapGroupsWithStateFunction<K, V, S, R> extends Serializable { - R call(K key, Iterator<V> values, KeyedState<S> state) throws Exception; -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 94e689a4d5..395d709f26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -218,119 +218,6 @@ class KeyValueGroupedDataset[K, V] private[sql]( mapGroups((key, data) => f.call(key, data.asJava))(encoder) } - /** - * ::Experimental:: - * (Scala-specific) - * Applies the given function to each group of data, while maintaining a user-defined per-group - * state. The result Dataset will represent the objects returned by the function. - * For a static batch Dataset, the function will be invoked once per group. For a streaming - * Dataset, the function will be invoked for each group repeatedly in every trigger, and - * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. - * - * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 - */ - @Experimental - @InterfaceStability.Evolving - def mapGroupsWithState[S: Encoder, U: Encoder]( - func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = { - flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s))) - } - - /** - * ::Experimental:: - * (Java-specific) - * Applies the given function to each group of data, while maintaining a user-defined per-group - * state. The result Dataset will represent the objects returned by the function. - * For a static batch Dataset, the function will be invoked once per group. For a streaming - * Dataset, the function will be invoked for each group repeatedly in every trigger, and - * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. - * - * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * @param stateEncoder Encoder for the state type. - * @param outputEncoder Encoder for the output type. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 - */ - @Experimental - @InterfaceStability.Evolving - def mapGroupsWithState[S, U]( - func: MapGroupsWithStateFunction[K, V, S, U], - stateEncoder: Encoder[S], - outputEncoder: Encoder[U]): Dataset[U] = { - flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func.call(key, it.asJava, s)) - )(stateEncoder, outputEncoder) - } - - /** - * ::Experimental:: - * (Scala-specific) - * Applies the given function to each group of data, while maintaining a user-defined per-group - * state. The result Dataset will represent the objects returned by the function. - * For a static batch Dataset, the function will be invoked once per group. For a streaming - * Dataset, the function will be invoked for each group repeatedly in every trigger, and - * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. - * - * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 - */ - @Experimental - @InterfaceStability.Evolving - def flatMapGroupsWithState[S: Encoder, U: Encoder]( - func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = { - Dataset[U]( - sparkSession, - MapGroupsWithState[K, V, S, U]( - func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], - groupingAttributes, - dataAttributes, - logicalPlan)) - } - - /** - * ::Experimental:: - * (Java-specific) - * Applies the given function to each group of data, while maintaining a user-defined per-group - * state. The result Dataset will represent the objects returned by the function. - * For a static batch Dataset, the function will be invoked once per group. For a streaming - * Dataset, the function will be invoked for each group repeatedly in every trigger, and - * updates to each group's state will be saved across invocations. - * See [[KeyedState]] for more details. - * - * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. - * @tparam U The type of the output objects. Must be encodable to Spark SQL types. - * @param func Function to be called on every group. - * @param stateEncoder Encoder for the state type. - * @param outputEncoder Encoder for the output type. - * - * See [[Encoder]] for more details on what types are encodable to Spark SQL. - * @since 2.1.1 - */ - @Experimental - @InterfaceStability.Evolving - def flatMapGroupsWithState[S, U]( - func: FlatMapGroupsWithStateFunction[K, V, S, U], - stateEncoder: Encoder[S], - outputEncoder: Encoder[U]): Dataset[U] = { - flatMapGroupsWithState[S, U]( - (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala - )(stateEncoder, outputEncoder) - } - /** * (Scala-specific) * Reduces the elements of each group of data using the specified binary function. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala deleted file mode 100644 index 6864b6f6b4..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import java.lang.IllegalArgumentException - -import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState - -/** - * :: Experimental :: - * - * Wrapper class for interacting with keyed state data in `mapGroupsWithState` and - * `flatMapGroupsWithState` operations on - * [[KeyValueGroupedDataset]]. - * - * Detail description on `[map/flatMap]GroupsWithState` operation - * ------------------------------------------------------------ - * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in [[KeyValueGroupedDataset]] - * will invoke the user-given function on each group (defined by the grouping function in - * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations. - * For a static batch Dataset, the function will be invoked once per group. For a streaming - * Dataset, the function will be invoked for each group repeatedly in every trigger. - * That is, in every batch of the [[streaming.StreamingQuery StreamingQuery]], - * the function will be invoked once for each group that has data in the batch. - * - * The function is invoked with following parameters. - * - The key of the group. - * - An iterator containing all the values for this key. - * - A user-defined state object set by previous invocations of the given function. - * In case of a batch Dataset, there is only one invocation and state object will be empty as - * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` - * is equivalent to `[map/flatMap]Groups`. - * - * Important points to note about the function. - * - In a trigger, the function will be called only the groups present in the batch. So do not - * assume that the function will be called in every trigger for every group that has state. - * - There is no guaranteed ordering of values in the iterator in the function, neither with - * batch, nor with streaming Datasets. - * - All the data will be shuffled before applying the function. - * - * Important points to note about using KeyedState. - * - The value of the state cannot be null. So updating state with null will throw - * `IllegalArgumentException`. - * - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers. - * - If `remove()` is called, then `exists()` will return `false`, - * `get()` will throw `NoSuchElementException` and `getOption()` will return `None` - * - After that, if `update(newState)` is called, then `exists()` will again return `true`, - * `get()` and `getOption()`will return the updated value. - * - * Scala example of using KeyedState in `mapGroupsWithState`: - * {{{ - * /* A mapping function that maintains an integer state for string keys and returns a string. */ - * def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = { - * // Check if state exists - * if (state.exists) { - * val existingState = state.get // Get the existing state - * val shouldRemove = ... // Decide whether to remove the state - * if (shouldRemove) { - * state.remove() // Remove the state - * } else { - * val newState = ... - * state.update(newState) // Set the new state - * } - * } else { - * val initialState = ... - * state.update(initialState) // Set the initial state - * } - * ... // return something - * } - * - * }}} - * - * Java example of using `KeyedState`: - * {{{ - * /* A mapping function that maintains an integer state for string keys and returns a string. */ - * MapGroupsWithStateFunction<String, Integer, Integer, String> mappingFunction = - * new MapGroupsWithStateFunction<String, Integer, Integer, String>() { - * - * @Override - * public String call(String key, Iterator<Integer> value, KeyedState<Integer> state) { - * if (state.exists()) { - * int existingState = state.get(); // Get the existing state - * boolean shouldRemove = ...; // Decide whether to remove the state - * if (shouldRemove) { - * state.remove(); // Remove the state - * } else { - * int newState = ...; - * state.update(newState); // Set the new state - * } - * } else { - * int initialState = ...; // Set the initial state - * state.update(initialState); - * } - * ... // return something - * } - * }; - * }}} - * - * @tparam S User-defined type of the state to be stored for each key. Must be encodable into - * Spark SQL types (see [[Encoder]] for more details). - * @since 2.1.1 - */ -@Experimental -@InterfaceStability.Evolving -trait KeyedState[S] extends LogicalKeyedState[S] { - - /** Whether state exists or not. */ - def exists: Boolean - - /** Get the state value if it exists, or throw NoSuchElementException. */ - @throws[NoSuchElementException]("when state does not exist") - def get: S - - /** Get the state value as a scala Option. */ - def getOption: Option[S] - - /** - * Update the value of the state. Note that `null` is not a valid value, and it throws - * IllegalArgumentException. - */ - @throws[IllegalArgumentException]("when updating with null") - def update(newState: S): Unit - - /** Remove this keyed state. */ - def remove(): Unit -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index adea358594..ba82ec156e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan, MapGroupsWithState} +import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} @@ -324,23 +324,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - /** - * Strategy to convert MapGroupsWithState logical operator to physical operator - * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. - */ - object MapGroupsWithStateStrategy extends Strategy { - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case MapGroupsWithState( - f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateDeser, stateSer, child) => - val execPlan = MapGroupsWithStateExec( - f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer, - planLater(child)) - execPlan :: Nil - case _ => - Nil - } - } - // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def numPartitions: Int = self.numPartitions @@ -382,8 +365,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil - case logical.MapGroupsWithState(f, key, value, grouping, data, output, _, _, child) => - execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 199ba5ce69..fde3b2a528 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -30,8 +30,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState -import org.apache.spark.sql.execution.streaming.KeyedStateImpl import org.apache.spark.sql.types.{DataType, ObjectType, StructType} @@ -146,11 +144,6 @@ object ObjectOperator { (i: InternalRow) => proj(i).get(0, deserializer.dataType) } - def deserializeRowToObject(deserializer: Expression): InternalRow => Any = { - val proj = GenerateSafeProjection.generate(deserializer :: Nil) - (i: InternalRow) => proj(i).get(0, deserializer.dataType) - } - def serializeObjectToRow(serializer: Seq[Expression]): Any => UnsafeRow = { val proj = GenerateUnsafeProjection.generate(serializer) val objType = serializer.head.collect { case b: BoundReference => b.dataType }.head @@ -351,21 +344,6 @@ case class MapGroupsExec( } } -object MapGroupsExec { - def apply( - func: (Any, Iterator[Any], LogicalKeyedState[Any]) => TraversableOnce[Any], - keyDeserializer: Expression, - valueDeserializer: Expression, - groupingAttributes: Seq[Attribute], - dataAttributes: Seq[Attribute], - outputObjAttr: Attribute, - child: SparkPlan): MapGroupsExec = { - val f = (key: Any, values: Iterator[Any]) => func(key, values, new KeyedStateImpl[Any](None)) - new MapGroupsExec(f, keyDeserializer, valueDeserializer, - groupingAttributes, dataAttributes, outputObjAttr, child) - } -} - /** * Groups the input rows together and calls the R function with each group and an iterator * containing all elements in the group. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 5c4cbfa755..6ab6fa61dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.streaming -import java.util.concurrent.atomic.AtomicInteger - import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, Literal} import org.apache.spark.sql.SparkSession @@ -41,9 +39,8 @@ class IncrementalExecution( extends QueryExecution(sparkSession, logicalPlan) with Logging { // TODO: make this always part of planning. - val streamingExtraStrategies = + val stateStrategy = sparkSession.sessionState.planner.StatefulAggregationStrategy +: - sparkSession.sessionState.planner.MapGroupsWithStateStrategy +: sparkSession.sessionState.planner.StreamingRelationStrategy +: sparkSession.sessionState.experimentalMethods.extraStrategies @@ -52,7 +49,7 @@ class IncrementalExecution( new SparkPlanner( sparkSession.sparkContext, sparkSession.sessionState.conf, - streamingExtraStrategies) + stateStrategy) /** * See [SPARK-18339] @@ -71,7 +68,7 @@ class IncrementalExecution( * Records the current id for a given stateful operator in the query plan as the `state` * preparation walks the query plan. */ - private val operatorId = new AtomicInteger(0) + private var operatorId = 0 /** Locates save/restore pairs surrounding aggregation. */ val state = new Rule[SparkPlan] { @@ -80,8 +77,8 @@ class IncrementalExecution( case StateStoreSaveExec(keys, None, None, None, UnaryExecNode(agg, StateStoreRestoreExec(keys2, None, child))) => - val stateId = - OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) + val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId) + operatorId += 1 StateStoreSaveExec( keys, @@ -93,12 +90,6 @@ class IncrementalExecution( keys, Some(stateId), child) :: Nil)) - case MapGroupsWithStateExec( - f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) => - val stateId = - OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) - MapGroupsWithStateExec( - f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala deleted file mode 100644 index eee7ec45dd..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import org.apache.spark.sql.KeyedState - -/** Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. */ -private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedState[S] { - private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) - private var defined: Boolean = optionalValue.isDefined - private var updated: Boolean = false - // whether value has been updated (but not removed) - private var removed: Boolean = false // whether value has been removed - - // ========= Public API ========= - override def exists: Boolean = defined - - override def get: S = { - if (defined) { - value - } else { - throw new NoSuchElementException("State is either not defined or has already been removed") - } - } - - override def getOption: Option[S] = { - if (defined) { - Some(value) - } else { - None - } - } - - override def update(newValue: S): Unit = { - if (newValue == null) { - throw new IllegalArgumentException("'null' is not a valid state value") - } - value = newValue - defined = true - updated = true - removed = false - } - - override def remove(): Unit = { - defined = false - updated = false - removed = true - } - - override def toString: String = { - s"KeyedState(${getOption.map(_.toString).getOrElse("<undefined>")})" - } - - // ========= Internal API ========= - - /** Whether the state has been marked for removing */ - def isRemoved: Boolean = { - removed - } - - /** Whether the state has been been updated */ - def isUpdated: Boolean = { - updated - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 693933f95a..1f74fffbe6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -186,7 +186,7 @@ trait ProgressReporter extends Logging { // lastExecution could belong to one of the previous triggers if `!hasNewData`. // Walking the plan again should be inexpensive. val stateNodes = lastExecution.executedPlan.collect { - case p if p.isInstanceOf[StateStoreWriter] => p + case p if p.isInstanceOf[StateStoreSaveExec] => p } stateNodes.map { node => val numRowsUpdated = if (hasNewData) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala similarity index 63% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala index 1292452574..d4ccced9ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -22,16 +22,16 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState} -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution -import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType -import org.apache.spark.util.CompletionIterator +import org.apache.spark.TaskContext /** Used to identify the state store for a given operator. */ @@ -41,7 +41,7 @@ case class OperatorStateId( batchId: Long) /** - * An operator that reads or writes state from the [[StateStore]]. The [[OperatorStateId]] should + * An operator that saves or restores state from the [[StateStore]]. The [[OperatorStateId]] should * be filled in by `prepareForExecution` in [[IncrementalExecution]]. */ trait StatefulOperator extends SparkPlan { @@ -54,20 +54,6 @@ trait StatefulOperator extends SparkPlan { } } -/** An operator that reads from a StateStore. */ -trait StateStoreReader extends StatefulOperator { - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) -} - -/** An operator that writes to a StateStore. */ -trait StateStoreWriter extends StatefulOperator { - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"), - "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows")) -} - /** * For each input tuple, the key is calculated and the value from the [[StateStore]] is added * to the stream (in addition to the input tuple) if present. @@ -76,7 +62,10 @@ case class StateStoreRestoreExec( keyExpressions: Seq[Attribute], stateId: Option[OperatorStateId], child: SparkPlan) - extends execution.UnaryExecNode with StateStoreReader { + extends execution.UnaryExecNode with StatefulOperator { + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) override protected def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") @@ -113,7 +102,12 @@ case class StateStoreSaveExec( outputMode: Option[OutputMode] = None, eventTimeWatermark: Option[Long] = None, child: SparkPlan) - extends execution.UnaryExecNode with StateStoreWriter { + extends execution.UnaryExecNode with StatefulOperator { + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"), + "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows")) /** Generate a predicate that matches data older than the watermark */ private lazy val watermarkPredicate: Option[Predicate] = { @@ -157,6 +151,13 @@ case class StateStoreSaveExec( val numTotalStateRows = longMetric("numTotalStateRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") + // Abort the state store in case of error + TaskContext.get().addTaskCompletionListener(_ => { + if (!store.hasCommitted) { + store.abort() + } + }) + outputMode match { // Update and output all rows in the StateStore. case Some(Complete) => @@ -183,7 +184,7 @@ case class StateStoreSaveExec( } // Assumption: Append mode can be done only when watermark has been specified - store.remove(watermarkPredicate.get.eval _) + store.remove(watermarkPredicate.get.eval) store.commit() numTotalStateRows += store.numKeys() @@ -206,7 +207,7 @@ case class StateStoreSaveExec( override def hasNext: Boolean = { if (!baseIterator.hasNext) { // Remove old aggregates if watermark specified - if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval _) + if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval) store.commit() numTotalStateRows += store.numKeys() false @@ -234,90 +235,3 @@ case class StateStoreSaveExec( override def outputPartitioning: Partitioning = child.outputPartitioning } - - -/** Physical operator for executing streaming mapGroupsWithState. */ -case class MapGroupsWithStateExec( - func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], - keyDeserializer: Expression, - valueDeserializer: Expression, - groupingAttributes: Seq[Attribute], - dataAttributes: Seq[Attribute], - outputObjAttr: Attribute, - stateId: Option[OperatorStateId], - stateDeserializer: Expression, - stateSerializer: Seq[NamedExpression], - child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter { - - override def outputPartitioning: Partitioning = child.outputPartitioning - - /** Distribute by grouping attributes */ - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(groupingAttributes) :: Nil - - /** Ordering needed for using GroupingIterator */ - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingAttributes.map(SortOrder(_, Ascending))) - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsWithStateStore[InternalRow]( - getStateId.checkpointLocation, - operatorId = getStateId.operatorId, - storeVersion = getStateId.batchId, - groupingAttributes.toStructType, - child.output.toStructType, - sqlContext.sessionState, - Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => - val numTotalStateRows = longMetric("numTotalStateRows") - val numUpdatedStateRows = longMetric("numUpdatedStateRows") - val numOutputRows = longMetric("numOutputRows") - - // Generate a iterator that returns the rows grouped by the grouping function - val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) - - // Converters to and from object and rows - val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) - val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) - val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) - val getStateObj = - ObjectOperator.deserializeRowToObject(stateDeserializer) - val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer) - - // For every group, get the key, values and corresponding state and call the function, - // and return an iterator of rows - val allRowsIterator = groupedIter.flatMap { case (keyRow, valueRowIter) => - - val key = keyRow.asInstanceOf[UnsafeRow] - val keyObj = getKeyObj(keyRow) // convert key to objects - val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects - val stateObjOption = store.get(key).map(getStateObj) // get existing state if any - val wrappedState = new KeyedStateImpl(stateObjOption) - val mappedIterator = func(keyObj, valueObjIter, wrappedState).map { obj => - numOutputRows += 1 - getOutputRow(obj) // convert back to rows - } - - // Return an iterator of rows generated this key, - // such that fully consumed, the updated state value will be saved - CompletionIterator[InternalRow, Iterator[InternalRow]]( - mappedIterator, { - // When the iterator is consumed, then write changes to state - if (wrappedState.isRemoved) { - store.remove(key) - numUpdatedStateRows += 1 - } else if (wrappedState.isUpdated) { - store.put(key, outputStateObj(wrappedState.get)) - numUpdatedStateRows += 1 - } - }) - } - - // Return an iterator of all the rows generated by all the keys, such that when fully - // consumer, all the state updates will be committed by the state store - CompletionIterator[InternalRow, Iterator[InternalRow]](allRowsIterator, { - store.commit() - numTotalStateRows += store.numKeys() - }) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index ab1204a750..f53b9b9a43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -147,25 +147,6 @@ private[state] class HDFSBackedStateStoreProvider( } } - /** Remove a single key. */ - override def remove(key: UnsafeRow): Unit = { - verify(state == UPDATING, "Cannot remove after already committed or aborted") - if (mapToUpdate.containsKey(key)) { - val value = mapToUpdate.remove(key) - Option(allUpdates.get(key)) match { - case Some(ValueUpdated(_, _)) | None => - // Value existed in previous version and maybe was updated, mark removed - allUpdates.put(key, ValueRemoved(key, value)) - case Some(ValueAdded(_, _)) => - // Value did not exist in previous version and was added, should not appear in updates - allUpdates.remove(key) - case Some(ValueRemoved(_, _)) => - // Remove already in update map, no need to change - } - writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value)) - } - } - /** Commit all the updates that have been made to the store, and return the new version. */ override def commit(): Long = { verify(state == UPDATING, "Cannot commit after already committed or aborted") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index dcb24b26f7..e61d95a1b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -58,11 +58,6 @@ trait StateStore { */ def remove(condition: UnsafeRow => Boolean): Unit - /** - * Remove a single key. - */ - def remove(key: UnsafeRow): Unit - /** * Commit all the updates that have been made to the store, and return the new version. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 589042afb1..1b56c08f72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.streaming import scala.reflect.ClassTag -import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.internal.SessionState @@ -60,18 +59,10 @@ package object state { sessionState: SessionState, storeCoordinator: Option[StateStoreCoordinatorRef])( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { - val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) - val wrappedF = (store: StateStore, iter: Iterator[T]) => { - // Abort the state store in case of error - TaskContext.get().addTaskCompletionListener(_ => { - if (!store.hasCommitted) store.abort() - }) - cleanedF(store, iter) - } new StateStoreRDD( dataRDD, - wrappedF, + cleanedF, checkpointLocation, operatorId, storeVersion, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 5ef4e887de..8304b728aa 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -225,38 +225,6 @@ public class JavaDatasetSuite implements Serializable { Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList())); - Dataset<String> mapped2 = grouped.mapGroupsWithState( - new MapGroupsWithStateFunction<Integer, String, Long, String>() { - @Override - public String call(Integer key, Iterator<String> values, KeyedState<Long> s) throws Exception { - StringBuilder sb = new StringBuilder(key.toString()); - while (values.hasNext()) { - sb.append(values.next()); - } - return sb.toString(); - } - }, - Encoders.LONG(), - Encoders.STRING()); - - Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped2.collectAsList())); - - Dataset<String> flatMapped2 = grouped.flatMapGroupsWithState( - new FlatMapGroupsWithStateFunction<Integer, String, Long, String>() { - @Override - public Iterator<String> call(Integer key, Iterator<String> values, KeyedState<Long> s) { - StringBuilder sb = new StringBuilder(key.toString()); - while (values.hasNext()) { - sb.append(values.next()); - } - return Collections.singletonList(sb.toString()).iterator(); - } - }, - Encoders.LONG(), - Encoders.STRING()); - - Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList())); - Dataset<Tuple2<Integer, String>> reduced = grouped.reduceGroups(new ReduceFunction<String>() { @Override public String call(String v1, String v2) throws Exception { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala deleted file mode 100644 index 0524898b15..0000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ /dev/null @@ -1,335 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.SparkException -import org.apache.spark.sql.KeyedState -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.streaming.{KeyedStateImpl, MemoryStream} -import org.apache.spark.sql.execution.streaming.state.StateStore - -/** Class to check custom state types */ -case class RunningCount(count: Long) - -class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { - - import testImplicits._ - - override def afterAll(): Unit = { - super.afterAll() - StateStore.stop() - } - - test("KeyedState - get, exists, update, remove") { - var state: KeyedStateImpl[String] = null - - def testState( - expectedData: Option[String], - shouldBeUpdated: Boolean = false, - shouldBeRemoved: Boolean = false): Unit = { - if (expectedData.isDefined) { - assert(state.exists) - assert(state.get === expectedData.get) - } else { - assert(!state.exists) - intercept[NoSuchElementException] { - state.get - } - } - assert(state.getOption === expectedData) - assert(state.isUpdated === shouldBeUpdated) - assert(state.isRemoved === shouldBeRemoved) - } - - // Updating empty state - state = new KeyedStateImpl[String](None) - testState(None) - state.update("") - testState(Some(""), shouldBeUpdated = true) - - // Updating exiting state - state = new KeyedStateImpl[String](Some("2")) - testState(Some("2")) - state.update("3") - testState(Some("3"), shouldBeUpdated = true) - - // Removing state - state.remove() - testState(None, shouldBeRemoved = true, shouldBeUpdated = false) - state.remove() // should be still callable - state.update("4") - testState(Some("4"), shouldBeRemoved = false, shouldBeUpdated = true) - - // Updating by null throw exception - intercept[IllegalArgumentException] { - state.update(null) - } - } - - test("KeyedState - primitive type") { - var intState = new KeyedStateImpl[Int](None) - intercept[NoSuchElementException] { - intState.get - } - assert(intState.getOption === None) - - intState = new KeyedStateImpl[Int](Some(10)) - assert(intState.get == 10) - intState.update(0) - assert(intState.get == 0) - intState.remove() - intercept[NoSuchElementException] { - intState.get - } - } - - test("flatMapGroupsWithState - streaming") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count if state is defined, otherwise does not return anything - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - - val count = state.getOption.map(_.count).getOrElse(0L) + values.size - if (count == 3) { - state.remove() - Iterator.empty - } else { - state.update(RunningCount(count)) - Iterator((key, count.toString)) - } - } - - val inputData = MemoryStream[String] - val result = - inputData.toDS() - .groupByKey(x => x) - .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) - - testStream(result, Append)( - AddData(inputData, "a"), - CheckLastBatch(("a", "1")), - assertNumStateRows(total = 1, updated = 1), - AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), - assertNumStateRows(total = 2, updated = 2), - StopStream, - StartStream(), - AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckLastBatch(("b", "2")), - assertNumStateRows(total = 1, updated = 2), - StopStream, - StartStream(), - AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")), - assertNumStateRows(total = 3, updated = 2) - ) - } - - test("flatMapGroupsWithState - streaming + func returns iterator that updates state lazily") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count if state is defined, otherwise does not return anything - // Additionally, it updates state lazily as the returned iterator get consumed - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - values.flatMap { _ => - val count = state.getOption.map(_.count).getOrElse(0L) + 1 - if (count == 3) { - state.remove() - None - } else { - state.update(RunningCount(count)) - Some((key, count.toString)) - } - } - } - - val inputData = MemoryStream[String] - val result = - inputData.toDS() - .groupByKey(x => x) - .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) - - testStream(result, Append)( - AddData(inputData, "a", "a", "b"), - CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), - StopStream, - StartStream(), - AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckLastBatch(("b", "2")), - StopStream, - StartStream(), - AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")) - ) - } - - test("flatMapGroupsWithState - batch") { - // Function that returns running count only if its even, otherwise does not return - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - if (state.exists) throw new IllegalArgumentException("state.exists should be false") - Iterator((key, values.size)) - } - checkAnswer( - Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc).toDF, - Seq(("a", 2), ("b", 1)).toDF) - } - - test("mapGroupsWithState - streaming") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - - val count = state.getOption.map(_.count).getOrElse(0L) + values.size - if (count == 3) { - state.remove() - (key, "-1") - } else { - state.update(RunningCount(count)) - (key, count.toString) - } - } - - val inputData = MemoryStream[String] - val result = - inputData.toDS() - .groupByKey(x => x) - .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) - - testStream(result, Append)( - AddData(inputData, "a"), - CheckLastBatch(("a", "1")), - assertNumStateRows(total = 1, updated = 1), - AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), - assertNumStateRows(total = 2, updated = 2), - StopStream, - StartStream(), - AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 - CheckLastBatch(("a", "-1"), ("b", "2")), - assertNumStateRows(total = 1, updated = 2), - StopStream, - StartStream(), - AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 - CheckLastBatch(("a", "1"), ("c", "1")), - assertNumStateRows(total = 3, updated = 2) - ) - } - - test("mapGroupsWithState - streaming + aggregation") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - - val count = state.getOption.map(_.count).getOrElse(0L) + values.size - if (count == 3) { - state.remove() - (key, "-1") - } else { - state.update(RunningCount(count)) - (key, count.toString) - } - } - - val inputData = MemoryStream[String] - val result = - inputData.toDS() - .groupByKey(x => x) - .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) - .groupByKey(_._1) - .count() - - testStream(result, Complete)( - AddData(inputData, "a"), - CheckLastBatch(("a", 1)), - AddData(inputData, "a", "b"), - // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 - CheckLastBatch(("a", 2), ("b", 1)), - StopStream, - StartStream(), - AddData(inputData, "a", "b"), - // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; - // so increment a and b by 1 - CheckLastBatch(("a", 3), ("b", 2)), - StopStream, - StartStream(), - AddData(inputData, "a", "c"), - // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; - // so increment a and c by 1 - CheckLastBatch(("a", 4), ("b", 2), ("c", 1)) - ) - } - - test("mapGroupsWithState - batch") { - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - if (state.exists) throw new IllegalArgumentException("state.exists should be false") - (key, values.size) - } - - checkAnswer( - spark.createDataset(Seq("a", "a", "b")) - .groupByKey(x => x) - .mapGroupsWithState(stateFunc) - .toDF, - spark.createDataset(Seq(("a", 2), ("b", 1))).toDF) - } - - testQuietly("StateStore.abort on task failure handling") { - val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure") - val count = state.getOption.map(_.count).getOrElse(0L) + values.size - state.update(RunningCount(count)) - (key, count) - } - - val inputData = MemoryStream[String] - val result = - inputData.toDS() - .groupByKey(x => x) - .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) - - def setFailInTask(value: Boolean): AssertOnQuery = AssertOnQuery { q => - MapGroupsWithStateSuite.failInTask = value - true - } - - testStream(result, Append)( - setFailInTask(false), - AddData(inputData, "a"), - CheckLastBatch(("a", 1L)), - AddData(inputData, "a"), - CheckLastBatch(("a", 2L)), - setFailInTask(true), - AddData(inputData, "a"), - ExpectFailure[SparkException](), // task should fail but should not increment count - setFailInTask(false), - StartStream(), - CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count - ) - } - - private def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = AssertOnQuery { q => - val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get - assert(progressWithData.stateOperators(0).numRowsTotal === total, "incorrect total rows") - assert(progressWithData.stateOperators(0).numRowsUpdated === updated, "incorrect updates rows") - true - } -} - -object MapGroupsWithStateSuite { - var failInTask = true -} -- GitLab