From fa757ee1d41396ad8734a3f2dd045bb09bc82a2e Mon Sep 17 00:00:00 2001
From: Tathagata Das <tathagata.das1565@gmail.com>
Date: Tue, 30 May 2017 15:33:06 -0700
Subject: [PATCH] [SPARK-20883][SPARK-20376][SS] Refactored StateStore APIs and
 added conf to choose implementation

## What changes were proposed in this pull request?

A bunch of changes to the StateStore APIs and implementation.
Current state store API has a bunch of problems that causes too many transient objects causing memory pressure.

- `StateStore.get(): Option` forces creation of Some/None objects for every get. Changed this to return the row or null.
- `StateStore.iterator(): (UnsafeRow, UnsafeRow)` forces creation of new tuple for each record returned. Changed this to return a UnsafeRowTuple which can be reused across records.
- `StateStore.updates()` requires the implementation to keep track of updates, while this is used minimally (only by Append mode in streaming aggregations). Removed updates() and updated StateStoreSaveExec accordingly.
- `StateStore.filter(condition)` and `StateStore.remove(condition)` has been merge into a single API `getRange(start, end)` which allows a state store to do optimized range queries (i.e. avoid full scans). Stateful operators have been updated accordingly.
- Removed a lot of unnecessary row copies Each operator copied rows before calling StateStore.put() even if the implementation does not require it to be copied. It is left up to the implementation on whether to copy the row or not.

Additionally,
- Added a name to the StateStoreId so that each operator+partition can use multiple state stores (different names)
- Added a configuration that allows the user to specify which implementation to use.
- Added new metrics to understand the time taken to update keys, remove keys and commit all changes to the state store. These metrics will be visible on the plan diagram in the SQL tab of the UI.
- Refactored unit tests such that they can be reused to test any implementation of StateStore.

