From 3457c32297e0150a4fbc80a30f84b9c62ca7c372 Mon Sep 17 00:00:00 2001
From: Shixiong Zhu <shixiong@databricks.com>
Date: Wed, 8 Mar 2017 14:30:54 -0800
Subject: [PATCH] Revert "[SPARK-19413][SS] MapGroupsWithState for arbitrary
 stateful operations for branch-2.1"

This reverts commit 502c927b8c8a99ef2adf4e6e1d7a6d9232d45ef5.
---
 .../UnsupportedOperationChecker.scala         |  11 +-
 .../sql/catalyst/plans/logical/object.scala   |  49 ---
 .../analysis/UnsupportedOperationsSuite.scala |  24 +-
 .../FlatMapGroupsWithStateFunction.java       |  38 --
 .../function/MapGroupsWithStateFunction.java  |  38 --
 .../spark/sql/KeyValueGroupedDataset.scala    | 113 ------
 .../org/apache/spark/sql/KeyedState.scala     | 142 --------
 .../spark/sql/execution/SparkStrategies.scala |  21 +-
 .../apache/spark/sql/execution/objects.scala  |  22 --
 .../streaming/IncrementalExecution.scala      |  19 +-
 .../execution/streaming/KeyedStateImpl.scala  |  80 -----
 .../streaming/ProgressReporter.scala          |   2 +-
 ...perators.scala => StatefulAggregate.scala} | 134 ++-----
 .../state/HDFSBackedStateStoreProvider.scala  |  19 -
 .../streaming/state/StateStore.scala          |   5 -
 .../execution/streaming/state/package.scala   |  11 +-
 .../apache/spark/sql/JavaDatasetSuite.java    |  32 --
 .../streaming/MapGroupsWithStateSuite.scala   | 335 ------------------
 18 files changed, 36 insertions(+), 1059 deletions(-)
 delete mode 100644 sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
 delete mode 100644 sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
 delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala
 delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala
 rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{statefulOperators.scala => StatefulAggregate.scala} (63%)
 delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index d8aad42edc..f4d016cb96 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -46,13 +46,8 @@ object UnsupportedOperationChecker {
         "Queries without streaming sources cannot be executed with writeStream.start()")(plan)
     }
 
-    /** Collect all the streaming aggregates in a sub plan */
-    def collectStreamingAggregates(subplan: LogicalPlan): Seq[Aggregate] = {
-      subplan.collect { case a: Aggregate if a.isStreaming => a }
-    }
-
     // Disallow multiple streaming aggregations
-    val aggregates = collectStreamingAggregates(plan)
+    val aggregates = plan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a }
 
     if (aggregates.size > 1) {
       throwError(
@@ -119,10 +114,6 @@ object UnsupportedOperationChecker {
         case _: InsertIntoTable =>
           throwError("InsertIntoTable is not supported with streaming DataFrames/Datasets")
 
-        case m: MapGroupsWithState if collectStreamingAggregates(m).nonEmpty =>
-          throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " +
-            "streaming DataFrame/Dataset")
-
         case Join(left, right, joinType, _) =>
 
           joinType match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 0be4823bbc..0ab4c90166 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -313,55 +313,6 @@ case class MapGroups(
     outputObjAttr: Attribute,
     child: LogicalPlan) extends UnaryNode with ObjectProducer
 
-/** Internal class representing State */
-trait LogicalKeyedState[S]
-
-/** Factory for constructing new `MapGroupsWithState` nodes. */
-object MapGroupsWithState {
-  def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder](
-      func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
-      groupingAttributes: Seq[Attribute],
-      dataAttributes: Seq[Attribute],
-      child: LogicalPlan): LogicalPlan = {
-    val mapped = new MapGroupsWithState(
-      func,
-      UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
-      UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes),
-      groupingAttributes,
-      dataAttributes,
-      CatalystSerde.generateObjAttr[U],
-      encoderFor[S].resolveAndBind().deserializer,
-      encoderFor[S].namedExpressions,
-      child)
-    CatalystSerde.serialize[U](mapped)
-  }
-}
-
-/**
- * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`,
- * while using state data.
- * Func is invoked with an object representation of the grouping key an iterator containing the
- * object representation of all the rows with that key.
- *
- * @param keyDeserializer used to extract the key object for each group.
- * @param valueDeserializer used to extract the items in the iterator from an input row.
- * @param groupingAttributes used to group the data
- * @param dataAttributes used to read the data
- * @param outputObjAttr used to define the output object
- * @param stateDeserializer used to deserialize state before calling `func`
- * @param stateSerializer used to serialize updated state after calling `func`
- */
-case class MapGroupsWithState(
-    func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
-    keyDeserializer: Expression,
-    valueDeserializer: Expression,
-    groupingAttributes: Seq[Attribute],
-    dataAttributes: Seq[Attribute],
-    outputObjAttr: Attribute,
-    stateDeserializer: Expression,
-    stateSerializer: Seq[NamedExpression],
-    child: LogicalPlan) extends UnaryNode with ObjectProducer
-
 /** Factory for constructing new `FlatMapGroupsInR` nodes. */
 object FlatMapGroupsInR {
   def apply(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
index 3b756e89d9..dcdb1ae089 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
@@ -22,13 +22,13 @@ import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal, NamedExpression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression}
 import org.apache.spark.sql.catalyst.expressions.aggregate.Count
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _}
+import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.streaming.OutputMode
-import org.apache.spark.sql.types.{IntegerType, LongType}
+import org.apache.spark.sql.types.IntegerType
 
 /** A dummy command for testing unsupported operations. */
 case class DummyCommand() extends Command
