From 502c927b8c8a99ef2adf4e6e1d7a6d9232d45ef5 Mon Sep 17 00:00:00 2001
From: Tathagata Das <tathagata.das1565@gmail.com>
Date: Wed, 8 Feb 2017 11:33:59 -0800
Subject: [PATCH] [SPARK-19413][SS] MapGroupsWithState for arbitrary stateful
 operations for branch-2.1

This is a follow up PR for merging #16758 to spark 2.1 branch

## What changes were proposed in this pull request?

`mapGroupsWithState` is a new API for arbitrary stateful operations in Structured Streaming, similar to `DStream.mapWithState`

*Requirements*
- Users should be able to specify a function that can do the following
- Access the input row corresponding to a key
- Access the previous state corresponding to a key
- Optionally, update or remove the state
- Output any number of new rows (or none at all)

*Proposed API*
```
// ------------ New methods on KeyValueGroupedDataset ------------
class KeyValueGroupedDataset[K, V] {
	// Scala friendly
	def mapGroupsWithState[S: Encoder, U: Encoder](func: (K, Iterator[V], KeyedState[S]) => U)
        def flatMapGroupsWithState[S: Encode, U: Encoder](func: (K, Iterator[V], KeyedState[S]) => Iterator[U])
	// Java friendly
       def mapGroupsWithState[S, U](func: MapGroupsWithStateFunction[K, V, S, R], stateEncoder: Encoder[S], resultEncoder: Encoder[U])
       def flatMapGroupsWithState[S, U](func: FlatMapGroupsWithStateFunction[K, V, S, R], stateEncoder: Encoder[S], resultEncoder: Encoder[U])
}

// ------------------- New Java-friendly function classes -------------------
public interface MapGroupsWithStateFunction<K, V, S, R> extends Serializable {
  R call(K key, Iterator<V> values, state: KeyedState<S>) throws Exception;
}
public interface FlatMapGroupsWithStateFunction<K, V, S, R> extends Serializable {
  Iterator<R> call(K key, Iterator<V> values, state: KeyedState<S>) throws Exception;
}

// ---------------------- Wrapper class for state data ----------------------
trait KeyedState[S] {
	def exists(): Boolean
  	def get(): S 			// throws Exception is state does not exist
	def getOption(): Option[S]
	def update(newState: S): Unit
	def remove(): Unit		// exists() will be false after this
}
```

Key Semantics of the State class
- The state can be null.
- If the state.remove() is called, then state.exists() will return false, and getOption will returm None.
- After that state.update(newState) is called, then state.exists() will return true, and getOption will return Some(...).
- None of the operations are thread-safe. This is to avoid memory barriers.

*Usage*
```
val stateFunc = (word: String, words: Iterator[String, runningCount: KeyedState[Long]) => {
    val newCount = words.size + runningCount.getOption.getOrElse(0L)
    runningCount.update(newCount)
   (word, newCount)
}

dataset					                        // type is Dataset[String]
  .groupByKey[String](w => w)        	                // generates KeyValueGroupedDataset[String, String]
  .mapGroupsWithState[Long, (String, Long)](stateFunc)	// returns Dataset[(String, Long)]
```

## How was this patch tested?
New unit tests.

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #16850 from tdas/mapWithState-branch-2.1.
---
 .../UnsupportedOperationChecker.scala         |  11 +-
 .../sql/catalyst/plans/logical/object.scala   |  49 +++
 .../analysis/UnsupportedOperationsSuite.scala |  24 +-
 .../FlatMapGroupsWithStateFunction.java       |  38 ++
 .../function/MapGroupsWithStateFunction.java  |  38 ++
 .../spark/sql/KeyValueGroupedDataset.scala    | 113 ++++++
 .../org/apache/spark/sql/KeyedState.scala     | 142 ++++++++
 .../spark/sql/execution/SparkStrategies.scala |  21 +-
 .../apache/spark/sql/execution/objects.scala  |  22 ++
 .../streaming/IncrementalExecution.scala      |  19 +-
 .../execution/streaming/KeyedStateImpl.scala  |  80 +++++
 .../streaming/ProgressReporter.scala          |   2 +-
 .../state/HDFSBackedStateStoreProvider.scala  |  19 +
 .../streaming/state/StateStore.scala          |   5 +
 .../execution/streaming/state/package.scala   |  11 +-
 ...ggregate.scala => statefulOperators.scala} | 134 +++++--
 .../apache/spark/sql/JavaDatasetSuite.java    |  32 ++
 .../streaming/MapGroupsWithStateSuite.scala   | 335 ++++++++++++++++++
 18 files changed, 1059 insertions(+), 36 deletions(-)
 create mode 100644 sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
 create mode 100644 sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
 create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala
 create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala
 rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{StatefulAggregate.scala => statefulOperators.scala} (63%)
 create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index f4d016cb96..d8aad42edc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -46,8 +46,13 @@ object UnsupportedOperationChecker {
         "Queries without streaming sources cannot be executed with writeStream.start()")(plan)
     }
 
+    /** Collect all the streaming aggregates in a sub plan */
+    def collectStreamingAggregates(subplan: LogicalPlan): Seq[Aggregate] = {
+      subplan.collect { case a: Aggregate if a.isStreaming => a }
+    }
+
     // Disallow multiple streaming aggregations