## How was this patch tested?
Old and new unit tests

Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #18107 from tdas/SPARK-20376.
---
 .../apache/spark/sql/internal/SQLConf.scala   |  11 +
 .../FlatMapGroupsWithStateExec.scala          |  39 +-
 .../state/HDFSBackedStateStoreProvider.scala  | 218 +++----
 .../streaming/state/StateStore.scala          | 163 ++++--
 .../streaming/state/StateStoreConf.scala      |  28 +-
 .../streaming/state/StateStoreRDD.scala       |  11 +-
 .../execution/streaming/state/package.scala   |  11 +-
 .../streaming/statefulOperators.scala         | 142 +++--
 .../streaming/state/StateStoreRDDSuite.scala  |  41 +-
 .../streaming/state/StateStoreSuite.scala     | 534 ++++++++----------
 .../FlatMapGroupsWithStateSuite.scala         |  40 +-
 .../spark/sql/streaming/StreamSuite.scala     |  45 ++
 12 files changed, 695 insertions(+), 588 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index c5d69c2046..c6f5cf641b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -552,6 +552,15 @@ object SQLConf {
     .booleanConf
     .createWithDefault(true)
 
+  val STATE_STORE_PROVIDER_CLASS =
+    buildConf("spark.sql.streaming.stateStore.providerClass")
+      .internal()
+      .doc(
+        "The class used to manage state data in stateful streaming queries. This class must " +
+          "be a subclass of StateStoreProvider, and must have a zero-arg constructor.")
+      .stringConf
+      .createOptional
+
   val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT =
     buildConf("spark.sql.streaming.stateStore.minDeltasForSnapshot")
       .internal()
@@ -828,6 +837,8 @@ class SQLConf extends Serializable with Logging {
 
   def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD)
 
+  def stateStoreProviderClass: Option[String] = getConf(STATE_STORE_PROVIDER_CLASS)
+
   def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)
 
   def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 3ceb4cf84a..2aad8701a4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -109,9 +109,11 @@ case class FlatMapGroupsWithStateExec(
     child.execute().mapPartitionsWithStateStore[InternalRow](
       getStateId.checkpointLocation,
       getStateId.operatorId,
+      storeName = "default",
       getStateId.batchId,
       groupingAttributes.toStructType,
       stateAttributes.toStructType,
+      indexOrdinal = None,
       sqlContext.sessionState,
       Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
         val updater = new StateStoreUpdater(store)
@@ -191,12 +193,12 @@ case class FlatMapGroupsWithStateExec(
             throw new IllegalStateException(
               s"Cannot filter timed out keys for $timeoutConf")
         }
-        val timingOutKeys = store.filter { case (_, stateRow) =>
-          val timeoutTimestamp = getTimeoutTimestamp(stateRow)
+        val timingOutKeys = store.getRange(None, None).filter { rowPair =>
+          val timeoutTimestamp = getTimeoutTimestamp(rowPair.value)
           timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold
         }
-        timingOutKeys.flatMap { case (keyRow, stateRow) =>
-          callFunctionAndUpdateState(keyRow, Iterator.empty, Some(stateRow), hasTimedOut = true)
+        timingOutKeys.flatMap { rowPair =>
+          callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true)
         }
       } else Iterator.empty
     }
@@ -205,18 +207,23 @@ case class FlatMapGroupsWithStateExec(
      * Call the user function on a key's data, update the state store, and return the return data
      * iterator. Note that the store updating is lazy, that is, the store will be updated only
      * after the returned iterator is fully consumed.
+     *
+     * @param keyRow Row representing the key, cannot be null
+     * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty
+     * @param prevStateRow Row representing the previous state, can be null
+     * @param hasTimedOut Whether this function is being called for a key timeout
      */
     private def callFunctionAndUpdateState(
         keyRow: UnsafeRow,
         valueRowIter: Iterator[InternalRow],
-        prevStateRowOption: Option[UnsafeRow],
+        prevStateRow: UnsafeRow,
         hasTimedOut: Boolean): Iterator[InternalRow] = {
 
       val keyObj = getKeyObj(keyRow)  // convert key to objects
       val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
-      val stateObjOption = getStateObj(prevStateRowOption)
+      val stateObj = getStateObj(prevStateRow)
       val keyedState = GroupStateImpl.createForStreaming(
-        stateObjOption,
+        Option(stateObj),
         batchTimestampMs.getOrElse(NO_TIMESTAMP),
         eventTimeWatermark.getOrElse(NO_TIMESTAMP),
         timeoutConf,
@@ -249,14 +256,11 @@ case class FlatMapGroupsWithStateExec(
           numUpdatedStateRows += 1
 
         } else {
-          val previousTimeoutTimestamp = prevStateRowOption match {
-            case Some(row) => getTimeoutTimestamp(row)
-            case None => NO_TIMESTAMP
-          }
+          val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow)
           val stateRowToWrite = if (keyedState.hasUpdated) {
             getStateRow(keyedState.get)
           } else {
-            prevStateRowOption.orNull
+            prevStateRow
           }
 
           val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp
@@ -269,7 +273,7 @@ case class FlatMapGroupsWithStateExec(
               throw new IllegalStateException("Attempting to write empty state")
             }
             setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp)
-            store.put(keyRow.copy(), stateRowToWrite.copy())
+            store.put(keyRow, stateRowToWrite)
             numUpdatedStateRows += 1
           }
         }
@@ -280,18 +284,21 @@ case class FlatMapGroupsWithStateExec(
     }
 
     /** Returns the state as Java object if defined */
-    def getStateObj(stateRowOption: Option[UnsafeRow]): Option[Any] = {
-      stateRowOption.map(getStateObjFromRow)
+    def getStateObj(stateRow: UnsafeRow): Any = {
+      if (stateRow != null) getStateObjFromRow(stateRow) else null
     }
 
     /** Returns the row for an updated state */
     def getStateRow(obj: Any): UnsafeRow = {
+      assert(obj != null)
       getStateRowFromObj(obj)
     }
 
     /** Returns the timeout timestamp of a state row is set */
     def getTimeoutTimestamp(stateRow: UnsafeRow): Long = {
-      if (isTimeoutEnabled) stateRow.getLong(timeoutTimestampIndex) else NO_TIMESTAMP
+      if (isTimeoutEnabled && stateRow != null) {
+        stateRow.getLong(timeoutTimestampIndex)
+      } else NO_TIMESTAMP
     }
 
     /** Set the timestamp in a state row */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index fb2bf47d6e..67d86daf10 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -67,13 +67,7 @@ import org.apache.spark.util.Utils
  * to ensure re-executed RDD operations re-apply updates on the correct past version of the
  * store.
  */
-private[state] class HDFSBackedStateStoreProvider(
-    val id: StateStoreId,
-    keySchema: StructType,
-    valueSchema: StructType,
-    storeConf: StateStoreConf,
-    hadoopConf: Configuration
-  ) extends StateStoreProvider with Logging {
+private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider with Logging {
 
   // ConcurrentHashMap is used because it generates fail-safe iterators on filtering
   // - The iterator is weakly consistent with the map, i.e., iterator's data reflect the values in
@@ -95,92 +89,36 @@ private[state] class HDFSBackedStateStoreProvider(
     private val newVersion = version + 1
     private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}")
     private lazy val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true))
-    private val allUpdates = new java.util.HashMap[UnsafeRow, StoreUpdate]()
-
     @volatile private var state: STATE = UPDATING
     @volatile private var finalDeltaFile: Path = null
 
     override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id
 
-    override def get(key: UnsafeRow): Option[UnsafeRow] = {
-      Option(mapToUpdate.get(key))
-    }
-
-    override def filter(
-        condition: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = {
-      mapToUpdate
-        .entrySet
-        .asScala
-        .iterator
-        .filter { entry => condition(entry.getKey, entry.getValue) }
-        .map { entry => (entry.getKey, entry.getValue) }
+    override def get(key: UnsafeRow): UnsafeRow = {
+      mapToUpdate.get(key)
     }
 
     override def put(key: UnsafeRow, value: UnsafeRow): Unit = {
       verify(state == UPDATING, "Cannot put after already committed or aborted")
-
-      val isNewKey = !mapToUpdate.containsKey(key)
-      mapToUpdate.put(key, value)
-
-      Option(allUpdates.get(key)) match {
-        case Some(ValueAdded(_, _)) =>
-          // Value did not exist in previous version and was added already, keep it marked as added
-          allUpdates.put(key, ValueAdded(key, value))
-        case Some(ValueUpdated(_, _)) | Some(ValueRemoved(_, _)) =>
-          // Value existed in previous version and updated/removed, mark it as updated
-          allUpdates.put(key, ValueUpdated(key, value))
-        case None =>
-          // There was no prior update, so mark this as added or updated according to its presence
-          // in previous version.
-          val update = if (isNewKey) ValueAdded(key, value) else ValueUpdated(key, value)
-          allUpdates.put(key, update)
-      }
-      writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value))
+      val keyCopy = key.copy()
+      val valueCopy = value.copy()
+      mapToUpdate.put(keyCopy, valueCopy)
+      writeUpdateToDeltaFile(tempDeltaFileStream, keyCopy, valueCopy)
     }
 
-    /** Remove keys that match the following condition */
-    override def remove(condition: UnsafeRow => Boolean): Unit = {
+    override def remove(key: UnsafeRow): Unit = {
       verify(state == UPDATING, "Cannot remove after already committed or aborted")
-      val entryIter = mapToUpdate.entrySet().iterator()
-      while (entryIter.hasNext) {
-        val entry = entryIter.next
-        if (condition(entry.getKey)) {
-          val value = entry.getValue
-          val key = entry.getKey
-          entryIter.remove()
-
-          Option(allUpdates.get(key)) match {
-            case Some(ValueUpdated(_, _)) | None =>
-              // Value existed in previous version and maybe was updated, mark removed
-              allUpdates.put(key, ValueRemoved(key, value))
-            case Some(ValueAdded(_, _)) =>
-              // Value did not exist in previous version and was added, should not appear in updates
-              allUpdates.remove(key)
-            case Some(ValueRemoved(_, _)) =>
-              // Remove already in update map, no need to change
-          }
-          writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value))
-        }
+      val prevValue = mapToUpdate.remove(key)
+      if (prevValue != null) {
+        writeRemoveToDeltaFile(tempDeltaFileStream, key)
       }
     }
 
-    /** Remove a single key. */
-    override def remove(key: UnsafeRow): Unit = {
-      verify(state == UPDATING, "Cannot remove after already committed or aborted")
-      if (mapToUpdate.containsKey(key)) {
-        val value = mapToUpdate.remove(key)
-        Option(allUpdates.get(key)) match {
-          case Some(ValueUpdated(_, _)) | None =>
-            // Value existed in previous version and maybe was updated, mark removed
-            allUpdates.put(key, ValueRemoved(key, value))
-          case Some(ValueAdded(_, _)) =>
-            // Value did not exist in previous version and was added, should not appear in updates
-            allUpdates.remove(key)
-          case Some(ValueRemoved(_, _)) =>
-          // Remove already in update map, no need to change
-        }
-        writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value))
-      }
+    override def getRange(
+        start: Option[UnsafeRow],
+        end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = {
+      verify(state == UPDATING, "Cannot getRange after already committed or aborted")
+      iterator()
     }
 
     /** Commit all the updates that have been made to the store, and return the new version. */
@@ -227,20 +165,11 @@ private[state] class HDFSBackedStateStoreProvider(
      * Get an iterator of all the store data.
      * This can be called only after committing all the updates made in the current thread.
      */
-    override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = {
-      verify(state == COMMITTED,
-        "Cannot get iterator of store data before committing or after aborting")
-      HDFSBackedStateStoreProvider.this.iterator(newVersion)
-    }
-
-    /**
-     * Get an iterator of all the updates made to the store in the current version.
-     * This can be called only after committing all the updates made in the current thread.
-     */
-    override def updates(): Iterator[StoreUpdate] = {
-      verify(state == COMMITTED,
-        "Cannot get iterator of updates before committing or after aborting")
-      allUpdates.values().asScala.toIterator
+    override def iterator(): Iterator[UnsafeRowPair] = {
+      val unsafeRowPair = new UnsafeRowPair()
+      mapToUpdate.entrySet.asScala.iterator.map { entry =>
+        unsafeRowPair.withRows(entry.getKey, entry.getValue)
+      }
     }
 
     override def numKeys(): Long = mapToUpdate.size()
@@ -269,6 +198,23 @@ private[state] class HDFSBackedStateStoreProvider(
     store
   }
 
+  override def init(
+      stateStoreId: StateStoreId,
+      keySchema: StructType,
+      valueSchema: StructType,
+      indexOrdinal: Option[Int], // for sorting the data
+      storeConf: StateStoreConf,
+      hadoopConf: Configuration): Unit = {
+    this.stateStoreId = stateStoreId
+    this.keySchema = keySchema
+    this.valueSchema = valueSchema
+    this.storeConf = storeConf
+    this.hadoopConf = hadoopConf
+    fs.mkdirs(baseDir)
+  }
+
+  override def id: StateStoreId = stateStoreId
+
   /** Do maintenance backing data files, including creating snapshots and cleaning up old files */
   override def doMaintenance(): Unit = {
     try {
@@ -280,19 +226,27 @@ private[state] class HDFSBackedStateStoreProvider(
     }
   }
 
+  override def close(): Unit = {
+    loadedMaps.values.foreach(_.clear())
+  }
+
   override def toString(): String = {
     s"HDFSStateStoreProvider[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]"
   }
 
-  /* Internal classes and methods */
+  /* Internal fields and methods */
 
-  private val loadedMaps = new mutable.HashMap[Long, MapType]
-  private val baseDir =
-    new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}")
-  private val fs = baseDir.getFileSystem(hadoopConf)
-  private val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
+  @volatile private var stateStoreId: StateStoreId = _
+  @volatile private var keySchema: StructType = _
+  @volatile private var valueSchema: StructType = _
+  @volatile private var storeConf: StateStoreConf = _
+  @volatile private var hadoopConf: Configuration = _
 
-  initialize()
+  private lazy val loadedMaps = new mutable.HashMap[Long, MapType]
+  private lazy val baseDir =
+    new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}")
+  private lazy val fs = baseDir.getFileSystem(hadoopConf)
+  private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
 
   private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean)
 
@@ -323,35 +277,18 @@ private[state] class HDFSBackedStateStoreProvider(
    * Get iterator of all the data of the latest version of the store.
    * Note that this will look up the files to determined the latest known version.
    */
-  private[state] def latestIterator(): Iterator[(UnsafeRow, UnsafeRow)] = synchronized {
+  private[state] def latestIterator(): Iterator[UnsafeRowPair] = synchronized {
     val versionsInFiles = fetchFiles().map(_.version).toSet
     val versionsLoaded = loadedMaps.keySet
     val allKnownVersions = versionsInFiles ++ versionsLoaded
+    val unsafeRowTuple = new UnsafeRowPair()
     if (allKnownVersions.nonEmpty) {
-      loadMap(allKnownVersions.max).entrySet().iterator().asScala.map { x =>
-        (x.getKey, x.getValue)
+      loadMap(allKnownVersions.max).entrySet().iterator().asScala.map { entry =>
+        unsafeRowTuple.withRows(entry.getKey, entry.getValue)
       }
     } else Iterator.empty
   }
 
-  /** Get iterator of a specific version of the store */
-  private[state] def iterator(version: Long): Iterator[(UnsafeRow, UnsafeRow)] = synchronized {
-    loadMap(version).entrySet().iterator().asScala.map { x =>
-      (x.getKey, x.getValue)
-    }
-  }
-
-  /** Initialize the store provider */
-  private def initialize(): Unit = {
-    try {
-      fs.mkdirs(baseDir)
-    } catch {
-      case e: IOException =>
-        throw new IllegalStateException(
-          s"Cannot use ${id.checkpointLocation} for storing state data for $this: $e ", e)
-    }
-  }
-
   /** Load the required version of the map data from the backing files */
   private def loadMap(version: Long): MapType = {
     if (version <= 0) return new MapType
@@ -367,32 +304,23 @@ private[state] class HDFSBackedStateStoreProvider(
     }
   }
 
-  private def writeToDeltaFile(output: DataOutputStream, update: StoreUpdate): Unit = {
-
-    def writeUpdate(key: UnsafeRow, value: UnsafeRow): Unit = {
-      val keyBytes = key.getBytes()
-      val valueBytes = value.getBytes()
-      output.writeInt(keyBytes.size)
-      output.write(keyBytes)
-      output.writeInt(valueBytes.size)
-      output.write(valueBytes)
-    }
-
-    def writeRemove(key: UnsafeRow): Unit = {
-      val keyBytes = key.getBytes()
-      output.writeInt(keyBytes.size)
-      output.write(keyBytes)
-      output.writeInt(-1)
-    }
+  private def writeUpdateToDeltaFile(
+      output: DataOutputStream,
+      key: UnsafeRow,
+      value: UnsafeRow): Unit = {
+    val keyBytes = key.getBytes()
+    val valueBytes = value.getBytes()
+    output.writeInt(keyBytes.size)
+    output.write(keyBytes)
+    output.writeInt(valueBytes.size)
+    output.write(valueBytes)
+  }
 
-    update match {
-      case ValueAdded(key, value) =>
-        writeUpdate(key, value)
-      case ValueUpdated(key, value) =>
-        writeUpdate(key, value)
-      case ValueRemoved(key, value) =>
-        writeRemove(key)
-    }
+  private def writeRemoveToDeltaFile(output: DataOutputStream, key: UnsafeRow): Unit = {
+    val keyBytes = key.getBytes()
+    output.writeInt(keyBytes.size)
+    output.write(keyBytes)
+    output.writeInt(-1)
   }
 
   private def finalizeDeltaFile(output: DataOutputStream): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index eaa558eb6d..29c456f86e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -29,15 +29,12 @@ import org.apache.spark.SparkEnv
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.ThreadUtils
-
-
-/** Unique identifier for a [[StateStore]] */
-case class StateStoreId(checkpointLocation: String, operatorId: Long, partitionId: Int)
+import org.apache.spark.util.{ThreadUtils, Utils}
 
 
 /**
- * Base trait for a versioned key-value store used for streaming aggregations
+ * Base trait for a versioned key-value store. Each instance of a `StateStore` represents a specific
+ * version of state data, and such instances are created through a [[StateStoreProvider]].
  */
 trait StateStore {
 
@@ -47,50 +44,54 @@ trait StateStore {
   /** Version of the data in this store before committing updates. */
   def version: Long
 
-  /** Get the current value of a key. */
-  def get(key: UnsafeRow): Option[UnsafeRow]
-
   /**
-   * Return an iterator of key-value pairs that satisfy a certain condition.
-   * Note that the iterator must be fail-safe towards modification to the store, that is,
-   * it must be based on the snapshot of store the time of this call, and any change made to the
-   * store while iterating through iterator should not cause the iterator to fail or have
-   * any affect on the values in the iterator.
+   * Get the current value of a non-null key.
+   * @return a non-null row if the key exists in the store, otherwise null.
    */
-  def filter(condition: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)]
+  def get(key: UnsafeRow): UnsafeRow
 
-  /** Put a new value for a key. */
+  /**
+   * Put a new value for a non-null key. Implementations must be aware that the UnsafeRows in
+   * the params can be reused, and must make copies of the data as needed for persistence.
+   */
   def put(key: UnsafeRow, value: UnsafeRow): Unit
 
   /**
-   * Remove keys that match the following condition.
+   * Remove a single non-null key.
    */
-  def remove(condition: UnsafeRow => Boolean): Unit
+  def remove(key: UnsafeRow): Unit
 
   /**
-   * Remove a single key.
+   * Get key value pairs with optional approximate `start` and `end` extents.
+   * If the State Store implementation maintains indices for the data based on the optional
+   * `keyIndexOrdinal` over fields `keySchema` (see `StateStoreProvider.init()`), then it can use
+   * `start` and `end` to make a best-effort scan over the data. Default implementation returns
+   * the full data scan iterator, which is correct but inefficient. Custom implementations must
+   * ensure that updates (puts, removes) can be made while iterating over this iterator.
+   *
+   * @param start UnsafeRow having the `keyIndexOrdinal` column set with appropriate starting value.
+   * @param end UnsafeRow having the `keyIndexOrdinal` column set with appropriate ending value.
+   * @return An iterator of key-value pairs that is guaranteed not miss any key between start and
+   *         end, both inclusive.
    */
-  def remove(key: UnsafeRow): Unit
+  def getRange(start: Option[UnsafeRow], end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = {
+    iterator()
+  }
 
   /**
    * Commit all the updates that have been made to the store, and return the new version.
+   * Implementations should ensure that no more updates (puts, removes) can be after a commit in
+   * order to avoid incorrect usage.
    */
   def commit(): Long
 
-  /** Abort all the updates that have been made to the store. */
-  def abort(): Unit
-
   /**
-   * Iterator of store data after a set of updates have been committed.
-   * This can be called only after committing all the updates made in the current thread.
+   * Abort all the updates that have been made to the store. Implementations should ensure that
+   * no more updates (puts, removes) can be after an abort in order to avoid incorrect usage.
    */
-  def iterator(): Iterator[(UnsafeRow, UnsafeRow)]
+  def abort(): Unit
 
-  /**
-   * Iterator of the updates that have been committed.
-   * This can be called only after committing all the updates made in the current thread.
-   */
-  def updates(): Iterator[StoreUpdate]
+  def iterator(): Iterator[UnsafeRowPair]
 
   /** Number of keys in the state store */
   def numKeys(): Long
@@ -102,28 +103,98 @@ trait StateStore {
 }
 
 
-/** Trait representing a provider of a specific version of a [[StateStore]]. */
+/**
+ * Trait representing a provider that provide [[StateStore]] instances representing
+ * versions of state data.
+ *
+ * The life cycle of a provider and its provide stores are as follows.
+ *
+ * - A StateStoreProvider is created in a executor for each unique [[StateStoreId]] when
+ *   the first batch of a streaming query is executed on the executor. All subsequent batches reuse
+ *   this provider instance until the query is stopped.
+ *
+ * - Every batch of streaming data request a specific version of the state data by invoking
+ *   `getStore(version)` which returns an instance of [[StateStore]] through which the required
+ *   version of the data can be accessed. It is the responsible of the provider to populate
+ *   this store with context information like the schema of keys and values, etc.
+ *
+ * - After the streaming query is stopped, the created provider instances are lazily disposed off.
+ */
 trait StateStoreProvider {
 
-  /** Get the store with the existing version. */
+  /**
+   * Initialize the provide with more contextual information from the SQL operator.
+   * This method will be called first after creating an instance of the StateStoreProvider by
+   * reflection.
+   *
+   * @param stateStoreId Id of the versioned StateStores that this provider will generate
+   * @param keySchema Schema of keys to be stored
+   * @param valueSchema Schema of value to be stored
+   * @param keyIndexOrdinal Optional column (represent as the ordinal of the field in keySchema) by
+   *                        which the StateStore implementation could index the data.
+   * @param storeConfs Configurations used by the StateStores
+   * @param hadoopConf Hadoop configuration that could be used by StateStore to save state data
+   */
+  def init(
+      stateStoreId: StateStoreId,
+      keySchema: StructType,
+      valueSchema: StructType,
+      keyIndexOrdinal: Option[Int], // for sorting the data by their keys
+      storeConfs: StateStoreConf,
+      hadoopConf: Configuration): Unit
+
+  /**
+   * Return the id of the StateStores this provider will generate.
+   * Should be the same as the one passed in init().
+   */
+  def id: StateStoreId
+
+  /** Called when the provider instance is unloaded from the executor */
+  def close(): Unit
+
+  /** Return an instance of [[StateStore]] representing state data of the given version */
   def getStore(version: Long): StateStore
 
-  /** Optional method for providers to allow for background maintenance */
+  /** Optional method for providers to allow for background maintenance (e.g. compactions) */
   def doMaintenance(): Unit = { }
 }
 
-
-/** Trait representing updates made to a [[StateStore]]. */
-sealed trait StoreUpdate {
-  def key: UnsafeRow
-  def value: UnsafeRow
+object StateStoreProvider {
+  /**
+   * Return a provider instance of the given provider class.
+   * The instance will be already initialized.
+   */
+  def instantiate(
+      stateStoreId: StateStoreId,
+      keySchema: StructType,
+      valueSchema: StructType,
+      indexOrdinal: Option[Int], // for sorting the data
+      storeConf: StateStoreConf,
+      hadoopConf: Configuration): StateStoreProvider = {
+    val providerClass = storeConf.providerClass.map(Utils.classForName)
+        .getOrElse(classOf[HDFSBackedStateStoreProvider])
+    val provider = providerClass.newInstance().asInstanceOf[StateStoreProvider]
+    provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
+    provider
+  }
 }
 
-case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate
 
-case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate
+/** Unique identifier for a bunch of keyed state data. */
+case class StateStoreId(
+    checkpointLocation: String,
+    operatorId: Long,
+    partitionId: Int,
+    name: String = "")
 
-case class ValueRemoved(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate
+/** Mutable, and reusable class for representing a pair of UnsafeRows. */
+class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) {
+  def withRows(key: UnsafeRow, value: UnsafeRow): UnsafeRowPair = {
+    this.key = key
+    this.value = value
+    this
+  }
+}
 
 
 /**
@@ -185,6 +256,7 @@ object StateStore extends Logging {
       storeId: StateStoreId,
       keySchema: StructType,
       valueSchema: StructType,
+      indexOrdinal: Option[Int],
       version: Long,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): StateStore = {
@@ -193,7 +265,9 @@ object StateStore extends Logging {
       startMaintenanceIfNeeded()
       val provider = loadedProviders.getOrElseUpdate(
         storeId,
-        new HDFSBackedStateStoreProvider(storeId, keySchema, valueSchema, storeConf, hadoopConf))
+        StateStoreProvider.instantiate(
+          storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
+      )
       reportActiveStoreInstance(storeId)
       provider
     }
@@ -202,7 +276,7 @@ object StateStore extends Logging {
 
   /** Unload a state store provider */
   def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized {
-    loadedProviders.remove(storeId)
+    loadedProviders.remove(storeId).foreach(_.close())
   }
 
   /** Whether a state store provider is loaded or not */
@@ -216,6 +290,7 @@ object StateStore extends Logging {
 
   /** Unload and stop all state store providers */
   def stop(): Unit = loadedProviders.synchronized {
+    loadedProviders.keySet.foreach { key => unload(key) }
     loadedProviders.clear()
     _coordRef = null
     if (maintenanceTask != null) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
index acfaa8e5eb..bab297c7df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
@@ -20,16 +20,34 @@ package org.apache.spark.sql.execution.streaming.state
 import org.apache.spark.sql.internal.SQLConf
 
 /** A class that contains configuration parameters for [[StateStore]]s. */
-private[streaming] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable {
+class StateStoreConf(@transient private val sqlConf: SQLConf)
+  extends Serializable {
 
   def this() = this(new SQLConf)
 
-  val minDeltasForSnapshot = conf.stateStoreMinDeltasForSnapshot
-
-  val minVersionsToRetain = conf.minBatchesToRetain
+  /**
+   * Minimum number of delta files in a chain after which HDFSBackedStateStore will
+   * consider generating a snapshot.
+   */
+  val minDeltasForSnapshot: Int = sqlConf.stateStoreMinDeltasForSnapshot
+
+  /** Minimum versions a State Store implementation should retain to allow rollbacks */
+  val minVersionsToRetain: Int = sqlConf.minBatchesToRetain
+
+  /**
+   * Optional fully qualified name of the subclass of [[StateStoreProvider]]
+   * managing state data. That is, the implementation of the State Store to use.
+   */
+  val providerClass: Option[String] = sqlConf.stateStoreProviderClass
+
+  /**
+   * Additional configurations related to state store. This will capture all configs in
+   * SQLConf that start with `spark.sql.streaming.stateStore.` */
+  val confs: Map[String, String] =
+    sqlConf.getAllConfs.filter(_._1.startsWith("spark.sql.streaming.stateStore."))
 }
 
-private[streaming] object StateStoreConf {
+object StateStoreConf {
   val empty = new StateStoreConf()
 
   def apply(conf: SQLConf): StateStoreConf = new StateStoreConf(conf)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
index e16dda8a5b..b744c25dc9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -35,9 +35,11 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
     storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
     checkpointLocation: String,
     operatorId: Long,
+    storeName: String,
     storeVersion: Long,
     keySchema: StructType,
     valueSchema: StructType,
+    indexOrdinal: Option[Int],
     sessionState: SessionState,
     @transient private val storeCoordinator: Option[StateStoreCoordinatorRef])
   extends RDD[U](dataRDD) {
@@ -45,21 +47,22 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
   private val storeConf = new StateStoreConf(sessionState.conf)
 
   // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
-  private val confBroadcast = dataRDD.context.broadcast(
+  private val hadoopConfBroadcast = dataRDD.context.broadcast(
     new SerializableConfiguration(sessionState.newHadoopConf()))
 
   override protected def getPartitions: Array[Partition] = dataRDD.partitions
 
   override def getPreferredLocations(partition: Partition): Seq[String] = {
-    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
+    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName)
     storeCoordinator.flatMap(_.getLocation(storeId)).toSeq
   }
 
   override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = {
     var store: StateStore = null
-    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
+    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName)
     store = StateStore.get(
-      storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value)
+      storeId, keySchema, valueSchema, indexOrdinal, storeVersion,
+      storeConf, hadoopConfBroadcast.value.value)
     val inputIter = dataRDD.iterator(partition, ctxt)
     storeUpdateFunction(store, inputIter)
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index 589042afb1..228fe86d59 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -34,17 +34,21 @@ package object state {
         sqlContext: SQLContext,
         checkpointLocation: String,
         operatorId: Long,
+        storeName: String,
         storeVersion: Long,
         keySchema: StructType,
-        valueSchema: StructType)(
+        valueSchema: StructType,
+        indexOrdinal: Option[Int])(
         storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
 
       mapPartitionsWithStateStore(
         checkpointLocation,
         operatorId,
+        storeName,
         storeVersion,
         keySchema,
         valueSchema,
+        indexOrdinal,
         sqlContext.sessionState,
         Some(sqlContext.streams.stateStoreCoordinator))(
         storeUpdateFunction)
@@ -54,9 +58,11 @@ package object state {
     private[streaming] def mapPartitionsWithStateStore[U: ClassTag](
         checkpointLocation: String,
         operatorId: Long,
+        storeName: String,
         storeVersion: Long,
         keySchema: StructType,
         valueSchema: StructType,
+        indexOrdinal: Option[Int],
         sessionState: SessionState,
         storeCoordinator: Option[StateStoreCoordinatorRef])(
         storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
@@ -69,14 +75,17 @@ package object state {
         })
         cleanedF(store, iter)
       }
+
       new StateStoreRDD(
         dataRDD,
         wrappedF,
         checkpointLocation,
         operatorId,
+        storeName,
         storeVersion,
         keySchema,
         valueSchema,
+        indexOrdinal,
         sessionState,
         storeCoordinator)
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index 8dbda298c8..3e57f3fbad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -17,21 +17,22 @@
 
 package org.apache.spark.sql.execution.streaming
 
+import java.util.concurrent.TimeUnit._
+
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate}
-import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalGroupState, ProcessingTimeTimeout}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
 import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.execution.streaming.state._
-import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
+import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types._
-import org.apache.spark.util.CompletionIterator
+import org.apache.spark.util.{CompletionIterator, NextIterator}
 
 
 /** Used to identify the state store for a given operator. */
@@ -61,11 +62,24 @@ trait StateStoreReader extends StatefulOperator {
 }
 
 /** An operator that writes to a StateStore. */
-trait StateStoreWriter extends StatefulOperator {
+trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
+
   override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
     "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"),
-    "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"))
+    "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"),
+    "allUpdatesTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total time to update rows"),
+    "allRemovalsTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total time to remove rows"),
+    "commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to commit changes")
+  )
+
+  /** Records the duration of running `body` for the next query progress update. */
+  protected def timeTakenMs(body: => Unit): Long = {
+    val startTime = System.nanoTime()
+    val result = body
+    val endTime = System.nanoTime()
+    math.max(NANOSECONDS.toMillis(endTime - startTime), 0)
+  }
 }
 
 /** An operator that supports watermark. */
@@ -108,6 +122,16 @@ trait WatermarkSupport extends UnaryExecNode {
   /** Predicate based on the child output that matches data older than the watermark. */
   lazy val watermarkPredicateForData: Option[Predicate] =
     watermarkExpression.map(newPredicate(_, child.output))
+
+  protected def removeKeysOlderThanWatermark(store: StateStore): Unit = {
+    if (watermarkPredicateForKeys.nonEmpty) {
+      store.getRange(None, None).foreach { rowPair =>
+        if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
+          store.remove(rowPair.key)
+        }
+      }
+    }
+  }
 }
 
 /**
@@ -126,9 +150,11 @@ case class StateStoreRestoreExec(
     child.execute().mapPartitionsWithStateStore(
       getStateId.checkpointLocation,
       operatorId = getStateId.operatorId,
+      storeName = "default",
       storeVersion = getStateId.batchId,
       keyExpressions.toStructType,
       child.output.toStructType,
+      indexOrdinal = None,
       sqlContext.sessionState,
       Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
         val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
@@ -136,7 +162,7 @@ case class StateStoreRestoreExec(
           val key = getKey(row)
           val savedState = store.get(key)
           numOutputRows += 1
-          row +: savedState.toSeq
+          row +: Option(savedState).toSeq
         }
     }
   }
@@ -165,54 +191,88 @@ case class StateStoreSaveExec(
     child.execute().mapPartitionsWithStateStore(
       getStateId.checkpointLocation,
       getStateId.operatorId,
+      storeName = "default",
       getStateId.batchId,
       keyExpressions.toStructType,
       child.output.toStructType,
+      indexOrdinal = None,
       sqlContext.sessionState,
       Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
         val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
         val numOutputRows = longMetric("numOutputRows")
         val numTotalStateRows = longMetric("numTotalStateRows")
         val numUpdatedStateRows = longMetric("numUpdatedStateRows")
+        val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
+        val allRemovalsTimeMs = longMetric("allRemovalsTimeMs")
+        val commitTimeMs = longMetric("commitTimeMs")
 
         outputMode match {
           // Update and output all rows in the StateStore.
           case Some(Complete) =>
-            while (iter.hasNext) {
-              val row = iter.next().asInstanceOf[UnsafeRow]
-              val key = getKey(row)
-              store.put(key.copy(), row.copy())
-              numUpdatedStateRows += 1
+            allUpdatesTimeMs += timeTakenMs {
+              while (iter.hasNext) {
+                val row = iter.next().asInstanceOf[UnsafeRow]
+                val key = getKey(row)
+                store.put(key, row)
+                numUpdatedStateRows += 1
+              }
+            }
+            allRemovalsTimeMs += 0
+            commitTimeMs += timeTakenMs {
+              store.commit()
             }
-            store.commit()
             numTotalStateRows += store.numKeys()
-            store.iterator().map { case (k, v) =>
+            store.iterator().map { rowPair =>
               numOutputRows += 1
-              v.asInstanceOf[InternalRow]
+              rowPair.value
             }
 
           // Update and output only rows being evicted from the StateStore
+          // Assumption: watermark predicates must be non-empty if append mode is allowed
           case Some(Append) =>
-            while (iter.hasNext) {
-              val row = iter.next().asInstanceOf[UnsafeRow]
-              val key = getKey(row)
-              store.put(key.copy(), row.copy())
-              numUpdatedStateRows += 1
+            allUpdatesTimeMs += timeTakenMs {
+              val filteredIter = iter.filter(row => !watermarkPredicateForData.get.eval(row))
+              while (filteredIter.hasNext) {
+                val row = filteredIter.next().asInstanceOf[UnsafeRow]
+                val key = getKey(row)
+                store.put(key, row)
+                numUpdatedStateRows += 1
+              }
             }
 
-            // Assumption: Append mode can be done only when watermark has been specified
-            store.remove(watermarkPredicateForKeys.get.eval _)
-            store.commit()
+            val removalStartTimeNs = System.nanoTime
+            val rangeIter = store.getRange(None, None)
+
+            new NextIterator[InternalRow] {
+              override protected def getNext(): InternalRow = {
+                var removedValueRow: InternalRow = null
+                while(rangeIter.hasNext && removedValueRow == null) {
+                  val rowPair = rangeIter.next()
+                  if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
+                    store.remove(rowPair.key)
+                    removedValueRow = rowPair.value
+                  }
+                }
+                if (removedValueRow == null) {
+                  finished = true
+                  null
+                } else {
+                  removedValueRow
+                }
+              }
 
-            numTotalStateRows += store.numKeys()
-            store.updates().filter(_.isInstanceOf[ValueRemoved]).map { removed =>
-              numOutputRows += 1
-              removed.value.asInstanceOf[InternalRow]
+              override protected def close(): Unit = {
+                allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs)
+                commitTimeMs += timeTakenMs { store.commit() }
+                numTotalStateRows += store.numKeys()
+              }
             }
 
           // Update and output modified rows from the StateStore.
           case Some(Update) =>
 
+            val updatesStartTimeNs = System.nanoTime
+
             new Iterator[InternalRow] {
 
               // Filter late date using watermark if specified
@@ -223,11 +283,11 @@ case class StateStoreSaveExec(
 
               override def hasNext: Boolean = {
                 if (!baseIterator.hasNext) {
+                  allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
+
                   // Remove old aggregates if watermark specified
-                  if (watermarkPredicateForKeys.nonEmpty) {
-                    store.remove(watermarkPredicateForKeys.get.eval _)
-                  }
-                  store.commit()
+                  allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
+                  commitTimeMs += timeTakenMs { store.commit() }
                   numTotalStateRows += store.numKeys()
                   false
                 } else {
@@ -238,7 +298,7 @@ case class StateStoreSaveExec(
               override def next(): InternalRow = {
                 val row = baseIterator.next().asInstanceOf[UnsafeRow]
                 val key = getKey(row)
-                store.put(key.copy(), row.copy())
+                store.put(key, row)
                 numOutputRows += 1
                 numUpdatedStateRows += 1
                 row
@@ -273,27 +333,34 @@ case class StreamingDeduplicateExec(
     child.execute().mapPartitionsWithStateStore(
       getStateId.checkpointLocation,
       getStateId.operatorId,
+      storeName = "default",
       getStateId.batchId,
       keyExpressions.toStructType,
       child.output.toStructType,
+      indexOrdinal = None,
       sqlContext.sessionState,
       Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
       val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
       val numOutputRows = longMetric("numOutputRows")
       val numTotalStateRows = longMetric("numTotalStateRows")
       val numUpdatedStateRows = longMetric("numUpdatedStateRows")
+      val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
+      val allRemovalsTimeMs = longMetric("allRemovalsTimeMs")
+      val commitTimeMs = longMetric("commitTimeMs")
 
       val baseIterator = watermarkPredicateForData match {
         case Some(predicate) => iter.filter(row => !predicate.eval(row))
         case None => iter
       }
 
+      val updatesStartTimeNs = System.nanoTime
+
       val result = baseIterator.filter { r =>
         val row = r.asInstanceOf[UnsafeRow]
         val key = getKey(row)
         val value = store.get(key)
-        if (value.isEmpty) {
-          store.put(key.copy(), StreamingDeduplicateExec.EMPTY_ROW)
+        if (value == null) {
+          store.put(key, StreamingDeduplicateExec.EMPTY_ROW)
           numUpdatedStateRows += 1
           numOutputRows += 1
           true
@@ -304,8 +371,9 @@ case class StreamingDeduplicateExec(
       }
 
       CompletionIterator[InternalRow, Iterator[InternalRow]](result, {
-        watermarkPredicateForKeys.foreach(f => store.remove(f.eval _))
-        store.commit()
+        allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
+        allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
+        commitTimeMs += timeTakenMs { store.commit() }
         numTotalStateRows += store.numKeys()
       })
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
index bd197be655..4a1a089af5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -38,13 +38,13 @@ import org.apache.spark.util.{CompletionIterator, Utils}
 
 class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll {
 
+  import StateStoreTestsHelper._
+
   private val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName)
-  private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString
+  private val tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString
   private val keySchema = StructType(Seq(StructField("key", StringType, true)))
   private val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
 
-  import StateStoreSuite._
-
   after {
     StateStore.stop()
   }
@@ -60,13 +60,14 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
       val opId = 0
       val rdd1 =
         makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
-            spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(
+            spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(
             increment)
       assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
 
       // Generate next version of stores
       val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore(
-        spark.sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
+        spark.sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)(
+        increment)
       assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
 
       // Make sure the previous RDD still has the same data.
@@ -84,7 +85,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
         storeVersion: Int): RDD[(String, Int)] = {
       implicit val sqlContext = spark.sqlContext
       makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore(
-        sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment)
+        sqlContext, path, opId, "name", storeVersion, keySchema, valueSchema, None)(increment)
     }
 
     // Generate RDDs and state store data
@@ -110,7 +111,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
       def iteratorOfPuts(store: StateStore, iter: Iterator[String]): Iterator[(String, Int)] = {
         val resIterator = iter.map { s =>
           val key = stringToRow(s)
-          val oldValue = store.get(key).map(rowToInt).getOrElse(0)
+          val oldValue = Option(store.get(key)).map(rowToInt).getOrElse(0)
           val newValue = oldValue + 1
           store.put(key, intToRow(newValue))
           (s, newValue)
@@ -125,21 +126,24 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
           iter: Iterator[String]): Iterator[(String, Option[Int])] = {
         iter.map { s =>
           val key = stringToRow(s)
-          val value = store.get(key).map(rowToInt)
+          val value = Option(store.get(key)).map(rowToInt)
           (s, value)
         }
       }
 
       val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore(
-        spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets)
+        spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(
+        iteratorOfGets)
       assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None))
 
       val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
-        sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts)
+        sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(
+        iteratorOfPuts)
       assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1))
 
       val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore(
-        sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets)
+        sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)(
+        iteratorOfGets)
       assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None))
     }
   }
@@ -152,15 +156,16 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
       withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
         implicit val sqlContext = spark.sqlContext
         val coordinatorRef = sqlContext.streams.stateStoreCoordinator
-        coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1")
-        coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2")
+        coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0, "name"), "host1", "exec1")
+        coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1, "name"), "host2", "exec2")
 
         assert(
-          coordinatorRef.getLocation(StateStoreId(path, opId, 0)) ===
+          coordinatorRef.getLocation(StateStoreId(path, opId, 0, "name")) ===
             Some(ExecutorCacheTaskLocation("host1", "exec1").toString))
 
         val rdd = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
-          sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment)
+          sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(
+          increment)
         require(rdd.partitions.length === 2)
 
         assert(
@@ -187,12 +192,12 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
         val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
         val opId = 0
         val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
-          sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment)
+          sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(increment)
         assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
 
         // Generate next version of stores
         val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore(
-          sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
+          sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)(increment)
         assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
 
         // Make sure the previous RDD still has the same data.
@@ -208,7 +213,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
   private val increment = (store: StateStore, iter: Iterator[String]) => {
     iter.foreach { s =>
       val key = stringToRow(s)
-      val oldValue = store.get(key).map(rowToInt).getOrElse(0)
+      val oldValue = Option(store.get(key)).map(rowToInt).getOrElse(0)
       store.put(key, intToRow(oldValue + 1))
     }
     store.commit()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index cc09b2d5b7..af2b9f1c11 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -40,15 +40,15 @@ import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.Utils
 
-class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester {
+class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
+  with BeforeAndAfter with PrivateMethodTester {
   type MapType = mutable.HashMap[UnsafeRow, UnsafeRow]
 
   import StateStoreCoordinatorSuite._
-  import StateStoreSuite._
+  import StateStoreTestsHelper._
 
-  private val tempDir = Utils.createTempDir().toString
-  private val keySchema = StructType(Seq(StructField("key", StringType, true)))
-  private val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
+  val keySchema = StructType(Seq(StructField("key", StringType, true)))
+  val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
 
   before {
     StateStore.stop()
@@ -60,186 +60,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     require(!StateStore.isMaintenanceRunning)
   }
 
-  test("get, put, remove, commit, and all data iterator") {
-    val provider = newStoreProvider()
-
-    // Verify state before starting a new set of updates
-    assert(provider.latestIterator().isEmpty)
-
-    val store = provider.getStore(0)
-    assert(!store.hasCommitted)
-    intercept[IllegalStateException] {
-      store.iterator()
-    }
-    intercept[IllegalStateException] {
-      store.updates()
-    }
-
-    // Verify state after updating
-    put(store, "a", 1)
-    assert(store.numKeys() === 1)
-    intercept[IllegalStateException] {
-      store.iterator()
-    }
-    intercept[IllegalStateException] {
-      store.updates()
-    }
-    assert(provider.latestIterator().isEmpty)
-
-    // Make updates, commit and then verify state
-    put(store, "b", 2)
-    put(store, "aa", 3)
-    assert(store.numKeys() === 3)
-    remove(store, _.startsWith("a"))
-    assert(store.numKeys() === 1)
-    assert(store.commit() === 1)
-
-    assert(store.hasCommitted)
-    assert(rowsToSet(store.iterator()) === Set("b" -> 2))
-    assert(rowsToSet(provider.latestIterator()) === Set("b" -> 2))
-    assert(fileExists(provider, version = 1, isSnapshot = false))
-
-    assert(getDataFromFiles(provider) === Set("b" -> 2))
-
-    // Trying to get newer versions should fail
-    intercept[Exception] {
-      provider.getStore(2)
-    }
-    intercept[Exception] {
-      getDataFromFiles(provider, 2)
-    }
-
-    // New updates to the reloaded store with new version, and does not change old version
-    val reloadedProvider = new HDFSBackedStateStoreProvider(
-      store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration)
-    val reloadedStore = reloadedProvider.getStore(1)
-    assert(reloadedStore.numKeys() === 1)
-    put(reloadedStore, "c", 4)
-    assert(reloadedStore.numKeys() === 2)
-    assert(reloadedStore.commit() === 2)
-    assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
-    assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4))
-    assert(getDataFromFiles(provider, version = 1) === Set("b" -> 2))
-    assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4))
-  }
-
-  test("filter and concurrent updates") {
-    val provider = newStoreProvider()
-
-    // Verify state before starting a new set of updates
-    assert(provider.latestIterator.isEmpty)
-    val store = provider.getStore(0)
-    put(store, "a", 1)
-    put(store, "b", 2)
-
-    // Updates should work while iterating of filtered entries
-    val filtered = store.filter { case (keyRow, _) => rowToString(keyRow) == "a" }
-    filtered.foreach { case (keyRow, valueRow) =>
-      store.put(keyRow, intToRow(rowToInt(valueRow) + 1))
-    }
-    assert(get(store, "a") === Some(2))
-
-    // Removes should work while iterating of filtered entries
-    val filtered2 = store.filter { case (keyRow, _) => rowToString(keyRow) == "b" }
-    filtered2.foreach { case (keyRow, _) =>
-      store.remove(keyRow)
-    }
-    assert(get(store, "b") === None)
-  }
-
-  test("updates iterator with all combos of updates and removes") {
-    val provider = newStoreProvider()
-    var currentVersion: Int = 0
-
-    def withStore(body: StateStore => Unit): Unit = {
-      val store = provider.getStore(currentVersion)
-      body(store)
-      currentVersion += 1
-    }
-
-    // New data should be seen in updates as value added, even if they had multiple updates
-    withStore { store =>
-      put(store, "a", 1)
-      put(store, "aa", 1)
-      put(store, "aa", 2)
-      store.commit()
-      assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2)))
-      assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2))
-    }
-
-    // Multiple updates to same key should be collapsed in the updates as a single value update
-    // Keys that have not been updated should not appear in the updates
-    withStore { store =>
-      put(store, "a", 4)
-      put(store, "a", 6)
-      store.commit()
-      assert(updatesToSet(store.updates()) === Set(Updated("a", 6)))
-      assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2))
-    }
-
-    // Keys added, updated and finally removed before commit should not appear in updates
-    withStore { store =>
-      put(store, "b", 4)     // Added, finally removed
-      put(store, "bb", 5)    // Added, updated, finally removed
-      put(store, "bb", 6)
-      remove(store, _.startsWith("b"))
-      store.commit()
-      assert(updatesToSet(store.updates()) === Set.empty)
-      assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2))
-    }
-
-    // Removed data should be seen in updates as a key removed
-    // Removed, but re-added data should be seen in updates as a value update
-    withStore { store =>
-      remove(store, _.startsWith("a"))
-      put(store, "a", 10)
-      store.commit()
-      assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa")))
-      assert(rowsToSet(store.iterator()) === Set("a" -> 10))
-    }
-  }
-
-  test("cancel") {
-    val provider = newStoreProvider()
-    val store = provider.getStore(0)
-    put(store, "a", 1)
-    store.commit()
-    assert(rowsToSet(store.iterator()) === Set("a" -> 1))
-
-    // cancelUpdates should not change the data in the files
-    val store1 = provider.getStore(1)
-    put(store1, "b", 1)
-    store1.abort()
-    assert(getDataFromFiles(provider) === Set("a" -> 1))
-  }
-
-  test("getStore with unexpected versions") {
-    val provider = newStoreProvider()
-
-    intercept[IllegalArgumentException] {
-      provider.getStore(-1)
-    }
-
-    // Prepare some data in the store
-    val store = provider.getStore(0)
-    put(store, "a", 1)
-    assert(store.commit() === 1)
-    assert(rowsToSet(store.iterator()) === Set("a" -> 1))
-
-    intercept[IllegalStateException] {
-      provider.getStore(2)
-    }
-
-    // Update store version with some data
-    val store1 = provider.getStore(1)
-    put(store1, "b", 1)
-    assert(store1.commit() === 2)
-    assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1))
-    assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1))
-  }
-
   test("snapshotting") {
-    val provider = newStoreProvider(minDeltasForSnapshot = 5)
+    val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5)
 
     var currentVersion = 0
     def updateVersionTo(targetVersion: Int): Unit = {
@@ -253,9 +75,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     }
 
     updateVersionTo(2)
-    require(getDataFromFiles(provider) === Set("a" -> 2))
+    require(getData(provider) === Set("a" -> 2))
     provider.doMaintenance()               // should not generate snapshot files
-    assert(getDataFromFiles(provider) === Set("a" -> 2))
+    assert(getData(provider) === Set("a" -> 2))
 
     for (i <- 1 to currentVersion) {
       assert(fileExists(provider, i, isSnapshot = false))  // all delta files present
@@ -264,22 +86,22 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
 
     // After version 6, snapshotting should generate one snapshot file
     updateVersionTo(6)
-    require(getDataFromFiles(provider) === Set("a" -> 6), "store not updated correctly")
+    require(getData(provider) === Set("a" -> 6), "store not updated correctly")
     provider.doMaintenance()       // should generate snapshot files
 
     val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true))
     assert(snapshotVersion.nonEmpty, "snapshot file not generated")
     deleteFilesEarlierThanVersion(provider, snapshotVersion.get)
     assert(
-      getDataFromFiles(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get),
+      getData(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get),
       "snapshotting messed up the data of the snapshotted version")
     assert(
-      getDataFromFiles(provider) === Set("a" -> 6),
+      getData(provider) === Set("a" -> 6),
       "snapshotting messed up the data of the final version")
 
     // After version 20, snapshotting should generate newer snapshot files
     updateVersionTo(20)
-    require(getDataFromFiles(provider) === Set("a" -> 20), "store not updated correctly")
+    require(getData(provider) === Set("a" -> 20), "store not updated correctly")
     provider.doMaintenance()       // do snapshot
 
     val latestSnapshotVersion = (0 to 20).filter(version =>
@@ -288,11 +110,11 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated")
 
     deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get)
-    assert(getDataFromFiles(provider) === Set("a" -> 20), "snapshotting messed up the data")
+    assert(getData(provider) === Set("a" -> 20), "snapshotting messed up the data")
   }
 
   test("cleaning") {
-    val provider = newStoreProvider(minDeltasForSnapshot = 5)
+    val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5)
 
     for (i <- 1 to 20) {
       val store = provider.getStore(i - 1)
@@ -307,8 +129,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted
 
     // last couple of versions should be retrievable
-    assert(getDataFromFiles(provider, 20) === Set("a" -> 20))
-    assert(getDataFromFiles(provider, 19) === Set("a" -> 19))
+    assert(getData(provider, 20) === Set("a" -> 20))
+    assert(getData(provider, 19) === Set("a" -> 19))
   }
 
   test("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") {
@@ -316,7 +138,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName)
     conf.set("fs.defaultFS", "fake:///")
 
-    val provider = newStoreProvider(hadoopConf = conf)
+    val provider = newStoreProvider(opId = Random.nextInt, partition = 0, hadoopConf = conf)
     provider.getStore(0).commit()
     provider.getStore(0).commit()
 
@@ -327,7 +149,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
   }
 
   test("corrupted file handling") {
-    val provider = newStoreProvider(minDeltasForSnapshot = 5)
+    val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5)
     for (i <- 1 to 6) {
       val store = provider.getStore(i - 1)
       put(store, "a", i)
@@ -338,62 +160,75 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
       fileExists(provider, version, isSnapshot = true)).getOrElse(fail("snapshot file not found"))
 
     // Corrupt snapshot file and verify that it throws error
-    assert(getDataFromFiles(provider, snapshotVersion) === Set("a" -> snapshotVersion))
+    assert(getData(provider, snapshotVersion) === Set("a" -> snapshotVersion))
     corruptFile(provider, snapshotVersion, isSnapshot = true)
     intercept[Exception] {
-      getDataFromFiles(provider, snapshotVersion)
+      getData(provider, snapshotVersion)
     }
 
     // Corrupt delta file and verify that it throws error
-    assert(getDataFromFiles(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1)))
+    assert(getData(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1)))
     corruptFile(provider, snapshotVersion - 1, isSnapshot = false)
     intercept[Exception] {
-      getDataFromFiles(provider, snapshotVersion - 1)
+      getData(provider, snapshotVersion - 1)
     }
 
     // Delete delta file and verify that it throws error
     deleteFilesEarlierThanVersion(provider, snapshotVersion)
     intercept[Exception] {
-      getDataFromFiles(provider, snapshotVersion - 1)
+      getData(provider, snapshotVersion - 1)
     }
   }
 
   test("StateStore.get") {
     quietly {
-      val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString
+      val dir = newDir()
       val storeId = StateStoreId(dir, 0, 0)
       val storeConf = StateStoreConf.empty
       val hadoopConf = new Configuration()
 
-
       // Verify that trying to get incorrect versions throw errors
       intercept[IllegalArgumentException] {
-        StateStore.get(storeId, keySchema, valueSchema, -1, storeConf, hadoopConf)
+        StateStore.get(
+          storeId, keySchema, valueSchema, None, -1, storeConf, hadoopConf)
       }
       assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store
 
       intercept[IllegalStateException] {
-        StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf)
+        StateStore.get(
+          storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf)
       }
 
-      // Increase version of the store
-      val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf)
+      // Increase version of the store and try to get again
+      val store0 = StateStore.get(
+        storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf)
       assert(store0.version === 0)
       put(store0, "a", 1)
       store0.commit()
 
-      assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1)
-      assert(StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf).version == 0)
+      val store1 = StateStore.get(
+        storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf)
+      assert(StateStore.isLoaded(storeId))
+      assert(store1.version === 1)
+      assert(rowsToSet(store1.iterator()) === Set("a" -> 1))
+
+      // Verify that you can also load older version
+      val store0reloaded = StateStore.get(
+        storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf)
+      assert(store0reloaded.version === 0)
+      assert(rowsToSet(store0reloaded.iterator()) === Set.empty)
 
       // Verify that you can remove the store and still reload and use it
       StateStore.unload(storeId)
       assert(!StateStore.isLoaded(storeId))
 
-      val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf)
+      val store1reloaded = StateStore.get(
+        storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf)
       assert(StateStore.isLoaded(storeId))
-      put(store1, "a", 2)
-      assert(store1.commit() === 2)
-      assert(rowsToSet(store1.iterator()) === Set("a" -> 2))
+      assert(store1reloaded.version === 1)
+      put(store1reloaded, "a", 2)
+      assert(store1reloaded.commit() === 2)
+      assert(rowsToSet(store1reloaded.iterator()) === Set("a" -> 2))
     }
   }
 
@@ -407,21 +242,20 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
       // fails to talk to the StateStoreCoordinator and unloads all the StateStores
       .set("spark.rpc.numRetries", "1")
     val opId = 0
-    val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString
+    val dir = newDir()
     val storeId = StateStoreId(dir, opId, 0)
     val sqlConf = new SQLConf()
     sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
     val storeConf = StateStoreConf(sqlConf)
     val hadoopConf = new Configuration()
-    val provider = new HDFSBackedStateStoreProvider(
-      storeId, keySchema, valueSchema, storeConf, hadoopConf)
+    val provider = newStoreProvider(storeId)
 
     var latestStoreVersion = 0
 
     def generateStoreVersions() {
       for (i <- 1 to 20) {
-        val store = StateStore.get(
-          storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf)
+        val store = StateStore.get(storeId, keySchema, valueSchema, None,
+          latestStoreVersion, storeConf, hadoopConf)
         put(store, "a", i)
         store.commit()
         latestStoreVersion += 1
@@ -469,7 +303,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
           }
 
           // Reload the store and verify
-          StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf)
+          StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None,
+            latestStoreVersion, storeConf, hadoopConf)
           assert(StateStore.isLoaded(storeId))
 
           // If some other executor loads the store, then this instance should be unloaded
@@ -479,7 +314,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
           }
 
           // Reload the store and verify
-          StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf)
+          StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None,
+            latestStoreVersion, storeConf, hadoopConf)
           assert(StateStore.isLoaded(storeId))
         }
       }
@@ -495,10 +331,11 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
 
   test("SPARK-18342: commit fails when rename fails") {
     import RenameReturnsFalseFileSystem._
-    val dir = scheme + "://" + Utils.createDirectory(tempDir, Random.nextString(5)).toURI.getPath
+    val dir = scheme + "://" + newDir()
     val conf = new Configuration()
     conf.set(s"fs.$scheme.impl", classOf[RenameReturnsFalseFileSystem].getName)
-    val provider = newStoreProvider(dir = dir, hadoopConf = conf)
+    val provider = newStoreProvider(
+      opId = Random.nextInt, partition = 0, dir = dir, hadoopConf = conf)
     val store = provider.getStore(0)
     put(store, "a", 0)
     val e = intercept[IllegalStateException](store.commit())
@@ -506,7 +343,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
   }
 
   test("SPARK-18416: do not create temp delta file until the store is updated") {
-    val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString
+    val dir = newDir()
     val storeId = StateStoreId(dir, 0, 0)
     val storeConf = StateStoreConf.empty
     val hadoopConf = new Configuration()
@@ -533,7 +370,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
 
     // Getting the store should not create temp file
     val store0 = shouldNotCreateTempFile {
-      StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf)
+      StateStore.get(
+        storeId, keySchema, valueSchema, indexOrdinal = None, version = 0, storeConf, hadoopConf)
     }
 
     // Put should create a temp file
@@ -548,7 +386,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
 
     // Remove should create a temp file
     val store1 = shouldNotCreateTempFile {
-      StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf)
+      StateStore.get(
+        storeId, keySchema, valueSchema, indexOrdinal = None, version = 1, storeConf, hadoopConf)
     }
     remove(store1, _ == "a")
     assert(numTempFiles === 1)
@@ -561,31 +400,55 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
 
     // Commit without any updates should create a delta file
     val store2 = shouldNotCreateTempFile {
-      StateStore.get(storeId, keySchema, valueSchema, 2, storeConf, hadoopConf)
+      StateStore.get(
+        storeId, keySchema, valueSchema, indexOrdinal = None, version = 2, storeConf, hadoopConf)
     }
     store2.commit()
     assert(numTempFiles === 0)
     assert(numDeltaFiles === 3)
   }
 
-  def getDataFromFiles(
-      provider: HDFSBackedStateStoreProvider,
+  override def newStoreProvider(): HDFSBackedStateStoreProvider = {
+    newStoreProvider(opId = Random.nextInt(), partition = 0)
+  }
+
+  override def newStoreProvider(storeId: StateStoreId): HDFSBackedStateStoreProvider = {
+    newStoreProvider(storeId.operatorId, storeId.partitionId, dir = storeId.checkpointLocation)
+  }
+
+  override def getLatestData(storeProvider: HDFSBackedStateStoreProvider): Set[(String, Int)] = {
+    getData(storeProvider)
+  }
+
+  override def getData(
+    provider: HDFSBackedStateStoreProvider,
     version: Int = -1): Set[(String, Int)] = {
-    val reloadedProvider = new HDFSBackedStateStoreProvider(
-      provider.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration)
+    val reloadedProvider = newStoreProvider(provider.id)
     if (version < 0) {
       reloadedProvider.latestIterator().map(rowsToStringInt).toSet
     } else {
-      reloadedProvider.iterator(version).map(rowsToStringInt).toSet
+      reloadedProvider.getStore(version).iterator().map(rowsToStringInt).toSet
     }
   }
 
-  def assertMap(
-      testMapOption: Option[MapType],
-      expectedMap: Map[String, Int]): Unit = {
-    assert(testMapOption.nonEmpty, "no map present")
-    val convertedMap = testMapOption.get.map(rowsToStringInt)
-    assert(convertedMap === expectedMap)
+  def newStoreProvider(
+      opId: Long,
+      partition: Int,
+      dir: String = newDir(),
+      minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
+      hadoopConf: Configuration = new Configuration): HDFSBackedStateStoreProvider = {
+    val sqlConf = new SQLConf()
+    sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot)
+    sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
+    val provider = new HDFSBackedStateStoreProvider()
+    provider.init(
+      StateStoreId(dir, opId, partition),
+      keySchema,
+      valueSchema,
+      indexOrdinal = None,
+      new StateStoreConf(sqlConf),
+      hadoopConf)
+    provider
   }
 
   def fileExists(
@@ -622,56 +485,150 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     filePath.delete()
     filePath.createNewFile()
   }
+}
 
-  def storeLoaded(storeId: StateStoreId): Boolean = {
-    val method = PrivateMethod[mutable.HashMap[StateStoreId, StateStore]]('loadedStores)
-    val loadedStores = StateStore invokePrivate method()
-    loadedStores.contains(storeId)
-  }
+abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
+  extends SparkFunSuite {
+  import StateStoreTestsHelper._
 
-  def unloadStore(storeId: StateStoreId): Boolean = {
-    val method = PrivateMethod('remove)
-    StateStore invokePrivate method(storeId)
-  }
+  test("get, put, remove, commit, and all data iterator") {
+    val provider = newStoreProvider()
 
-  def newStoreProvider(
-      opId: Long = Random.nextLong,
-      partition: Int = 0,
-      minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
-      dir: String = Utils.createDirectory(tempDir, Random.nextString(5)).toString,
-      hadoopConf: Configuration = new Configuration()
-    ): HDFSBackedStateStoreProvider = {
-    val sqlConf = new SQLConf()
-    sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot)
-    sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
-    new HDFSBackedStateStoreProvider(
-      StateStoreId(dir, opId, partition),
-      keySchema,
-      valueSchema,
-      new StateStoreConf(sqlConf),
-      hadoopConf)
+    // Verify state before starting a new set of updates
+    assert(getLatestData(provider).isEmpty)
+
+    val store = provider.getStore(0)
+    assert(!store.hasCommitted)
+    assert(get(store, "a") === None)
+    assert(store.iterator().isEmpty)
+    assert(store.numKeys() === 0)
+
+    // Verify state after updating
+    put(store, "a", 1)
+    assert(get(store, "a") === Some(1))
+    assert(store.numKeys() === 1)
+
+    assert(store.iterator().nonEmpty)
+    assert(getLatestData(provider).isEmpty)
+
+    // Make updates, commit and then verify state
+    put(store, "b", 2)
+    put(store, "aa", 3)
+    assert(store.numKeys() === 3)
+    remove(store, _.startsWith("a"))
+    assert(store.numKeys() === 1)
+    assert(store.commit() === 1)
+
+    assert(store.hasCommitted)
+    assert(rowsToSet(store.iterator()) === Set("b" -> 2))
+    assert(getLatestData(provider) === Set("b" -> 2))
+
+    // Trying to get newer versions should fail
+    intercept[Exception] {
+      provider.getStore(2)
+    }
+    intercept[Exception] {
+      getData(provider, 2)
+    }
+
+    // New updates to the reloaded store with new version, and does not change old version
+    val reloadedProvider = newStoreProvider(store.id)
+    val reloadedStore = reloadedProvider.getStore(1)
+    assert(reloadedStore.numKeys() === 1)
+    put(reloadedStore, "c", 4)
+    assert(reloadedStore.numKeys() === 2)
+    assert(reloadedStore.commit() === 2)
+    assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
+    assert(getLatestData(provider) === Set("b" -> 2, "c" -> 4))
+    assert(getData(provider, version = 1) === Set("b" -> 2))
   }
 
-  def remove(store: StateStore, condition: String => Boolean): Unit = {
-    store.remove(row => condition(rowToString(row)))
+  test("removing while iterating") {
+    val provider = newStoreProvider()
+
+    // Verify state before starting a new set of updates
+    assert(getLatestData(provider).isEmpty)
+    val store = provider.getStore(0)
+    put(store, "a", 1)
+    put(store, "b", 2)
+
+    // Updates should work while iterating of filtered entries
+    val filtered = store.iterator.filter { tuple => rowToString(tuple.key) == "a" }
+    filtered.foreach { tuple =>
+      store.put(tuple.key, intToRow(rowToInt(tuple.value) + 1))
+    }
+    assert(get(store, "a") === Some(2))
+
+    // Removes should work while iterating of filtered entries
+    val filtered2 = store.iterator.filter { tuple => rowToString(tuple.key) == "b" }
+    filtered2.foreach { tuple => store.remove(tuple.key) }
+    assert(get(store, "b") === None)
   }
 
-  private def put(store: StateStore, key: String, value: Int): Unit = {
-    store.put(stringToRow(key), intToRow(value))
+  test("abort") {
+    val provider = newStoreProvider()
+    val store = provider.getStore(0)
+    put(store, "a", 1)
+    store.commit()
+    assert(rowsToSet(store.iterator()) === Set("a" -> 1))
+
+    // cancelUpdates should not change the data in the files
+    val store1 = provider.getStore(1)
+    put(store1, "b", 1)
+    store1.abort()
   }
 
-  private def get(store: StateStore, key: String): Option[Int] = {
-    store.get(stringToRow(key)).map(rowToInt)
+  test("getStore with invalid versions") {
+    val provider = newStoreProvider()
+
+    def checkInvalidVersion(version: Int): Unit = {
+      intercept[Exception] {
+        provider.getStore(version)
+      }
+    }
+
+    checkInvalidVersion(-1)
+    checkInvalidVersion(1)
+
+    val store = provider.getStore(0)
+    put(store, "a", 1)
+    assert(store.commit() === 1)
+    assert(rowsToSet(store.iterator()) === Set("a" -> 1))
+
+    val store1_ = provider.getStore(1)
+    assert(rowsToSet(store1_.iterator()) === Set("a" -> 1))
+
+    checkInvalidVersion(-1)
+    checkInvalidVersion(2)
+
+    // Update store version with some data
+    val store1 = provider.getStore(1)
+    assert(rowsToSet(store1.iterator()) === Set("a" -> 1))
+    put(store1, "b", 1)
+    assert(store1.commit() === 2)
+    assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1))
+
+    checkInvalidVersion(-1)
+    checkInvalidVersion(3)
   }
-}
 
-private[state] object StateStoreSuite {
+  /** Return a new provider with a random id */
+  def newStoreProvider(): ProviderClass
+
+  /** Return a new provider with the given id */
+  def newStoreProvider(storeId: StateStoreId): ProviderClass
+
+  /** Get the latest data referred to by the given provider but not using this provider */
+  def getLatestData(storeProvider: ProviderClass): Set[(String, Int)]
+
+  /**
+   * Get a specific version of data referred to by the given provider but not using
+   * this provider
+   */
+  def getData(storeProvider: ProviderClass, version: Int): Set[(String, Int)]
+}
 
-  /** Trait and classes mirroring [[StoreUpdate]] for testing store updates iterator */
-  trait TestUpdate
-  case class Added(key: String, value: Int) extends TestUpdate
-  case class Updated(key: String, value: Int) extends TestUpdate
-  case class Removed(key: String) extends TestUpdate
+object StateStoreTestsHelper {
 
   val strProj = UnsafeProjection.create(Array[DataType](StringType))
   val intProj = UnsafeProjection.create(Array[DataType](IntegerType))
@@ -692,26 +649,29 @@ private[state] object StateStoreSuite {
     row.getInt(0)
   }
 
-  def rowsToIntInt(row: (UnsafeRow, UnsafeRow)): (Int, Int) = {
-    (rowToInt(row._1), rowToInt(row._2))
+  def rowsToStringInt(row: UnsafeRowPair): (String, Int) = {
+    (rowToString(row.key), rowToInt(row.value))
   }
 
+  def rowsToSet(iterator: Iterator[UnsafeRowPair]): Set[(String, Int)] = {
+    iterator.map(rowsToStringInt).toSet
+  }
 
-  def rowsToStringInt(row: (UnsafeRow, UnsafeRow)): (String, Int) = {
-    (rowToString(row._1), rowToInt(row._2))
+  def remove(store: StateStore, condition: String => Boolean): Unit = {
+    store.getRange(None, None).foreach { rowPair =>
+      if (condition(rowToString(rowPair.key))) store.remove(rowPair.key)
+    }
   }
 
-  def rowsToSet(iterator: Iterator[(UnsafeRow, UnsafeRow)]): Set[(String, Int)] = {
-    iterator.map(rowsToStringInt).toSet
+  def put(store: StateStore, key: String, value: Int): Unit = {
+    store.put(stringToRow(key), intToRow(value))
   }
 
-  def updatesToSet(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = {
-    iterator.map {
-      case ValueAdded(key, value) => Added(rowToString(key), rowToInt(value))
-      case ValueUpdated(key, value) => Updated(rowToString(key), rowToInt(value))
-      case ValueRemoved(key, _) => Removed(rowToString(key))
-    }.toSet
+  def get(store: StateStore, key: String): Option[Int] = {
+    Option(store.get(stringToRow(key))).map(rowToInt)
   }
+
+  def newDir(): String = Utils.createTempDir().toString
 }
 
 /**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 6bb9408ce9..0d9ca81349 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.execution.RDDScanExec
 import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream}
-import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StoreUpdate}
+import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, UnsafeRowPair}
 import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore
 import org.apache.spark.sql.streaming.util.StreamManualClock
 import org.apache.spark.sql.types.{DataType, IntegerType}
@@ -508,22 +508,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
     expectedState = Some(5),                                  // state should change
     expectedTimeoutTimestamp = 5000)                          // timestamp should change
 
-  test("StateStoreUpdater - rows are cloned before writing to StateStore") {
-    // function for running count
-    val func = (key: Int, values: Iterator[Int], state: GroupState[Int]) => {
-      state.update(state.getOption.getOrElse(0) + values.size)
-      Iterator.empty
-    }
-    val store = newStateStore()
-    val plan = newFlatMapGroupsWithStateExec(func)
-    val updater = new plan.StateStoreUpdater(store)
-    val data = Seq(1, 1, 2)
-    val returnIter = updater.updateStateForKeysWithData(data.iterator.map(intToRow))
-    returnIter.size // consume the iterator to force store updates
-    val storeData = store.iterator.map { case (k, v) => (rowToInt(k), rowToInt(v)) }.toSet
-    assert(storeData === Set((1, 2), (2, 1)))
-  }
-
   test("flatMapGroupsWithState - streaming") {
     // Function to maintain running count up to 2, and then remove the count
     // Returns the data and the count if state is defined, otherwise does not return anything
@@ -1016,11 +1000,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
       callFunction()
       val updatedStateRow = store.get(key)
       assert(
-        updater.getStateObj(updatedStateRow).map(_.toString.toInt) === expectedState,
+        Option(updater.getStateObj(updatedStateRow)).map(_.toString.toInt) === expectedState,
         "final state not as expected")
-      if (updatedStateRow.nonEmpty) {
+      if (updatedStateRow != null) {
         assert(
-          updater.getTimeoutTimestamp(updatedStateRow.get) === expectedTimeoutTimestamp,
+          updater.getTimeoutTimestamp(updatedStateRow) === expectedTimeoutTimestamp,
           "final timeout timestamp not as expected")
       }
     }
@@ -1080,25 +1064,19 @@ object FlatMapGroupsWithStateSuite {
     import scala.collection.JavaConverters._
     private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow]
 
-    override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = {
-      map.entrySet.iterator.asScala.map { case e => (e.getKey, e.getValue) }
+    override def iterator(): Iterator[UnsafeRowPair] = {
+      map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) }
     }
 
-    override def filter(c: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = {
-      iterator.filter { case (k, v) => c(k, v) }
+    override def get(key: UnsafeRow): UnsafeRow = map.get(key)
+    override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = {
+      map.put(key.copy(), newValue.copy())
     }
-
-    override def get(key: UnsafeRow): Option[UnsafeRow] = Option(map.get(key))
-    override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = map.put(key, newValue)
     override def remove(key: UnsafeRow): Unit = { map.remove(key) }
-    override def remove(condition: (UnsafeRow) => Boolean): Unit = {
-      iterator.map(_._1).filter(condition).foreach(map.remove)
-    }
     override def commit(): Long = version + 1
     override def abort(): Unit = { }
     override def id: StateStoreId = null
     override def version: Long = 0
-    override def updates(): Iterator[StoreUpdate] = { throw new UnsupportedOperationException }
     override def numKeys(): Long = map.size
     override def hasCommitted: Boolean = true
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index 1fc062974e..280f2dc27b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -24,6 +24,7 @@ import scala.reflect.ClassTag
 import scala.util.control.ControlThrowable
 
 import org.apache.commons.io.FileUtils
+import org.apache.hadoop.conf.Configuration
 
 import org.apache.spark.SparkContext
 import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
@@ -31,6 +32,7 @@ import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
 import org.apache.spark.sql.execution.command.ExplainCommand
 import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources.StreamSourceProvider
@@ -614,6 +616,30 @@ class StreamSuite extends StreamTest {
     assertDescContainsQueryNameAnd(batch = 2)
     query.stop()
   }
+
+  testQuietly("specify custom state store provider") {
+    val queryName = "memStream"
+    val providerClassName = classOf[TestStateStoreProvider].getCanonicalName
+    withSQLConf("spark.sql.streaming.stateStore.providerClass" -> providerClassName) {
+      val input = MemoryStream[Int]
+      val query = input
+        .toDS()
+        .groupBy()
+        .count()
+        .writeStream
+        .outputMode("complete")
+        .format("memory")
+        .queryName(queryName)
+        .start()
+      input.addData(1, 2, 3)
+      val e = intercept[Exception] {
+        query.awaitTermination()
+      }
+
+      assert(e.getMessage.contains(providerClassName))
+      assert(e.getMessage.contains("instantiated"))
+    }
+  }
 }
 
 abstract class FakeSource extends StreamSourceProvider {
@@ -719,3 +745,22 @@ object ThrowingInterruptedIOException {
    */
   @volatile var createSourceLatch: CountDownLatch = null
 }
+
+class TestStateStoreProvider extends StateStoreProvider {
+
+  override def init(
+      stateStoreId: StateStoreId,
+      keySchema: StructType,
+      valueSchema: StructType,
+      indexOrdinal: Option[Int],
+      storeConfs: StateStoreConf,
+      hadoopConf: Configuration): Unit = {
+    throw new Exception("Successfully instantiated")
+  }
+
+  override def id: StateStoreId = null
+
+  override def close(): Unit = { }
+
+  override def getStore(version: Long): StateStore = null
+}
-- 
GitLab