@@ -111,24 +111,6 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
     outputMode = Complete,
     expectedMsgs = Seq("distinct aggregation"))
 
-  // MapGroupsWithState: Not supported after a streaming aggregation
-  val att = new AttributeReference(name = "a", dataType = LongType)()
-  assertSupportedInBatchPlan(
-    "mapGroupsWithState - mapGroupsWithState on batch relation",
-    MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation))
-
-  assertSupportedInStreamingPlan(
-    "mapGroupsWithState - mapGroupsWithState on streaming relation before aggregation",
-    MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), streamRelation),
-    outputMode = Append)
-
-  assertNotSupportedInStreamingPlan(
-    "mapGroupsWithState - mapGroupsWithState on streaming relation after aggregation",
-    MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att),
-      Aggregate(Nil, aggExprs("c"), streamRelation)),
-    outputMode = Complete,
-    expectedMsgs = Seq("(map/flatMap)GroupsWithState"))
-
   // Inner joins: Stream-stream not supported
   testBinaryOperationInStreamingPlan(
     "inner join",
diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
deleted file mode 100644
index 2570c8d02a..0000000000
--- a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
+++ /dev/null
@@ -1,38 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function;
-
-import java.io.Serializable;
-import java.util.Iterator;
-
-import org.apache.spark.annotation.Experimental;
-import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.Encoder;
-import org.apache.spark.sql.KeyedState;
-
-/**
- * ::Experimental::
- * Base interface for a map function used in
- * {@link org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroupsWithState(FlatMapGroupsWithStateFunction, Encoder, Encoder)}.
- * @since 2.1.1
- */
-@Experimental
-@InterfaceStability.Evolving
-public interface FlatMapGroupsWithStateFunction<K, V, S, R> extends Serializable {
-  Iterator<R> call(K key, Iterator<V> values, KeyedState<S> state) throws Exception;
-}
diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
deleted file mode 100644
index 614d3925e0..0000000000
--- a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
+++ /dev/null
@@ -1,38 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.api.java.function;
-
-import java.io.Serializable;
-import java.util.Iterator;
-
-import org.apache.spark.annotation.Experimental;
-import org.apache.spark.annotation.InterfaceStability;
-import org.apache.spark.sql.Encoder;
-import org.apache.spark.sql.KeyedState;
-
-/**
- * ::Experimental::
- * Base interface for a map function used in
- * {@link org.apache.spark.sql.KeyValueGroupedDataset#mapGroupsWithState(MapGroupsWithStateFunction, Encoder, Encoder)}
- * @since 2.1.1
- */
-@Experimental
-@InterfaceStability.Evolving
-public interface MapGroupsWithStateFunction<K, V, S, R> extends Serializable {
-  R call(K key, Iterator<V> values, KeyedState<S> state) throws Exception;
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 94e689a4d5..395d709f26 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -218,119 +218,6 @@ class KeyValueGroupedDataset[K, V] private[sql](
     mapGroups((key, data) => f.call(key, data.asJava))(encoder)
   }
 
-  /**
-   * ::Experimental::
-   * (Scala-specific)
-   * Applies the given function to each group of data, while maintaining a user-defined per-group
-   * state. The result Dataset will represent the objects returned by the function.
-   * For a static batch Dataset, the function will be invoked once per group. For a streaming
-   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
-   * updates to each group's state will be saved across invocations.
-   * See [[KeyedState]] for more details.
-   *
-   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 2.1.1
-   */
-  @Experimental
-  @InterfaceStability.Evolving
-  def mapGroupsWithState[S: Encoder, U: Encoder](
-      func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = {
-    flatMapGroupsWithState[S, U](
-      (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s)))
-  }
-
-  /**
-   * ::Experimental::
-   * (Java-specific)
-   * Applies the given function to each group of data, while maintaining a user-defined per-group
-   * state. The result Dataset will represent the objects returned by the function.
-   * For a static batch Dataset, the function will be invoked once per group. For a streaming
-   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
-   * updates to each group's state will be saved across invocations.
-   * See [[KeyedState]] for more details.
-   *
-   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func          Function to be called on every group.
-   * @param stateEncoder  Encoder for the state type.
-   * @param outputEncoder Encoder for the output type.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 2.1.1
-   */
-  @Experimental
-  @InterfaceStability.Evolving
-  def mapGroupsWithState[S, U](
-      func: MapGroupsWithStateFunction[K, V, S, U],
-      stateEncoder: Encoder[S],
-      outputEncoder: Encoder[U]): Dataset[U] = {
-    flatMapGroupsWithState[S, U](
-      (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func.call(key, it.asJava, s))
-    )(stateEncoder, outputEncoder)
-  }
-
-  /**
-   * ::Experimental::
-   * (Scala-specific)
-   * Applies the given function to each group of data, while maintaining a user-defined per-group
-   * state. The result Dataset will represent the objects returned by the function.
-   * For a static batch Dataset, the function will be invoked once per group. For a streaming
-   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
-   * updates to each group's state will be saved across invocations.
-   * See [[KeyedState]] for more details.
-   *
-   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 2.1.1
-   */
-  @Experimental
-  @InterfaceStability.Evolving
-  def flatMapGroupsWithState[S: Encoder, U: Encoder](
-      func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = {
-    Dataset[U](
-      sparkSession,
-      MapGroupsWithState[K, V, S, U](
-        func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]],
-        groupingAttributes,
-        dataAttributes,
-        logicalPlan))
-  }
-
-  /**
-   * ::Experimental::
-   * (Java-specific)
-   * Applies the given function to each group of data, while maintaining a user-defined per-group
-   * state. The result Dataset will represent the objects returned by the function.
-   * For a static batch Dataset, the function will be invoked once per group. For a streaming
-   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
-   * updates to each group's state will be saved across invocations.
-   * See [[KeyedState]] for more details.
-   *
-   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
-   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
-   * @param func          Function to be called on every group.
-   * @param stateEncoder  Encoder for the state type.
-   * @param outputEncoder Encoder for the output type.
-   *
-   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
-   * @since 2.1.1
-   */
-  @Experimental
-  @InterfaceStability.Evolving
-  def flatMapGroupsWithState[S, U](
-      func: FlatMapGroupsWithStateFunction[K, V, S, U],
-      stateEncoder: Encoder[S],
-      outputEncoder: Encoder[U]): Dataset[U] = {
-    flatMapGroupsWithState[S, U](
-      (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala
-    )(stateEncoder, outputEncoder)
-  }
-
   /**
    * (Scala-specific)
    * Reduces the elements of each group of data using the specified binary function.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala
deleted file mode 100644
index 6864b6f6b4..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala
+++ /dev/null
@@ -1,142 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import java.lang.IllegalArgumentException
-
-import org.apache.spark.annotation.{Experimental, InterfaceStability}
-import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
-
-/**
- * :: Experimental ::
- *
- * Wrapper class for interacting with keyed state data in `mapGroupsWithState` and
- * `flatMapGroupsWithState` operations on
- * [[KeyValueGroupedDataset]].
- *
- * Detail description on `[map/flatMap]GroupsWithState` operation
- * ------------------------------------------------------------
- * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in [[KeyValueGroupedDataset]]
- * will invoke the user-given function on each group (defined by the grouping function in
- * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations.
- * For a static batch Dataset, the function will be invoked once per group. For a streaming
- * Dataset, the function will be invoked for each group repeatedly in every trigger.
- * That is, in every batch of the [[streaming.StreamingQuery StreamingQuery]],
- * the function will be invoked once for each group that has data in the batch.
- *
- * The function is invoked with following parameters.
- *  - The key of the group.
- *  - An iterator containing all the values for this key.
- *  - A user-defined state object set by previous invocations of the given function.
- * In case of a batch Dataset, there is only one invocation and state object will be empty as
- * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState`
- * is equivalent to `[map/flatMap]Groups`.
- *
- * Important points to note about the function.
- *  - In a trigger, the function will be called only the groups present in the batch. So do not
- *    assume that the function will be called in every trigger for every group that has state.
- *  - There is no guaranteed ordering of values in the iterator in the function, neither with
- *    batch, nor with streaming Datasets.
- *  - All the data will be shuffled before applying the function.
- *
- * Important points to note about using KeyedState.
- *  - The value of the state cannot be null. So updating state with null will throw
- *    `IllegalArgumentException`.
- *  - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers.
- *  - If `remove()` is called, then `exists()` will return `false`,
- *    `get()` will throw `NoSuchElementException` and `getOption()` will return `None`
- *  - After that, if `update(newState)` is called, then `exists()` will again return `true`,
- *    `get()` and `getOption()`will return the updated value.
- *
- * Scala example of using KeyedState in `mapGroupsWithState`:
- * {{{
- * /* A mapping function that maintains an integer state for string keys and returns a string. */
- * def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = {
- *   // Check if state exists
- *   if (state.exists) {
- *     val existingState = state.get  // Get the existing state
- *     val shouldRemove = ...         // Decide whether to remove the state
- *     if (shouldRemove) {
- *       state.remove()     // Remove the state
- *     } else {
- *       val newState = ...
- *       state.update(newState)    // Set the new state
- *     }
- *   } else {
- *     val initialState = ...
- *     state.update(initialState)  // Set the initial state
- *   }
- *   ... // return something
- * }
- *
- * }}}
- *
- * Java example of using `KeyedState`:
- * {{{
- * /* A mapping function that maintains an integer state for string keys and returns a string. */
- * MapGroupsWithStateFunction<String, Integer, Integer, String> mappingFunction =
- *    new MapGroupsWithStateFunction<String, Integer, Integer, String>() {
- *
- *      @Override
- *      public String call(String key, Iterator<Integer> value, KeyedState<Integer> state) {
- *        if (state.exists()) {
- *          int existingState = state.get(); // Get the existing state
- *          boolean shouldRemove = ...; // Decide whether to remove the state
- *          if (shouldRemove) {
- *            state.remove(); // Remove the state
- *          } else {
- *            int newState = ...;
- *            state.update(newState); // Set the new state
- *          }
- *        } else {
- *          int initialState = ...; // Set the initial state
- *          state.update(initialState);
- *        }
- *        ... // return something
- *      }
- *    };
- * }}}
- *
- * @tparam S User-defined type of the state to be stored for each key. Must be encodable into
- *           Spark SQL types (see [[Encoder]] for more details).
- * @since 2.1.1
- */
-@Experimental
-@InterfaceStability.Evolving
-trait KeyedState[S] extends LogicalKeyedState[S] {
-
-  /** Whether state exists or not. */
-  def exists: Boolean
-
-  /** Get the state value if it exists, or throw NoSuchElementException. */
-  @throws[NoSuchElementException]("when state does not exist")
-  def get: S
-
-  /** Get the state value as a scala Option. */
-  def getOption: Option[S]
-
-  /**
-   * Update the value of the state. Note that `null` is not a valid value, and it throws
-   * IllegalArgumentException.
-   */
-  @throws[IllegalArgumentException]("when updating with null")
-  def update(newState: S): Unit
-
-  /** Remove this keyed state. */
-  def remove(): Unit
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index adea358594..ba82ec156e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning._
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan, MapGroupsWithState}
+import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan}
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution
 import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