-    val aggregates = plan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a }
+    val aggregates = collectStreamingAggregates(plan)
 
     if (aggregates.size > 1) {
       throwError(
@@ -114,6 +119,10 @@ object UnsupportedOperationChecker {
         case _: InsertIntoTable =>
           throwError("InsertIntoTable is not supported with streaming DataFrames/Datasets")
 
+        case m: MapGroupsWithState if collectStreamingAggregates(m).nonEmpty =>
+          throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " +
+            "streaming DataFrame/Dataset")
+
         case Join(left, right, joinType, _) =>
 
           joinType match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 0ab4c90166..0be4823bbc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -313,6 +313,55 @@ case class MapGroups(
     outputObjAttr: Attribute,
     child: LogicalPlan) extends UnaryNode with ObjectProducer
 
+/** Internal class representing State */
+trait LogicalKeyedState[S]
+
+/** Factory for constructing new `MapGroupsWithState` nodes. */
+object MapGroupsWithState {
+  def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder](
+      func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
+      groupingAttributes: Seq[Attribute],
+      dataAttributes: Seq[Attribute],
+      child: LogicalPlan): LogicalPlan = {
+    val mapped = new MapGroupsWithState(
+      func,
+      UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
+      UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes),
+      groupingAttributes,
+      dataAttributes,
+      CatalystSerde.generateObjAttr[U],
+      encoderFor[S].resolveAndBind().deserializer,
+      encoderFor[S].namedExpressions,
+      child)
+    CatalystSerde.serialize[U](mapped)
+  }
+}
+
+/**
+ * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`,
+ * while using state data.
+ * Func is invoked with an object representation of the grouping key an iterator containing the
+ * object representation of all the rows with that key.
+ *
+ * @param keyDeserializer used to extract the key object for each group.
+ * @param valueDeserializer used to extract the items in the iterator from an input row.
+ * @param groupingAttributes used to group the data
+ * @param dataAttributes used to read the data
+ * @param outputObjAttr used to define the output object
+ * @param stateDeserializer used to deserialize state before calling `func`
+ * @param stateSerializer used to serialize updated state after calling `func`
+ */
+case class MapGroupsWithState(
+    func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
+    keyDeserializer: Expression,
+    valueDeserializer: Expression,
+    groupingAttributes: Seq[Attribute],
+    dataAttributes: Seq[Attribute],
+    outputObjAttr: Attribute,
+    stateDeserializer: Expression,
+    stateSerializer: Seq[NamedExpression],
+    child: LogicalPlan) extends UnaryNode with ObjectProducer
+
 /** Factory for constructing new `FlatMapGroupsInR` nodes. */
 object FlatMapGroupsInR {
   def apply(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
index dcdb1ae089..3b756e89d9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
@@ -22,13 +22,13 @@ import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal, NamedExpression}
 import org.apache.spark.sql.catalyst.expressions.aggregate.Count
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _}
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.streaming.OutputMode
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.{IntegerType, LongType}
 
 /** A dummy command for testing unsupported operations. */
 case class DummyCommand() extends Command
@@ -111,6 +111,24 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
     outputMode = Complete,
     expectedMsgs = Seq("distinct aggregation"))
 
+  // MapGroupsWithState: Not supported after a streaming aggregation
+  val att = new AttributeReference(name = "a", dataType = LongType)()
+  assertSupportedInBatchPlan(
+    "mapGroupsWithState - mapGroupsWithState on batch relation",
+    MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation))
+
+  assertSupportedInStreamingPlan(
+    "mapGroupsWithState - mapGroupsWithState on streaming relation before aggregation",
+    MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), streamRelation),
+    outputMode = Append)
+
+  assertNotSupportedInStreamingPlan(
+    "mapGroupsWithState - mapGroupsWithState on streaming relation after aggregation",
+    MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att),
+      Aggregate(Nil, aggExprs("c"), streamRelation)),
+    outputMode = Complete,
+    expectedMsgs = Seq("(map/flatMap)GroupsWithState"))
+
   // Inner joins: Stream-stream not supported
   testBinaryOperationInStreamingPlan(
     "inner join",
diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
new file mode 100644
index 0000000000..2570c8d02a
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.KeyedState;
+
+/**
+ * ::Experimental::
+ * Base interface for a map function used in
+ * {@link org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroupsWithState(FlatMapGroupsWithStateFunction, Encoder, Encoder)}.
+ * @since 2.1.1
+ */
+@Experimental
+@InterfaceStability.Evolving
+public interface FlatMapGroupsWithStateFunction<K, V, S, R> extends Serializable {
+  Iterator<R> call(K key, Iterator<V> values, KeyedState<S> state) throws Exception;
+}
diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
new file mode 100644
index 0000000000..614d3925e0
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.KeyedState;
+
+/**
+ * ::Experimental::
+ * Base interface for a map function used in
+ * {@link org.apache.spark.sql.KeyValueGroupedDataset#mapGroupsWithState(MapGroupsWithStateFunction, Encoder, Encoder)}
+ * @since 2.1.1
+ */
+@Experimental
+@InterfaceStability.Evolving
+public interface MapGroupsWithStateFunction<K, V, S, R> extends Serializable {
+  R call(K key, Iterator<V> values, KeyedState<S> state) throws Exception;
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 395d709f26..94e689a4d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -218,6 +218,119 @@ class KeyValueGroupedDataset[K, V] private[sql](
     mapGroups((key, data) => f.call(key, data.asJava))(encoder)
   }
 
+  /**
+   * ::Experimental::
+   * (Scala-specific)
+   * Applies the given function to each group of data, while maintaining a user-defined per-group
+   * state. The result Dataset will represent the objects returned by the function.
+   * For a static batch Dataset, the function will be invoked once per group. For a streaming
+   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+   * updates to each group's state will be saved across invocations.
+   * See [[KeyedState]] for more details.
+   *
+   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 2.1.1
+   */
+  @Experimental
+  @InterfaceStability.Evolving
+  def mapGroupsWithState[S: Encoder, U: Encoder](
+      func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = {
+    flatMapGroupsWithState[S, U](
+      (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s)))
+  }
+
+  /**
+   * ::Experimental::
+   * (Java-specific)
+   * Applies the given function to each group of data, while maintaining a user-defined per-group
+   * state. The result Dataset will represent the objects returned by the function.
+   * For a static batch Dataset, the function will be invoked once per group. For a streaming
+   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+   * updates to each group's state will be saved across invocations.
+   * See [[KeyedState]] for more details.
+   *
+   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func          Function to be called on every group.
+   * @param stateEncoder  Encoder for the state type.
+   * @param outputEncoder Encoder for the output type.
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 2.1.1
+   */
+  @Experimental
+  @InterfaceStability.Evolving
+  def mapGroupsWithState[S, U](
+      func: MapGroupsWithStateFunction[K, V, S, U],
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U]): Dataset[U] = {
+    flatMapGroupsWithState[S, U](
+      (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func.call(key, it.asJava, s))
+    )(stateEncoder, outputEncoder)
+  }
+
+  /**
+   * ::Experimental::
+   * (Scala-specific)
+   * Applies the given function to each group of data, while maintaining a user-defined per-group
+   * state. The result Dataset will represent the objects returned by the function.
+   * For a static batch Dataset, the function will be invoked once per group. For a streaming
+   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+   * updates to each group's state will be saved across invocations.
+   * See [[KeyedState]] for more details.
+   *
+   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 2.1.1
+   */
+  @Experimental
+  @InterfaceStability.Evolving
+  def flatMapGroupsWithState[S: Encoder, U: Encoder](
+      func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = {
+    Dataset[U](
+      sparkSession,
+      MapGroupsWithState[K, V, S, U](
+        func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]],
+        groupingAttributes,
+        dataAttributes,
+        logicalPlan))
+  }
+
+  /**
+   * ::Experimental::
+   * (Java-specific)
+   * Applies the given function to each group of data, while maintaining a user-defined per-group
+   * state. The result Dataset will represent the objects returned by the function.
+   * For a static batch Dataset, the function will be invoked once per group. For a streaming
+   * Dataset, the function will be invoked for each group repeatedly in every trigger, and
+   * updates to each group's state will be saved across invocations.
+   * See [[KeyedState]] for more details.
+   *
+   * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types.
+   * @tparam U The type of the output objects. Must be encodable to Spark SQL types.
+   * @param func          Function to be called on every group.
+   * @param stateEncoder  Encoder for the state type.
+   * @param outputEncoder Encoder for the output type.
+   *
+   * See [[Encoder]] for more details on what types are encodable to Spark SQL.
+   * @since 2.1.1
+   */
+  @Experimental
+  @InterfaceStability.Evolving
+  def flatMapGroupsWithState[S, U](
+      func: FlatMapGroupsWithStateFunction[K, V, S, U],
+      stateEncoder: Encoder[S],
+      outputEncoder: Encoder[U]): Dataset[U] = {
+    flatMapGroupsWithState[S, U](
+      (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala
+    )(stateEncoder, outputEncoder)
+  }
+
   /**
    * (Scala-specific)
    * Reduces the elements of each group of data using the specified binary function.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala
new file mode 100644
index 0000000000..6864b6f6b4
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.lang.IllegalArgumentException
+
+import org.apache.spark.annotation.{Experimental, InterfaceStability}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
+
+/**
+ * :: Experimental ::
+ *
+ * Wrapper class for interacting with keyed state data in `mapGroupsWithState` and
+ * `flatMapGroupsWithState` operations on
+ * [[KeyValueGroupedDataset]].
+ *
+ * Detail description on `[map/flatMap]GroupsWithState` operation
+ * ------------------------------------------------------------
+ * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in [[KeyValueGroupedDataset]]
+ * will invoke the user-given function on each group (defined by the grouping function in
+ * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations.
+ * For a static batch Dataset, the function will be invoked once per group. For a streaming
+ * Dataset, the function will be invoked for each group repeatedly in every trigger.
+ * That is, in every batch of the [[streaming.StreamingQuery StreamingQuery]],
+ * the function will be invoked once for each group that has data in the batch.
+ *
+ * The function is invoked with following parameters.
+ *  - The key of the group.
+ *  - An iterator containing all the values for this key.
+ *  - A user-defined state object set by previous invocations of the given function.
+ * In case of a batch Dataset, there is only one invocation and state object will be empty as
+ * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState`
+ * is equivalent to `[map/flatMap]Groups`.
+ *
+ * Important points to note about the function.
+ *  - In a trigger, the function will be called only the groups present in the batch. So do not
+ *    assume that the function will be called in every trigger for every group that has state.
+ *  - There is no guaranteed ordering of values in the iterator in the function, neither with
+ *    batch, nor with streaming Datasets.
+ *  - All the data will be shuffled before applying the function.
+ *
+ * Important points to note about using KeyedState.
+ *  - The value of the state cannot be null. So updating state with null will throw
+ *    `IllegalArgumentException`.
+ *  - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers.
+ *  - If `remove()` is called, then `exists()` will return `false`,
+ *    `get()` will throw `NoSuchElementException` and `getOption()` will return `None`
+ *  - After that, if `update(newState)` is called, then `exists()` will again return `true`,
+ *    `get()` and `getOption()`will return the updated value.
+ *
+ * Scala example of using KeyedState in `mapGroupsWithState`:
+ * {{{
+ * /* A mapping function that maintains an integer state for string keys and returns a string. */
+ * def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = {
+ *   // Check if state exists
+ *   if (state.exists) {
+ *     val existingState = state.get  // Get the existing state
+ *     val shouldRemove = ...         // Decide whether to remove the state
+ *     if (shouldRemove) {
+ *       state.remove()     // Remove the state
+ *     } else {
+ *       val newState = ...
+ *       state.update(newState)    // Set the new state
+ *     }
+ *   } else {
+ *     val initialState = ...
+ *     state.update(initialState)  // Set the initial state
+ *   }
+ *   ... // return something
+ * }
+ *
+ * }}}
+ *
+ * Java example of using `KeyedState`:
+ * {{{
+ * /* A mapping function that maintains an integer state for string keys and returns a string. */
+ * MapGroupsWithStateFunction<String, Integer, Integer, String> mappingFunction =
+ *    new MapGroupsWithStateFunction<String, Integer, Integer, String>() {
+ *
+ *      @Override
+ *      public String call(String key, Iterator<Integer> value, KeyedState<Integer> state) {
+ *        if (state.exists()) {
+ *          int existingState = state.get(); // Get the existing state
+ *          boolean shouldRemove = ...; // Decide whether to remove the state
+ *          if (shouldRemove) {
+ *            state.remove(); // Remove the state
+ *          } else {
+ *            int newState = ...;
+ *            state.update(newState); // Set the new state
+ *          }
+ *        } else {
+ *          int initialState = ...; // Set the initial state
+ *          state.update(initialState);
+ *        }
+ *        ... // return something
+ *      }
+ *    };
+ * }}}
+ *
+ * @tparam S User-defined type of the state to be stored for each key. Must be encodable into
+ *           Spark SQL types (see [[Encoder]] for more details).
+ * @since 2.1.1
+ */
+@Experimental
+@InterfaceStability.Evolving
+trait KeyedState[S] extends LogicalKeyedState[S] {
+
+  /** Whether state exists or not. */
+  def exists: Boolean
+
+  /** Get the state value if it exists, or throw NoSuchElementException. */
+  @throws[NoSuchElementException]("when state does not exist")
+  def get: S
+
+  /** Get the state value as a scala Option. */
+  def getOption: Option[S]
+
+  /**
+   * Update the value of the state. Note that `null` is not a valid value, and it throws
+   * IllegalArgumentException.
+   */
+  @throws[IllegalArgumentException]("when updating with null")
+  def update(newState: S): Unit
+
+  /** Remove this keyed state. */
+  def remove(): Unit
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index ba82ec156e..adea358594 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning._
 import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan, MapGroupsWithState}
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution
 import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