@@ -324,23 +324,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
     }
   }
 
-  /**
-   * Strategy to convert MapGroupsWithState logical operator to physical operator
-   * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]].
-   */
-  object MapGroupsWithStateStrategy extends Strategy {
-    override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-      case MapGroupsWithState(
-          f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateDeser, stateSer, child) =>
-        val execPlan = MapGroupsWithStateExec(
-          f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer,
-          planLater(child))
-        execPlan :: Nil
-      case _ =>
-        Nil
-    }
-  }
-
   // Can we automate these 'pass through' operations?
   object BasicOperators extends Strategy {
     def numPartitions: Int = self.numPartitions
@@ -382,8 +365,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil
       case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
         execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
-      case logical.MapGroupsWithState(f, key, value, grouping, data, output, _, _, child) =>
-        execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil
       case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
         execution.CoGroupExec(
           f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 199ba5ce69..fde3b2a528 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -30,8 +30,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.objects.Invoke
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
-import org.apache.spark.sql.execution.streaming.KeyedStateImpl
 import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
 
 
@@ -146,11 +144,6 @@ object ObjectOperator {
     (i: InternalRow) => proj(i).get(0, deserializer.dataType)
   }
 
-  def deserializeRowToObject(deserializer: Expression): InternalRow => Any = {
-    val proj = GenerateSafeProjection.generate(deserializer :: Nil)
-    (i: InternalRow) => proj(i).get(0, deserializer.dataType)
-  }
-
   def serializeObjectToRow(serializer: Seq[Expression]): Any => UnsafeRow = {
     val proj = GenerateUnsafeProjection.generate(serializer)
     val objType = serializer.head.collect { case b: BoundReference => b.dataType }.head
@@ -351,21 +344,6 @@ case class MapGroupsExec(
   }
 }
 
-object MapGroupsExec {
-  def apply(
-      func: (Any, Iterator[Any], LogicalKeyedState[Any]) => TraversableOnce[Any],
-      keyDeserializer: Expression,
-      valueDeserializer: Expression,
-      groupingAttributes: Seq[Attribute],
-      dataAttributes: Seq[Attribute],
-      outputObjAttr: Attribute,
-      child: SparkPlan): MapGroupsExec = {
-    val f = (key: Any, values: Iterator[Any]) => func(key, values, new KeyedStateImpl[Any](None))
-    new MapGroupsExec(f, keyDeserializer, valueDeserializer,
-      groupingAttributes, dataAttributes, outputObjAttr, child)
-  }
-}
-
 /**
  * Groups the input rows together and calls the R function with each group and an iterator
  * containing all elements in the group.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 5c4cbfa755..6ab6fa61dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.execution.streaming
 
-import java.util.concurrent.atomic.AtomicInteger
-
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, Literal}
 import org.apache.spark.sql.SparkSession
@@ -41,9 +39,8 @@ class IncrementalExecution(
   extends QueryExecution(sparkSession, logicalPlan) with Logging {
 
   // TODO: make this always part of planning.
-  val streamingExtraStrategies =
+  val stateStrategy =
     sparkSession.sessionState.planner.StatefulAggregationStrategy +:
-    sparkSession.sessionState.planner.MapGroupsWithStateStrategy +:
     sparkSession.sessionState.planner.StreamingRelationStrategy +:
     sparkSession.sessionState.experimentalMethods.extraStrategies
 
@@ -52,7 +49,7 @@ class IncrementalExecution(
     new SparkPlanner(
       sparkSession.sparkContext,
       sparkSession.sessionState.conf,
-      streamingExtraStrategies)
+      stateStrategy)
 
   /**
    * See [SPARK-18339]
@@ -71,7 +68,7 @@ class IncrementalExecution(
    * Records the current id for a given stateful operator in the query plan as the `state`
    * preparation walks the query plan.
    */
-  private val operatorId = new AtomicInteger(0)
+  private var operatorId = 0
 
   /** Locates save/restore pairs surrounding aggregation. */
   val state = new Rule[SparkPlan] {
@@ -80,8 +77,8 @@ class IncrementalExecution(
       case StateStoreSaveExec(keys, None, None, None,
              UnaryExecNode(agg,
                StateStoreRestoreExec(keys2, None, child))) =>
-        val stateId =
-          OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
+        val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId)
+        operatorId += 1
 
         StateStoreSaveExec(
           keys,
@@ -93,12 +90,6 @@ class IncrementalExecution(
               keys,
               Some(stateId),
               child) :: Nil))
-      case MapGroupsWithStateExec(
-             f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) =>
-        val stateId =
-          OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
-        MapGroupsWithStateExec(
-          f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child)
     }
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala
deleted file mode 100644
index eee7ec45dd..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala
+++ /dev/null
@@ -1,80 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import org.apache.spark.sql.KeyedState
-
-/** Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. */
-private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedState[S] {
-  private var value: S = optionalValue.getOrElse(null.asInstanceOf[S])
-  private var defined: Boolean = optionalValue.isDefined
-  private var updated: Boolean = false
-  // whether value has been updated (but not removed)
-  private var removed: Boolean = false // whether value has been removed
-
-  // ========= Public API =========
-  override def exists: Boolean = defined
-
-  override def get: S = {
-    if (defined) {
-      value
-    } else {
-      throw new NoSuchElementException("State is either not defined or has already been removed")
-    }
-  }
-
-  override def getOption: Option[S] = {
-    if (defined) {
-      Some(value)
-    } else {
-      None
-    }
-  }
-
-  override def update(newValue: S): Unit = {
-    if (newValue == null) {
-      throw new IllegalArgumentException("'null' is not a valid state value")
-    }
-    value = newValue
-    defined = true
-    updated = true
-    removed = false
-  }
-
-  override def remove(): Unit = {
-    defined = false
-    updated = false
-    removed = true
-  }
-
-  override def toString: String = {
-    s"KeyedState(${getOption.map(_.toString).getOrElse("<undefined>")})"
-  }
-
-  // ========= Internal API =========
-
-  /** Whether the state has been marked for removing */
-  def isRemoved: Boolean = {
-    removed
-  }
-
-  /** Whether the state has been been updated */
-  def isUpdated: Boolean = {
-    updated
-  }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index 693933f95a..1f74fffbe6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -186,7 +186,7 @@ trait ProgressReporter extends Logging {
     // lastExecution could belong to one of the previous triggers if `!hasNewData`.
     // Walking the plan again should be inexpensive.
     val stateNodes = lastExecution.executedPlan.collect {
-      case p if p.isInstanceOf[StateStoreWriter] => p
+      case p if p.isInstanceOf[StateStoreSaveExec] => p
     }
     stateNodes.map { node =>
       val numRowsUpdated = if (hasNewData) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
similarity index 63%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
index 1292452574..d4ccced9ac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
@@ -22,16 +22,16 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate}
-import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState}
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.execution
-import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.execution.streaming.state._
+import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.CompletionIterator
+import org.apache.spark.TaskContext
 
 
 /** Used to identify the state store for a given operator. */
@@ -41,7 +41,7 @@ case class OperatorStateId(
     batchId: Long)
 
 /**
- * An operator that reads or writes state from the [[StateStore]].  The [[OperatorStateId]] should
+ * An operator that saves or restores state from the [[StateStore]].  The [[OperatorStateId]] should
  * be filled in by `prepareForExecution` in [[IncrementalExecution]].
  */
 trait StatefulOperator extends SparkPlan {
@@ -54,20 +54,6 @@ trait StatefulOperator extends SparkPlan {
   }
 }
 
-/** An operator that reads from a StateStore. */
-trait StateStoreReader extends StatefulOperator {
-  override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
-}
-
-/** An operator that writes to a StateStore. */
-trait StateStoreWriter extends StatefulOperator {
-  override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
-    "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"),
-    "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"))
-}
-
 /**
  * For each input tuple, the key is calculated and the value from the [[StateStore]] is added
  * to the stream (in addition to the input tuple) if present.
@@ -76,7 +62,10 @@ case class StateStoreRestoreExec(
     keyExpressions: Seq[Attribute],
     stateId: Option[OperatorStateId],
     child: SparkPlan)
-  extends execution.UnaryExecNode with StateStoreReader {
+  extends execution.UnaryExecNode with StatefulOperator {
+
+  override lazy val metrics = Map(
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
 
   override protected def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
@@ -113,7 +102,12 @@ case class StateStoreSaveExec(
     outputMode: Option[OutputMode] = None,
     eventTimeWatermark: Option[Long] = None,
     child: SparkPlan)
-  extends execution.UnaryExecNode with StateStoreWriter {
+  extends execution.UnaryExecNode with StatefulOperator {
+
+  override lazy val metrics = Map(
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
+    "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"),
+    "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"))
 
   /** Generate a predicate that matches data older than the watermark */
   private lazy val watermarkPredicate: Option[Predicate] = {
@@ -157,6 +151,13 @@ case class StateStoreSaveExec(
         val numTotalStateRows = longMetric("numTotalStateRows")
         val numUpdatedStateRows = longMetric("numUpdatedStateRows")
 
+        // Abort the state store in case of error
+        TaskContext.get().addTaskCompletionListener(_ => {
+          if (!store.hasCommitted) {
+            store.abort()
+          }
+        })
+
         outputMode match {
           // Update and output all rows in the StateStore.
           case Some(Complete) =>
@@ -183,7 +184,7 @@ case class StateStoreSaveExec(
             }
 
             // Assumption: Append mode can be done only when watermark has been specified
-            store.remove(watermarkPredicate.get.eval _)
+            store.remove(watermarkPredicate.get.eval)
             store.commit()
 
             numTotalStateRows += store.numKeys()
@@ -206,7 +207,7 @@ case class StateStoreSaveExec(
               override def hasNext: Boolean = {
                 if (!baseIterator.hasNext) {
                   // Remove old aggregates if watermark specified
-                  if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval _)
+                  if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval)
                   store.commit()
                   numTotalStateRows += store.numKeys()
                   false
@@ -234,90 +235,3 @@ case class StateStoreSaveExec(
 
   override def outputPartitioning: Partitioning = child.outputPartitioning
 }
-
-
-/** Physical operator for executing streaming mapGroupsWithState. */
-case class MapGroupsWithStateExec(
-    func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
-    keyDeserializer: Expression,
-    valueDeserializer: Expression,
-    groupingAttributes: Seq[Attribute],
-    dataAttributes: Seq[Attribute],
-    outputObjAttr: Attribute,
-    stateId: Option[OperatorStateId],
-    stateDeserializer: Expression,
-    stateSerializer: Seq[NamedExpression],
-    child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter {
-
-  override def outputPartitioning: Partitioning = child.outputPartitioning
-
-  /** Distribute by grouping attributes */
-  override def requiredChildDistribution: Seq[Distribution] =
-    ClusteredDistribution(groupingAttributes) :: Nil
-
-  /** Ordering needed for using GroupingIterator */
-  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
-    Seq(groupingAttributes.map(SortOrder(_, Ascending)))
-
-  override protected def doExecute(): RDD[InternalRow] = {
-    child.execute().mapPartitionsWithStateStore[InternalRow](
-      getStateId.checkpointLocation,
-      operatorId = getStateId.operatorId,
-      storeVersion = getStateId.batchId,
-      groupingAttributes.toStructType,
-      child.output.toStructType,
-      sqlContext.sessionState,
-      Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
-        val numTotalStateRows = longMetric("numTotalStateRows")
-        val numUpdatedStateRows = longMetric("numUpdatedStateRows")
-        val numOutputRows = longMetric("numOutputRows")
-
-        // Generate a iterator that returns the rows grouped by the grouping function
-        val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
-
-        // Converters to and from object and rows
-        val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
-        val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
-        val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
-        val getStateObj =
-          ObjectOperator.deserializeRowToObject(stateDeserializer)
-        val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer)
-
-        // For every group, get the key, values and corresponding state and call the function,
-        // and return an iterator of rows
-        val allRowsIterator = groupedIter.flatMap { case (keyRow, valueRowIter) =>
-
-          val key = keyRow.asInstanceOf[UnsafeRow]
-          val keyObj = getKeyObj(keyRow)                         // convert key to objects
-          val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
-          val stateObjOption = store.get(key).map(getStateObj)   // get existing state if any
-          val wrappedState = new KeyedStateImpl(stateObjOption)
-          val mappedIterator = func(keyObj, valueObjIter, wrappedState).map { obj =>
-            numOutputRows += 1
-            getOutputRow(obj) // convert back to rows
-          }
-
-          // Return an iterator of rows generated this key,
-          // such that fully consumed, the updated state value will be saved
-          CompletionIterator[InternalRow, Iterator[InternalRow]](
-            mappedIterator, {
-              // When the iterator is consumed, then write changes to state
-              if (wrappedState.isRemoved) {
-                store.remove(key)
-                numUpdatedStateRows += 1
-              } else if (wrappedState.isUpdated) {
-                store.put(key, outputStateObj(wrappedState.get))
-                numUpdatedStateRows += 1
-              }
-            })
-        }
-
-        // Return an iterator of all the rows generated by all the keys, such that when fully
-        // consumer, all the state updates will be committed by the state store
-        CompletionIterator[InternalRow, Iterator[InternalRow]](allRowsIterator, {
-          store.commit()
-          numTotalStateRows += store.numKeys()
-        })
-      }
-  }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index ab1204a750..f53b9b9a43 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -147,25 +147,6 @@ private[state] class HDFSBackedStateStoreProvider(
       }
     }
 
-    /** Remove a single key. */
-    override def remove(key: UnsafeRow): Unit = {
-      verify(state == UPDATING, "Cannot remove after already committed or aborted")
-      if (mapToUpdate.containsKey(key)) {
-        val value = mapToUpdate.remove(key)
-        Option(allUpdates.get(key)) match {
-          case Some(ValueUpdated(_, _)) | None =>
-            // Value existed in previous version and maybe was updated, mark removed
-            allUpdates.put(key, ValueRemoved(key, value))
-          case Some(ValueAdded(_, _)) =>
-            // Value did not exist in previous version and was added, should not appear in updates
-            allUpdates.remove(key)
-          case Some(ValueRemoved(_, _)) =>
-          // Remove already in update map, no need to change
-        }
-        writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value))
-      }
-    }
-
     /** Commit all the updates that have been made to the store, and return the new version. */
     override def commit(): Long = {
       verify(state == UPDATING, "Cannot commit after already committed or aborted")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index dcb24b26f7..e61d95a1b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -58,11 +58,6 @@ trait StateStore {
    */
   def remove(condition: UnsafeRow => Boolean): Unit
 
-  /**
-   * Remove a single key.
-   */
-  def remove(key: UnsafeRow): Unit
-
   /**
    * Commit all the updates that have been made to the store, and return the new version.
    */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index 589042afb1..1b56c08f72 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.streaming
 
 import scala.reflect.ClassTag
 
-import org.apache.spark.TaskContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.internal.SessionState
@@ -60,18 +59,10 @@ package object state {
         sessionState: SessionState,
         storeCoordinator: Option[StateStoreCoordinatorRef])(
         storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
-
       val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
-      val wrappedF = (store: StateStore, iter: Iterator[T]) => {
-        // Abort the state store in case of error
-        TaskContext.get().addTaskCompletionListener(_ => {
-          if (!store.hasCommitted) store.abort()
-        })
-        cleanedF(store, iter)
-      }
       new StateStoreRDD(
         dataRDD,
-        wrappedF,
+        cleanedF,
         checkpointLocation,
         operatorId,
         storeVersion,
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 5ef4e887de..8304b728aa 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -225,38 +225,6 @@ public class JavaDatasetSuite implements Serializable {
 
     Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList()));
 
-    Dataset<String> mapped2 = grouped.mapGroupsWithState(
-      new MapGroupsWithStateFunction<Integer, String, Long, String>() {
-        @Override
-        public String call(Integer key, Iterator<String> values, KeyedState<Long> s) throws Exception {
-          StringBuilder sb = new StringBuilder(key.toString());
-          while (values.hasNext()) {
-            sb.append(values.next());
-          }
-          return sb.toString();
-        }
-        },
-        Encoders.LONG(),
-        Encoders.STRING());
-
-    Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped2.collectAsList()));
-
-    Dataset<String> flatMapped2 = grouped.flatMapGroupsWithState(
-      new FlatMapGroupsWithStateFunction<Integer, String, Long, String>() {
-        @Override
-        public Iterator<String> call(Integer key, Iterator<String> values, KeyedState<Long> s) {
-          StringBuilder sb = new StringBuilder(key.toString());
-          while (values.hasNext()) {
-            sb.append(values.next());
-          }
-          return Collections.singletonList(sb.toString()).iterator();
-        }
-      },
-      Encoders.LONG(),
-      Encoders.STRING());
-
-    Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList()));
-
     Dataset<Tuple2<Integer, String>> reduced = grouped.reduceGroups(new ReduceFunction<String>() {
       @Override
       public String call(String v1, String v2) throws Exception {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala
deleted file mode 100644
index 0524898b15..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala
+++ /dev/null
@@ -1,335 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.streaming
-
-import org.scalatest.BeforeAndAfterAll
-
-import org.apache.spark.SparkException
-import org.apache.spark.sql.KeyedState
-import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
-import org.apache.spark.sql.execution.streaming.{KeyedStateImpl, MemoryStream}
-import org.apache.spark.sql.execution.streaming.state.StateStore
-
-/** Class to check custom state types */
-case class RunningCount(count: Long)
-
-class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll {
-
-  import testImplicits._
-
-  override def afterAll(): Unit = {
-    super.afterAll()
-    StateStore.stop()
-  }
-
-  test("KeyedState - get, exists, update, remove") {
-    var state: KeyedStateImpl[String] = null
-
-    def testState(
-        expectedData: Option[String],
-        shouldBeUpdated: Boolean = false,
-        shouldBeRemoved: Boolean = false): Unit = {
-      if (expectedData.isDefined) {
-        assert(state.exists)
-        assert(state.get === expectedData.get)
-      } else {
-        assert(!state.exists)
-        intercept[NoSuchElementException] {
-          state.get
-        }
-      }
-      assert(state.getOption === expectedData)
-      assert(state.isUpdated === shouldBeUpdated)
-      assert(state.isRemoved === shouldBeRemoved)
-    }
-
-    // Updating empty state
-    state = new KeyedStateImpl[String](None)
-    testState(None)
-    state.update("")
-    testState(Some(""), shouldBeUpdated = true)
-
-    // Updating exiting state
-    state = new KeyedStateImpl[String](Some("2"))
-    testState(Some("2"))
-    state.update("3")
-    testState(Some("3"), shouldBeUpdated = true)
-
-    // Removing state
-    state.remove()
-    testState(None, shouldBeRemoved = true, shouldBeUpdated = false)
-    state.remove()      // should be still callable
-    state.update("4")
-    testState(Some("4"), shouldBeRemoved = false, shouldBeUpdated = true)
-
-    // Updating by null throw exception
-    intercept[IllegalArgumentException] {
-      state.update(null)
-    }
-  }
-
-  test("KeyedState - primitive type") {
-    var intState = new KeyedStateImpl[Int](None)
-    intercept[NoSuchElementException] {
-      intState.get
-    }
-    assert(intState.getOption === None)
-
-    intState = new KeyedStateImpl[Int](Some(10))
-    assert(intState.get == 10)
-    intState.update(0)
-    assert(intState.get == 0)
-    intState.remove()
-    intercept[NoSuchElementException] {
-      intState.get
-    }
-  }
-
-  test("flatMapGroupsWithState - streaming") {
-    // Function to maintain running count up to 2, and then remove the count
-    // Returns the data and the count if state is defined, otherwise does not return anything
-    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
-
-      val count = state.getOption.map(_.count).getOrElse(0L) + values.size
-      if (count == 3) {
-        state.remove()
-        Iterator.empty
-      } else {
-        state.update(RunningCount(count))
-        Iterator((key, count.toString))
-      }
-    }
-
-    val inputData = MemoryStream[String]
-    val result =
-      inputData.toDS()
-        .groupByKey(x => x)
-        .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str)
-
-    testStream(result, Append)(
-      AddData(inputData, "a"),
-      CheckLastBatch(("a", "1")),
-      assertNumStateRows(total = 1, updated = 1),
-      AddData(inputData, "a", "b"),
-      CheckLastBatch(("a", "2"), ("b", "1")),
-      assertNumStateRows(total = 2, updated = 2),
-      StopStream,
-      StartStream(),
-      AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a
-      CheckLastBatch(("b", "2")),
-      assertNumStateRows(total = 1, updated = 2),
-      StopStream,
-      StartStream(),
-      AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and
-      CheckLastBatch(("a", "1"), ("c", "1")),
-      assertNumStateRows(total = 3, updated = 2)
-    )
-  }
-
-  test("flatMapGroupsWithState - streaming + func returns iterator that updates state lazily") {
-    // Function to maintain running count up to 2, and then remove the count
-    // Returns the data and the count if state is defined, otherwise does not return anything
-    // Additionally, it updates state lazily as the returned iterator get consumed
-    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
-      values.flatMap { _ =>
-        val count = state.getOption.map(_.count).getOrElse(0L) + 1
-        if (count == 3) {
-          state.remove()
-          None
-        } else {
-          state.update(RunningCount(count))
-          Some((key, count.toString))
-        }
-      }
-    }
-
-    val inputData = MemoryStream[String]
-    val result =
-      inputData.toDS()
-        .groupByKey(x => x)
-        .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str)
-
-    testStream(result, Append)(
-      AddData(inputData, "a", "a", "b"),
-      CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")),
-      StopStream,
-      StartStream(),
-      AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a
-      CheckLastBatch(("b", "2")),
-      StopStream,
-      StartStream(),
-      AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and
-      CheckLastBatch(("a", "1"), ("c", "1"))
-    )
-  }
-
-  test("flatMapGroupsWithState - batch") {
-    // Function that returns running count only if its even, otherwise does not return
-    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
-      if (state.exists) throw new IllegalArgumentException("state.exists should be false")
-      Iterator((key, values.size))
-    }
-    checkAnswer(
-      Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc).toDF,
-      Seq(("a", 2), ("b", 1)).toDF)
-  }
-
-  test("mapGroupsWithState - streaming") {
-    // Function to maintain running count up to 2, and then remove the count
-    // Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
-    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
-
-      val count = state.getOption.map(_.count).getOrElse(0L) + values.size
-      if (count == 3) {
-        state.remove()
-        (key, "-1")
-      } else {
-        state.update(RunningCount(count))
-        (key, count.toString)
-      }
-    }
-
-    val inputData = MemoryStream[String]
-    val result =
-      inputData.toDS()
-        .groupByKey(x => x)
-        .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
-
-    testStream(result, Append)(
-      AddData(inputData, "a"),
-      CheckLastBatch(("a", "1")),
-      assertNumStateRows(total = 1, updated = 1),
-      AddData(inputData, "a", "b"),
-      CheckLastBatch(("a", "2"), ("b", "1")),
-      assertNumStateRows(total = 2, updated = 2),
-      StopStream,
-      StartStream(),
-      AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1
-      CheckLastBatch(("a", "-1"), ("b", "2")),
-      assertNumStateRows(total = 1, updated = 2),
-      StopStream,
-      StartStream(),
-      AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1
-      CheckLastBatch(("a", "1"), ("c", "1")),
-      assertNumStateRows(total = 3, updated = 2)
-    )
-  }
-
-  test("mapGroupsWithState - streaming + aggregation") {
-    // Function to maintain running count up to 2, and then remove the count
-    // Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
-    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
-
-      val count = state.getOption.map(_.count).getOrElse(0L) + values.size
-      if (count == 3) {
-        state.remove()
-        (key, "-1")
-      } else {
-        state.update(RunningCount(count))
-        (key, count.toString)
-      }
-    }
-
-    val inputData = MemoryStream[String]
-    val result =
-      inputData.toDS()
-        .groupByKey(x => x)
-        .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
-        .groupByKey(_._1)
-        .count()
-
-    testStream(result, Complete)(
-      AddData(inputData, "a"),
-      CheckLastBatch(("a", 1)),
-      AddData(inputData, "a", "b"),
-      // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1
-      CheckLastBatch(("a", 2), ("b", 1)),
-      StopStream,
-      StartStream(),
-      AddData(inputData, "a", "b"),
-      // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ;
-      // so increment a and b by 1
-      CheckLastBatch(("a", 3), ("b", 2)),
-      StopStream,
-      StartStream(),
-      AddData(inputData, "a", "c"),
-      // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ;
-      // so increment a and c by 1
-      CheckLastBatch(("a", 4), ("b", 2), ("c", 1))
-    )
-  }
-
-  test("mapGroupsWithState - batch") {
-    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
-      if (state.exists) throw new IllegalArgumentException("state.exists should be false")
-      (key, values.size)
-    }
-
-    checkAnswer(
-      spark.createDataset(Seq("a", "a", "b"))
-        .groupByKey(x => x)
-        .mapGroupsWithState(stateFunc)
-        .toDF,
-      spark.createDataset(Seq(("a", 2), ("b", 1))).toDF)
-  }
-
-  testQuietly("StateStore.abort on task failure handling") {
-    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
-      if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure")
-      val count = state.getOption.map(_.count).getOrElse(0L) + values.size
-      state.update(RunningCount(count))
-      (key, count)
-    }
-
-    val inputData = MemoryStream[String]
-    val result =
-      inputData.toDS()
-        .groupByKey(x => x)
-        .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
-
-    def setFailInTask(value: Boolean): AssertOnQuery = AssertOnQuery { q =>
-      MapGroupsWithStateSuite.failInTask = value
-      true
-    }
-
-    testStream(result, Append)(
-      setFailInTask(false),
-      AddData(inputData, "a"),
-      CheckLastBatch(("a", 1L)),
-      AddData(inputData, "a"),
-      CheckLastBatch(("a", 2L)),
-      setFailInTask(true),
-      AddData(inputData, "a"),
-      ExpectFailure[SparkException](),   // task should fail but should not increment count
-      setFailInTask(false),
-      StartStream(),
-      CheckLastBatch(("a", 3L))     // task should not fail, and should show correct count
-    )
-  }
-
-  private def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = AssertOnQuery { q =>
-    val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get
-    assert(progressWithData.stateOperators(0).numRowsTotal === total, "incorrect total rows")
-    assert(progressWithData.stateOperators(0).numRowsUpdated === updated, "incorrect updates rows")
-    true
-  }
-}
-
-object MapGroupsWithStateSuite {
-  var failInTask = true
-}
-- 
GitLab