@@ -324,6 +324,23 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
     }
   }
 
+  /**
+   * Strategy to convert MapGroupsWithState logical operator to physical operator
+   * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]].
+   */
+  object MapGroupsWithStateStrategy extends Strategy {
+    override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+      case MapGroupsWithState(
+          f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateDeser, stateSer, child) =>
+        val execPlan = MapGroupsWithStateExec(
+          f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer,
+          planLater(child))
+        execPlan :: Nil
+      case _ =>
+        Nil
+    }
+  }
+
   // Can we automate these 'pass through' operations?
   object BasicOperators extends Strategy {
     def numPartitions: Int = self.numPartitions
@@ -365,6 +382,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil
       case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
         execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
+      case logical.MapGroupsWithState(f, key, value, grouping, data, output, _, _, child) =>
+        execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil
       case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
         execution.CoGroupExec(
           f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index fde3b2a528..199ba5ce69 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -30,6 +30,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.objects.Invoke
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState
+import org.apache.spark.sql.execution.streaming.KeyedStateImpl
 import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
 
 
@@ -144,6 +146,11 @@ object ObjectOperator {
     (i: InternalRow) => proj(i).get(0, deserializer.dataType)
   }
 
+  def deserializeRowToObject(deserializer: Expression): InternalRow => Any = {
+    val proj = GenerateSafeProjection.generate(deserializer :: Nil)
+    (i: InternalRow) => proj(i).get(0, deserializer.dataType)
+  }
+
   def serializeObjectToRow(serializer: Seq[Expression]): Any => UnsafeRow = {
     val proj = GenerateUnsafeProjection.generate(serializer)
     val objType = serializer.head.collect { case b: BoundReference => b.dataType }.head
@@ -344,6 +351,21 @@ case class MapGroupsExec(
   }
 }
 
+object MapGroupsExec {
+  def apply(
+      func: (Any, Iterator[Any], LogicalKeyedState[Any]) => TraversableOnce[Any],
+      keyDeserializer: Expression,
+      valueDeserializer: Expression,
+      groupingAttributes: Seq[Attribute],
+      dataAttributes: Seq[Attribute],
+      outputObjAttr: Attribute,
+      child: SparkPlan): MapGroupsExec = {
+    val f = (key: Any, values: Iterator[Any]) => func(key, values, new KeyedStateImpl[Any](None))
+    new MapGroupsExec(f, keyDeserializer, valueDeserializer,
+      groupingAttributes, dataAttributes, outputObjAttr, child)
+  }
+}
+
 /**
  * Groups the input rows together and calls the R function with each group and an iterator
  * containing all elements in the group.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 6ab6fa61dc..5c4cbfa755 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.streaming
 
+import java.util.concurrent.atomic.AtomicInteger
+
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, Literal}
 import org.apache.spark.sql.SparkSession
@@ -39,8 +41,9 @@ class IncrementalExecution(
   extends QueryExecution(sparkSession, logicalPlan) with Logging {
 
   // TODO: make this always part of planning.
-  val stateStrategy =
+  val streamingExtraStrategies =
     sparkSession.sessionState.planner.StatefulAggregationStrategy +:
+    sparkSession.sessionState.planner.MapGroupsWithStateStrategy +:
     sparkSession.sessionState.planner.StreamingRelationStrategy +:
     sparkSession.sessionState.experimentalMethods.extraStrategies
 
@@ -49,7 +52,7 @@ class IncrementalExecution(
     new SparkPlanner(
       sparkSession.sparkContext,
       sparkSession.sessionState.conf,
-      stateStrategy)
+      streamingExtraStrategies)
 
   /**
    * See [SPARK-18339]
@@ -68,7 +71,7 @@ class IncrementalExecution(
    * Records the current id for a given stateful operator in the query plan as the `state`
    * preparation walks the query plan.
    */
-  private var operatorId = 0
+  private val operatorId = new AtomicInteger(0)
 
   /** Locates save/restore pairs surrounding aggregation. */
   val state = new Rule[SparkPlan] {
@@ -77,8 +80,8 @@ class IncrementalExecution(
       case StateStoreSaveExec(keys, None, None, None,
              UnaryExecNode(agg,
                StateStoreRestoreExec(keys2, None, child))) =>
-        val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId)
-        operatorId += 1
+        val stateId =
+          OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
 
         StateStoreSaveExec(
           keys,
@@ -90,6 +93,12 @@ class IncrementalExecution(
               keys,
               Some(stateId),
               child) :: Nil))
+      case MapGroupsWithStateExec(
+             f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) =>
+        val stateId =
+          OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
+        MapGroupsWithStateExec(
+          f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child)
     }
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala
new file mode 100644
index 0000000000..eee7ec45dd
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.sql.KeyedState
+
+/** Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. */
+private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedState[S] {
+  private var value: S = optionalValue.getOrElse(null.asInstanceOf[S])
+  private var defined: Boolean = optionalValue.isDefined
+  private var updated: Boolean = false
+  // whether value has been updated (but not removed)
+  private var removed: Boolean = false // whether value has been removed
+
+  // ========= Public API =========
+  override def exists: Boolean = defined
+
+  override def get: S = {
+    if (defined) {
+      value
+    } else {
+      throw new NoSuchElementException("State is either not defined or has already been removed")
+    }
+  }
+
+  override def getOption: Option[S] = {
+    if (defined) {
+      Some(value)
+    } else {
+      None
+    }
+  }
+
+  override def update(newValue: S): Unit = {
+    if (newValue == null) {
+      throw new IllegalArgumentException("'null' is not a valid state value")
+    }
+    value = newValue
+    defined = true
+    updated = true
+    removed = false
+  }
+
+  override def remove(): Unit = {
+    defined = false
+    updated = false
+    removed = true
+  }
+
+  override def toString: String = {
+    s"KeyedState(${getOption.map(_.toString).getOrElse("<undefined>")})"
+  }
+
+  // ========= Internal API =========
+
+  /** Whether the state has been marked for removing */
+  def isRemoved: Boolean = {
+    removed
+  }
+
+  /** Whether the state has been been updated */
+  def isUpdated: Boolean = {
+    updated
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index 1f74fffbe6..693933f95a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -186,7 +186,7 @@ trait ProgressReporter extends Logging {
     // lastExecution could belong to one of the previous triggers if `!hasNewData`.
     // Walking the plan again should be inexpensive.
     val stateNodes = lastExecution.executedPlan.collect {
-      case p if p.isInstanceOf[StateStoreSaveExec] => p
+      case p if p.isInstanceOf[StateStoreWriter] => p
     }
     stateNodes.map { node =>
       val numRowsUpdated = if (hasNewData) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index 1279b71c4d..61eb601a18 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -147,6 +147,25 @@ private[state] class HDFSBackedStateStoreProvider(
       }
     }
 
+    /** Remove a single key. */
+    override def remove(key: UnsafeRow): Unit = {
+      verify(state == UPDATING, "Cannot remove after already committed or aborted")
+      if (mapToUpdate.containsKey(key)) {
+        val value = mapToUpdate.remove(key)
+        Option(allUpdates.get(key)) match {
+          case Some(ValueUpdated(_, _)) | None =>
+            // Value existed in previous version and maybe was updated, mark removed
+            allUpdates.put(key, ValueRemoved(key, value))
+          case Some(ValueAdded(_, _)) =>
+            // Value did not exist in previous version and was added, should not appear in updates
+            allUpdates.remove(key)
+          case Some(ValueRemoved(_, _)) =>
+          // Remove already in update map, no need to change
+        }
+        writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value))
+      }
+    }
+
     /** Commit all the updates that have been made to the store, and return the new version. */
     override def commit(): Long = {
       verify(state == UPDATING, "Cannot commit after already committed or aborted")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index e61d95a1b1..dcb24b26f7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -58,6 +58,11 @@ trait StateStore {
    */
   def remove(condition: UnsafeRow => Boolean): Unit
 
+  /**
+   * Remove a single key.
+   */
+  def remove(key: UnsafeRow): Unit
+
   /**
    * Commit all the updates that have been made to the store, and return the new version.
    */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index 1b56c08f72..589042afb1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming
 
 import scala.reflect.ClassTag
 
+import org.apache.spark.TaskContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.internal.SessionState
@@ -59,10 +60,18 @@ package object state {
         sessionState: SessionState,
         storeCoordinator: Option[StateStoreCoordinatorRef])(
         storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
+
       val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
+      val wrappedF = (store: StateStore, iter: Iterator[T]) => {
+        // Abort the state store in case of error
+        TaskContext.get().addTaskCompletionListener(_ => {
+          if (!store.hasCommitted) store.abort()
+        })
+        cleanedF(store, iter)
+      }
       new StateStoreRDD(
         dataRDD,
-        cleanedF,
+        wrappedF,
         checkpointLocation,
         operatorId,
         storeVersion,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
similarity index 63%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index d4ccced9ac..1292452574 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -22,16 +22,16 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate}
-import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
-import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState}
+import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.execution
+import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.execution.streaming.state._
-import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.TaskContext
+import org.apache.spark.util.CompletionIterator
 
 
 /** Used to identify the state store for a given operator. */
@@ -41,7 +41,7 @@ case class OperatorStateId(
     batchId: Long)
 
 /**
- * An operator that saves or restores state from the [[StateStore]].  The [[OperatorStateId]] should
+ * An operator that reads or writes state from the [[StateStore]].  The [[OperatorStateId]] should
  * be filled in by `prepareForExecution` in [[IncrementalExecution]].
  */
 trait StatefulOperator extends SparkPlan {
@@ -54,6 +54,20 @@ trait StatefulOperator extends SparkPlan {
   }
 }
 
+/** An operator that reads from a StateStore. */
+trait StateStoreReader extends StatefulOperator {
+  override lazy val metrics = Map(
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+}
+
+/** An operator that writes to a StateStore. */
+trait StateStoreWriter extends StatefulOperator {
+  override lazy val metrics = Map(
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
+    "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"),
+    "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"))
+}
+
 /**
  * For each input tuple, the key is calculated and the value from the [[StateStore]] is added
  * to the stream (in addition to the input tuple) if present.
@@ -62,10 +76,7 @@ case class StateStoreRestoreExec(
     keyExpressions: Seq[Attribute],
     stateId: Option[OperatorStateId],
     child: SparkPlan)
-  extends execution.UnaryExecNode with StatefulOperator {
-
-  override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+  extends execution.UnaryExecNode with StateStoreReader {
 
   override protected def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
@@ -102,12 +113,7 @@ case class StateStoreSaveExec(
     outputMode: Option[OutputMode] = None,
     eventTimeWatermark: Option[Long] = None,
     child: SparkPlan)
-  extends execution.UnaryExecNode with StatefulOperator {
-
-  override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
-    "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"),
-    "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"))
+  extends execution.UnaryExecNode with StateStoreWriter {
 
   /** Generate a predicate that matches data older than the watermark */
   private lazy val watermarkPredicate: Option[Predicate] = {
@@ -151,13 +157,6 @@ case class StateStoreSaveExec(
         val numTotalStateRows = longMetric("numTotalStateRows")
         val numUpdatedStateRows = longMetric("numUpdatedStateRows")
 
-        // Abort the state store in case of error
-        TaskContext.get().addTaskCompletionListener(_ => {
-          if (!store.hasCommitted) {
-            store.abort()
-          }
-        })
-
         outputMode match {
           // Update and output all rows in the StateStore.
           case Some(Complete) =>
@@ -184,7 +183,7 @@ case class StateStoreSaveExec(
             }
 
             // Assumption: Append mode can be done only when watermark has been specified
-            store.remove(watermarkPredicate.get.eval)
+            store.remove(watermarkPredicate.get.eval _)
             store.commit()
 
             numTotalStateRows += store.numKeys()
@@ -207,7 +206,7 @@ case class StateStoreSaveExec(
               override def hasNext: Boolean = {
                 if (!baseIterator.hasNext) {
                   // Remove old aggregates if watermark specified
-                  if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval)
+                  if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval _)
                   store.commit()
                   numTotalStateRows += store.numKeys()
                   false
@@ -235,3 +234,90 @@ case class StateStoreSaveExec(
 
   override def outputPartitioning: Partitioning = child.outputPartitioning
 }
+
+
+/** Physical operator for executing streaming mapGroupsWithState. */
+case class MapGroupsWithStateExec(
+    func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any],
+    keyDeserializer: Expression,
+    valueDeserializer: Expression,
+    groupingAttributes: Seq[Attribute],
+    dataAttributes: Seq[Attribute],
+    outputObjAttr: Attribute,
+    stateId: Option[OperatorStateId],
+    stateDeserializer: Expression,
+    stateSerializer: Seq[NamedExpression],
+    child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter {
+
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+
+  /** Distribute by grouping attributes */
+  override def requiredChildDistribution: Seq[Distribution] =
+    ClusteredDistribution(groupingAttributes) :: Nil
+
+  /** Ordering needed for using GroupingIterator */
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+    Seq(groupingAttributes.map(SortOrder(_, Ascending)))
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    child.execute().mapPartitionsWithStateStore[InternalRow](
+      getStateId.checkpointLocation,
+      operatorId = getStateId.operatorId,
+      storeVersion = getStateId.batchId,
+      groupingAttributes.toStructType,
+      child.output.toStructType,
+      sqlContext.sessionState,
+      Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
+        val numTotalStateRows = longMetric("numTotalStateRows")
+        val numUpdatedStateRows = longMetric("numUpdatedStateRows")
+        val numOutputRows = longMetric("numOutputRows")
+
+        // Generate a iterator that returns the rows grouped by the grouping function
+        val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
+
+        // Converters to and from object and rows
+        val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
+        val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
+        val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
+        val getStateObj =
+          ObjectOperator.deserializeRowToObject(stateDeserializer)
+        val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer)
+
+        // For every group, get the key, values and corresponding state and call the function,
+        // and return an iterator of rows
+        val allRowsIterator = groupedIter.flatMap { case (keyRow, valueRowIter) =>
+
+          val key = keyRow.asInstanceOf[UnsafeRow]
+          val keyObj = getKeyObj(keyRow)                         // convert key to objects
+          val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
+          val stateObjOption = store.get(key).map(getStateObj)   // get existing state if any
+          val wrappedState = new KeyedStateImpl(stateObjOption)
+          val mappedIterator = func(keyObj, valueObjIter, wrappedState).map { obj =>
+            numOutputRows += 1
+            getOutputRow(obj) // convert back to rows
+          }
+
+          // Return an iterator of rows generated this key,
+          // such that fully consumed, the updated state value will be saved
+          CompletionIterator[InternalRow, Iterator[InternalRow]](
+            mappedIterator, {
+              // When the iterator is consumed, then write changes to state
+              if (wrappedState.isRemoved) {
+                store.remove(key)
+                numUpdatedStateRows += 1
+              } else if (wrappedState.isUpdated) {
+                store.put(key, outputStateObj(wrappedState.get))
+                numUpdatedStateRows += 1
+              }
+            })
+        }
+
+        // Return an iterator of all the rows generated by all the keys, such that when fully
+        // consumer, all the state updates will be committed by the state store
+        CompletionIterator[InternalRow, Iterator[InternalRow]](allRowsIterator, {
+          store.commit()
+          numTotalStateRows += store.numKeys()
+        })
+      }
+  }
+}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 8304b728aa..5ef4e887de 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -225,6 +225,38 @@ public class JavaDatasetSuite implements Serializable {
 
     Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList()));
 
+    Dataset<String> mapped2 = grouped.mapGroupsWithState(
+      new MapGroupsWithStateFunction<Integer, String, Long, String>() {
+        @Override
+        public String call(Integer key, Iterator<String> values, KeyedState<Long> s) throws Exception {
+          StringBuilder sb = new StringBuilder(key.toString());
+          while (values.hasNext()) {
+            sb.append(values.next());
+          }
+          return sb.toString();
+        }
+        },
+        Encoders.LONG(),
+        Encoders.STRING());
+
+    Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped2.collectAsList()));
+
+    Dataset<String> flatMapped2 = grouped.flatMapGroupsWithState(
+      new FlatMapGroupsWithStateFunction<Integer, String, Long, String>() {
+        @Override
+        public Iterator<String> call(Integer key, Iterator<String> values, KeyedState<Long> s) {
+          StringBuilder sb = new StringBuilder(key.toString());
+          while (values.hasNext()) {
+            sb.append(values.next());
+          }
+          return Collections.singletonList(sb.toString()).iterator();
+        }
+      },
+      Encoders.LONG(),
+      Encoders.STRING());
+
+    Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList()));
+
     Dataset<Tuple2<Integer, String>> reduced = grouped.reduceGroups(new ReduceFunction<String>() {
       @Override
       public String call(String v1, String v2) throws Exception {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala
new file mode 100644
index 0000000000..0524898b15
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala
@@ -0,0 +1,335 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.KeyedState
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
+import org.apache.spark.sql.execution.streaming.{KeyedStateImpl, MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.StateStore
+
+/** Class to check custom state types */
+case class RunningCount(count: Long)
+
+class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll {
+
+  import testImplicits._
+
+  override def afterAll(): Unit = {
+    super.afterAll()
+    StateStore.stop()
+  }
+
+  test("KeyedState - get, exists, update, remove") {
+    var state: KeyedStateImpl[String] = null
+
+    def testState(
+        expectedData: Option[String],
+        shouldBeUpdated: Boolean = false,
+        shouldBeRemoved: Boolean = false): Unit = {
+      if (expectedData.isDefined) {
+        assert(state.exists)
+        assert(state.get === expectedData.get)
+      } else {
+        assert(!state.exists)
+        intercept[NoSuchElementException] {
+          state.get
+        }
+      }
+      assert(state.getOption === expectedData)
+      assert(state.isUpdated === shouldBeUpdated)
+      assert(state.isRemoved === shouldBeRemoved)
+    }
+
+    // Updating empty state
+    state = new KeyedStateImpl[String](None)
+    testState(None)
+    state.update("")
+    testState(Some(""), shouldBeUpdated = true)
+
+    // Updating exiting state
+    state = new KeyedStateImpl[String](Some("2"))
+    testState(Some("2"))
+    state.update("3")
+    testState(Some("3"), shouldBeUpdated = true)
+
+    // Removing state
+    state.remove()
+    testState(None, shouldBeRemoved = true, shouldBeUpdated = false)
+    state.remove()      // should be still callable
+    state.update("4")
+    testState(Some("4"), shouldBeRemoved = false, shouldBeUpdated = true)
+
+    // Updating by null throw exception
+    intercept[IllegalArgumentException] {
+      state.update(null)
+    }
+  }
+
+  test("KeyedState - primitive type") {
+    var intState = new KeyedStateImpl[Int](None)
+    intercept[NoSuchElementException] {
+      intState.get
+    }
+    assert(intState.getOption === None)
+
+    intState = new KeyedStateImpl[Int](Some(10))
+    assert(intState.get == 10)
+    intState.update(0)
+    assert(intState.get == 0)
+    intState.remove()
+    intercept[NoSuchElementException] {
+      intState.get
+    }
+  }
+
+  test("flatMapGroupsWithState - streaming") {
+    // Function to maintain running count up to 2, and then remove the count
+    // Returns the data and the count if state is defined, otherwise does not return anything
+    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+
+      val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+      if (count == 3) {
+        state.remove()
+        Iterator.empty
+      } else {
+        state.update(RunningCount(count))
+        Iterator((key, count.toString))
+      }
+    }
+
+    val inputData = MemoryStream[String]
+    val result =
+      inputData.toDS()
+        .groupByKey(x => x)
+        .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str)
+
+    testStream(result, Append)(
+      AddData(inputData, "a"),
+      CheckLastBatch(("a", "1")),
+      assertNumStateRows(total = 1, updated = 1),
+      AddData(inputData, "a", "b"),
+      CheckLastBatch(("a", "2"), ("b", "1")),
+      assertNumStateRows(total = 2, updated = 2),
+      StopStream,
+      StartStream(),
+      AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a
+      CheckLastBatch(("b", "2")),
+      assertNumStateRows(total = 1, updated = 2),
+      StopStream,
+      StartStream(),
+      AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and
+      CheckLastBatch(("a", "1"), ("c", "1")),
+      assertNumStateRows(total = 3, updated = 2)
+    )
+  }
+
+  test("flatMapGroupsWithState - streaming + func returns iterator that updates state lazily") {
+    // Function to maintain running count up to 2, and then remove the count
+    // Returns the data and the count if state is defined, otherwise does not return anything
+    // Additionally, it updates state lazily as the returned iterator get consumed
+    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+      values.flatMap { _ =>
+        val count = state.getOption.map(_.count).getOrElse(0L) + 1
+        if (count == 3) {
+          state.remove()
+          None
+        } else {
+          state.update(RunningCount(count))
+          Some((key, count.toString))
+        }
+      }
+    }
+
+    val inputData = MemoryStream[String]
+    val result =
+      inputData.toDS()
+        .groupByKey(x => x)
+        .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str)
+
+    testStream(result, Append)(
+      AddData(inputData, "a", "a", "b"),
+      CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")),
+      StopStream,
+      StartStream(),
+      AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a
+      CheckLastBatch(("b", "2")),
+      StopStream,
+      StartStream(),
+      AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and
+      CheckLastBatch(("a", "1"), ("c", "1"))
+    )
+  }
+
+  test("flatMapGroupsWithState - batch") {
+    // Function that returns running count only if its even, otherwise does not return
+    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+      if (state.exists) throw new IllegalArgumentException("state.exists should be false")
+      Iterator((key, values.size))
+    }
+    checkAnswer(
+      Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc).toDF,
+      Seq(("a", 2), ("b", 1)).toDF)
+  }
+
+  test("mapGroupsWithState - streaming") {
+    // Function to maintain running count up to 2, and then remove the count
+    // Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
+    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+
+      val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+      if (count == 3) {
+        state.remove()
+        (key, "-1")
+      } else {
+        state.update(RunningCount(count))
+        (key, count.toString)
+      }
+    }
+
+    val inputData = MemoryStream[String]
+    val result =
+      inputData.toDS()
+        .groupByKey(x => x)
+        .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
+
+    testStream(result, Append)(
+      AddData(inputData, "a"),
+      CheckLastBatch(("a", "1")),
+      assertNumStateRows(total = 1, updated = 1),
+      AddData(inputData, "a", "b"),
+      CheckLastBatch(("a", "2"), ("b", "1")),
+      assertNumStateRows(total = 2, updated = 2),
+      StopStream,
+      StartStream(),
+      AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1
+      CheckLastBatch(("a", "-1"), ("b", "2")),
+      assertNumStateRows(total = 1, updated = 2),
+      StopStream,
+      StartStream(),
+      AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1
+      CheckLastBatch(("a", "1"), ("c", "1")),
+      assertNumStateRows(total = 3, updated = 2)
+    )
+  }
+
+  test("mapGroupsWithState - streaming + aggregation") {
+    // Function to maintain running count up to 2, and then remove the count
+    // Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
+    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+
+      val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+      if (count == 3) {
+        state.remove()
+        (key, "-1")
+      } else {
+        state.update(RunningCount(count))
+        (key, count.toString)
+      }
+    }
+
+    val inputData = MemoryStream[String]
+    val result =
+      inputData.toDS()
+        .groupByKey(x => x)
+        .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
+        .groupByKey(_._1)
+        .count()
+
+    testStream(result, Complete)(
+      AddData(inputData, "a"),
+      CheckLastBatch(("a", 1)),
+      AddData(inputData, "a", "b"),
+      // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1
+      CheckLastBatch(("a", 2), ("b", 1)),
+      StopStream,
+      StartStream(),
+      AddData(inputData, "a", "b"),
+      // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ;
+      // so increment a and b by 1
+      CheckLastBatch(("a", 3), ("b", 2)),
+      StopStream,
+      StartStream(),
+      AddData(inputData, "a", "c"),
+      // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ;
+      // so increment a and c by 1
+      CheckLastBatch(("a", 4), ("b", 2), ("c", 1))
+    )
+  }
+
+  test("mapGroupsWithState - batch") {
+    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+      if (state.exists) throw new IllegalArgumentException("state.exists should be false")
+      (key, values.size)
+    }
+
+    checkAnswer(
+      spark.createDataset(Seq("a", "a", "b"))
+        .groupByKey(x => x)
+        .mapGroupsWithState(stateFunc)
+        .toDF,
+      spark.createDataset(Seq(("a", 2), ("b", 1))).toDF)
+  }
+
+  testQuietly("StateStore.abort on task failure handling") {
+    val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+      if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure")
+      val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+      state.update(RunningCount(count))
+      (key, count)
+    }
+
+    val inputData = MemoryStream[String]
+    val result =
+      inputData.toDS()
+        .groupByKey(x => x)
+        .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
+
+    def setFailInTask(value: Boolean): AssertOnQuery = AssertOnQuery { q =>
+      MapGroupsWithStateSuite.failInTask = value
+      true
+    }
+
+    testStream(result, Append)(
+      setFailInTask(false),
+      AddData(inputData, "a"),
+      CheckLastBatch(("a", 1L)),
+      AddData(inputData, "a"),
+      CheckLastBatch(("a", 2L)),
+      setFailInTask(true),
+      AddData(inputData, "a"),
+      ExpectFailure[SparkException](),   // task should fail but should not increment count
+      setFailInTask(false),
+      StartStream(),
+      CheckLastBatch(("a", 3L))     // task should not fail, and should show correct count
+    )
+  }
+
+  private def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = AssertOnQuery { q =>
+    val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get
+    assert(progressWithData.stateOperators(0).numRowsTotal === total, "incorrect total rows")
+    assert(progressWithData.stateOperators(0).numRowsUpdated === updated, "incorrect updates rows")
+    true
+  }
+}
+
+object MapGroupsWithStateSuite {
+  var failInTask = true
+}
-- 
GitLab