diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index c5d69c204642ede6e840a60e650a1708e9f29aea..c6f5cf641b8d535a552f7712b364efc16c4adb54 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -552,6 +552,15 @@ object SQLConf {
     .booleanConf
     .createWithDefault(true)
 
+  val STATE_STORE_PROVIDER_CLASS =
+    buildConf("spark.sql.streaming.stateStore.providerClass")
+      .internal()
+      .doc(
+        "The class used to manage state data in stateful streaming queries. This class must " +
+          "be a subclass of StateStoreProvider, and must have a zero-arg constructor.")
+      .stringConf
+      .createOptional
+
   val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT =
     buildConf("spark.sql.streaming.stateStore.minDeltasForSnapshot")
       .internal()
@@ -828,6 +837,8 @@ class SQLConf extends Serializable with Logging {
 
   def optimizerInSetConversionThreshold: Int = getConf(OPTIMIZER_INSET_CONVERSION_THRESHOLD)
 
+  def stateStoreProviderClass: Option[String] = getConf(STATE_STORE_PROVIDER_CLASS)
+
   def stateStoreMinDeltasForSnapshot: Int = getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)
 
   def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 3ceb4cf84a413ae7942ae9dfd73a3da2991866a1..2aad8701a4eca7a799a0507f058bcf2f510381b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -109,9 +109,11 @@ case class FlatMapGroupsWithStateExec(
     child.execute().mapPartitionsWithStateStore[InternalRow](
       getStateId.checkpointLocation,
       getStateId.operatorId,
+      storeName = "default",
       getStateId.batchId,
       groupingAttributes.toStructType,
       stateAttributes.toStructType,
+      indexOrdinal = None,
       sqlContext.sessionState,
       Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
         val updater = new StateStoreUpdater(store)
@@ -191,12 +193,12 @@ case class FlatMapGroupsWithStateExec(
             throw new IllegalStateException(
               s"Cannot filter timed out keys for $timeoutConf")
         }
-        val timingOutKeys = store.filter { case (_, stateRow) =>
-          val timeoutTimestamp = getTimeoutTimestamp(stateRow)
+        val timingOutKeys = store.getRange(None, None).filter { rowPair =>
+          val timeoutTimestamp = getTimeoutTimestamp(rowPair.value)
           timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold
         }
-        timingOutKeys.flatMap { case (keyRow, stateRow) =>
-          callFunctionAndUpdateState(keyRow, Iterator.empty, Some(stateRow), hasTimedOut = true)
+        timingOutKeys.flatMap { rowPair =>
+          callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true)
         }
       } else Iterator.empty
     }
@@ -205,18 +207,23 @@ case class FlatMapGroupsWithStateExec(
      * Call the user function on a key's data, update the state store, and return the return data
      * iterator. Note that the store updating is lazy, that is, the store will be updated only
      * after the returned iterator is fully consumed.
+     *
+     * @param keyRow Row representing the key, cannot be null
+     * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty
+     * @param prevStateRow Row representing the previous state, can be null
+     * @param hasTimedOut Whether this function is being called for a key timeout
      */
     private def callFunctionAndUpdateState(
         keyRow: UnsafeRow,
         valueRowIter: Iterator[InternalRow],
-        prevStateRowOption: Option[UnsafeRow],
+        prevStateRow: UnsafeRow,
         hasTimedOut: Boolean): Iterator[InternalRow] = {
 
       val keyObj = getKeyObj(keyRow)  // convert key to objects
       val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
-      val stateObjOption = getStateObj(prevStateRowOption)
+      val stateObj = getStateObj(prevStateRow)
       val keyedState = GroupStateImpl.createForStreaming(
-        stateObjOption,
+        Option(stateObj),
         batchTimestampMs.getOrElse(NO_TIMESTAMP),
         eventTimeWatermark.getOrElse(NO_TIMESTAMP),
         timeoutConf,
@@ -249,14 +256,11 @@ case class FlatMapGroupsWithStateExec(
           numUpdatedStateRows += 1
 
         } else {
-          val previousTimeoutTimestamp = prevStateRowOption match {
-            case Some(row) => getTimeoutTimestamp(row)
-            case None => NO_TIMESTAMP
-          }
+          val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow)
           val stateRowToWrite = if (keyedState.hasUpdated) {
             getStateRow(keyedState.get)
           } else {
-            prevStateRowOption.orNull
+            prevStateRow
           }
 
           val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp
@@ -269,7 +273,7 @@ case class FlatMapGroupsWithStateExec(
               throw new IllegalStateException("Attempting to write empty state")
             }
             setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp)
-            store.put(keyRow.copy(), stateRowToWrite.copy())
+            store.put(keyRow, stateRowToWrite)
             numUpdatedStateRows += 1
           }
         }
@@ -280,18 +284,21 @@ case class FlatMapGroupsWithStateExec(
     }
 
     /** Returns the state as Java object if defined */
-    def getStateObj(stateRowOption: Option[UnsafeRow]): Option[Any] = {
-      stateRowOption.map(getStateObjFromRow)
+    def getStateObj(stateRow: UnsafeRow): Any = {
+      if (stateRow != null) getStateObjFromRow(stateRow) else null
     }
 
     /** Returns the row for an updated state */
     def getStateRow(obj: Any): UnsafeRow = {
+      assert(obj != null)
       getStateRowFromObj(obj)
     }
 
     /** Returns the timeout timestamp of a state row is set */
     def getTimeoutTimestamp(stateRow: UnsafeRow): Long = {
-      if (isTimeoutEnabled) stateRow.getLong(timeoutTimestampIndex) else NO_TIMESTAMP
+      if (isTimeoutEnabled && stateRow != null) {
+        stateRow.getLong(timeoutTimestampIndex)
+      } else NO_TIMESTAMP
     }
 
     /** Set the timestamp in a state row */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index fb2bf47d6e83bfaab6648d7575c62b15ed9d042a..67d86daf10812e668d0163b65d3af75110305962 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -67,13 +67,7 @@ import org.apache.spark.util.Utils
  * to ensure re-executed RDD operations re-apply updates on the correct past version of the
  * store.
  */
-private[state] class HDFSBackedStateStoreProvider(
-    val id: StateStoreId,
-    keySchema: StructType,
-    valueSchema: StructType,
-    storeConf: StateStoreConf,
-    hadoopConf: Configuration
-  ) extends StateStoreProvider with Logging {
+private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider with Logging {
 
   // ConcurrentHashMap is used because it generates fail-safe iterators on filtering
   // - The iterator is weakly consistent with the map, i.e., iterator's data reflect the values in
@@ -95,92 +89,36 @@ private[state] class HDFSBackedStateStoreProvider(
     private val newVersion = version + 1
     private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}")
     private lazy val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true))
-    private val allUpdates = new java.util.HashMap[UnsafeRow, StoreUpdate]()
-
     @volatile private var state: STATE = UPDATING
     @volatile private var finalDeltaFile: Path = null
 
     override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id
 
-    override def get(key: UnsafeRow): Option[UnsafeRow] = {
-      Option(mapToUpdate.get(key))
-    }
-
-    override def filter(
-        condition: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = {
-      mapToUpdate
-        .entrySet
-        .asScala
-        .iterator
-        .filter { entry => condition(entry.getKey, entry.getValue) }
-        .map { entry => (entry.getKey, entry.getValue) }
+    override def get(key: UnsafeRow): UnsafeRow = {
+      mapToUpdate.get(key)
     }
 
     override def put(key: UnsafeRow, value: UnsafeRow): Unit = {
       verify(state == UPDATING, "Cannot put after already committed or aborted")
-
-      val isNewKey = !mapToUpdate.containsKey(key)
-      mapToUpdate.put(key, value)
-
-      Option(allUpdates.get(key)) match {
-        case Some(ValueAdded(_, _)) =>
-          // Value did not exist in previous version and was added already, keep it marked as added
-          allUpdates.put(key, ValueAdded(key, value))
-        case Some(ValueUpdated(_, _)) | Some(ValueRemoved(_, _)) =>
-          // Value existed in previous version and updated/removed, mark it as updated
-          allUpdates.put(key, ValueUpdated(key, value))
-        case None =>
-          // There was no prior update, so mark this as added or updated according to its presence
-          // in previous version.
-          val update = if (isNewKey) ValueAdded(key, value) else ValueUpdated(key, value)
-          allUpdates.put(key, update)
-      }
-      writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value))
+      val keyCopy = key.copy()
+      val valueCopy = value.copy()
+      mapToUpdate.put(keyCopy, valueCopy)
+      writeUpdateToDeltaFile(tempDeltaFileStream, keyCopy, valueCopy)
     }
 
-    /** Remove keys that match the following condition */
-    override def remove(condition: UnsafeRow => Boolean): Unit = {
+    override def remove(key: UnsafeRow): Unit = {
       verify(state == UPDATING, "Cannot remove after already committed or aborted")
-      val entryIter = mapToUpdate.entrySet().iterator()
-      while (entryIter.hasNext) {
-        val entry = entryIter.next
-        if (condition(entry.getKey)) {
-          val value = entry.getValue
-          val key = entry.getKey
-          entryIter.remove()
-
-          Option(allUpdates.get(key)) match {
-            case Some(ValueUpdated(_, _)) | None =>
-              // Value existed in previous version and maybe was updated, mark removed
-              allUpdates.put(key, ValueRemoved(key, value))
-            case Some(ValueAdded(_, _)) =>
-              // Value did not exist in previous version and was added, should not appear in updates
-              allUpdates.remove(key)
-            case Some(ValueRemoved(_, _)) =>
-              // Remove already in update map, no need to change
-          }
-          writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value))
-        }
+      val prevValue = mapToUpdate.remove(key)
+      if (prevValue != null) {
+        writeRemoveToDeltaFile(tempDeltaFileStream, key)
       }
     }
 
-    /** Remove a single key. */
-    override def remove(key: UnsafeRow): Unit = {
-      verify(state == UPDATING, "Cannot remove after already committed or aborted")
-      if (mapToUpdate.containsKey(key)) {
-        val value = mapToUpdate.remove(key)
-        Option(allUpdates.get(key)) match {
-          case Some(ValueUpdated(_, _)) | None =>
-            // Value existed in previous version and maybe was updated, mark removed
-            allUpdates.put(key, ValueRemoved(key, value))
-          case Some(ValueAdded(_, _)) =>
-            // Value did not exist in previous version and was added, should not appear in updates
-            allUpdates.remove(key)
-          case Some(ValueRemoved(_, _)) =>
-          // Remove already in update map, no need to change
-        }
-        writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value))
-      }
+    override def getRange(
+        start: Option[UnsafeRow],
+        end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = {
+      verify(state == UPDATING, "Cannot getRange after already committed or aborted")
+      iterator()
     }
 
     /** Commit all the updates that have been made to the store, and return the new version. */
@@ -227,20 +165,11 @@ private[state] class HDFSBackedStateStoreProvider(
      * Get an iterator of all the store data.
      * This can be called only after committing all the updates made in the current thread.
      */
-    override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = {
-      verify(state == COMMITTED,
-        "Cannot get iterator of store data before committing or after aborting")
-      HDFSBackedStateStoreProvider.this.iterator(newVersion)
-    }
-
-    /**
-     * Get an iterator of all the updates made to the store in the current version.
-     * This can be called only after committing all the updates made in the current thread.
-     */
-    override def updates(): Iterator[StoreUpdate] = {
-      verify(state == COMMITTED,
-        "Cannot get iterator of updates before committing or after aborting")
-      allUpdates.values().asScala.toIterator
+    override def iterator(): Iterator[UnsafeRowPair] = {
+      val unsafeRowPair = new UnsafeRowPair()
+      mapToUpdate.entrySet.asScala.iterator.map { entry =>
+        unsafeRowPair.withRows(entry.getKey, entry.getValue)
+      }
     }
 
     override def numKeys(): Long = mapToUpdate.size()
@@ -269,6 +198,23 @@ private[state] class HDFSBackedStateStoreProvider(
     store
   }
 
+  override def init(
+      stateStoreId: StateStoreId,
+      keySchema: StructType,
+      valueSchema: StructType,
+      indexOrdinal: Option[Int], // for sorting the data
+      storeConf: StateStoreConf,
+      hadoopConf: Configuration): Unit = {
+    this.stateStoreId = stateStoreId
+    this.keySchema = keySchema
+    this.valueSchema = valueSchema
+    this.storeConf = storeConf
+    this.hadoopConf = hadoopConf
+    fs.mkdirs(baseDir)
+  }
+
+  override def id: StateStoreId = stateStoreId
+
   /** Do maintenance backing data files, including creating snapshots and cleaning up old files */
   override def doMaintenance(): Unit = {
     try {
@@ -280,19 +226,27 @@ private[state] class HDFSBackedStateStoreProvider(
     }
   }
 
+  override def close(): Unit = {
+    loadedMaps.values.foreach(_.clear())
+  }
+
   override def toString(): String = {
     s"HDFSStateStoreProvider[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]"
   }
 
-  /* Internal classes and methods */
+  /* Internal fields and methods */
 
-  private val loadedMaps = new mutable.HashMap[Long, MapType]
-  private val baseDir =
-    new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}")
-  private val fs = baseDir.getFileSystem(hadoopConf)
-  private val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
+  @volatile private var stateStoreId: StateStoreId = _
+  @volatile private var keySchema: StructType = _
+  @volatile private var valueSchema: StructType = _
+  @volatile private var storeConf: StateStoreConf = _
+  @volatile private var hadoopConf: Configuration = _
 
-  initialize()
+  private lazy val loadedMaps = new mutable.HashMap[Long, MapType]
+  private lazy val baseDir =
+    new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}")
+  private lazy val fs = baseDir.getFileSystem(hadoopConf)
+  private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
 
   private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean)
 
@@ -323,35 +277,18 @@ private[state] class HDFSBackedStateStoreProvider(
    * Get iterator of all the data of the latest version of the store.
    * Note that this will look up the files to determined the latest known version.
    */
-  private[state] def latestIterator(): Iterator[(UnsafeRow, UnsafeRow)] = synchronized {
+  private[state] def latestIterator(): Iterator[UnsafeRowPair] = synchronized {
     val versionsInFiles = fetchFiles().map(_.version).toSet
     val versionsLoaded = loadedMaps.keySet
     val allKnownVersions = versionsInFiles ++ versionsLoaded
+    val unsafeRowTuple = new UnsafeRowPair()
     if (allKnownVersions.nonEmpty) {
-      loadMap(allKnownVersions.max).entrySet().iterator().asScala.map { x =>
-        (x.getKey, x.getValue)
+      loadMap(allKnownVersions.max).entrySet().iterator().asScala.map { entry =>
+        unsafeRowTuple.withRows(entry.getKey, entry.getValue)
       }
     } else Iterator.empty
   }
 
-  /** Get iterator of a specific version of the store */
-  private[state] def iterator(version: Long): Iterator[(UnsafeRow, UnsafeRow)] = synchronized {
-    loadMap(version).entrySet().iterator().asScala.map { x =>
-      (x.getKey, x.getValue)
-    }
-  }
-
-  /** Initialize the store provider */
-  private def initialize(): Unit = {
-    try {
-      fs.mkdirs(baseDir)
-    } catch {
-      case e: IOException =>
-        throw new IllegalStateException(
-          s"Cannot use ${id.checkpointLocation} for storing state data for $this: $e ", e)
-    }
-  }
-
   /** Load the required version of the map data from the backing files */
   private def loadMap(version: Long): MapType = {
     if (version <= 0) return new MapType
@@ -367,32 +304,23 @@ private[state] class HDFSBackedStateStoreProvider(
     }
   }
 
-  private def writeToDeltaFile(output: DataOutputStream, update: StoreUpdate): Unit = {
-
-    def writeUpdate(key: UnsafeRow, value: UnsafeRow): Unit = {
-      val keyBytes = key.getBytes()
-      val valueBytes = value.getBytes()
-      output.writeInt(keyBytes.size)
-      output.write(keyBytes)
-      output.writeInt(valueBytes.size)
-      output.write(valueBytes)
-    }
-
-    def writeRemove(key: UnsafeRow): Unit = {
-      val keyBytes = key.getBytes()
-      output.writeInt(keyBytes.size)
-      output.write(keyBytes)
-      output.writeInt(-1)
-    }
+  private def writeUpdateToDeltaFile(
+      output: DataOutputStream,
+      key: UnsafeRow,
+      value: UnsafeRow): Unit = {
+    val keyBytes = key.getBytes()
+    val valueBytes = value.getBytes()
+    output.writeInt(keyBytes.size)
+    output.write(keyBytes)
+    output.writeInt(valueBytes.size)
+    output.write(valueBytes)
+  }
 
-    update match {
-      case ValueAdded(key, value) =>
-        writeUpdate(key, value)
-      case ValueUpdated(key, value) =>
-        writeUpdate(key, value)
-      case ValueRemoved(key, value) =>
-        writeRemove(key)
-    }
+  private def writeRemoveToDeltaFile(output: DataOutputStream, key: UnsafeRow): Unit = {
+    val keyBytes = key.getBytes()
+    output.writeInt(keyBytes.size)
+    output.write(keyBytes)
+    output.writeInt(-1)
   }
 
   private def finalizeDeltaFile(output: DataOutputStream): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index eaa558eb6d0ed4263b9d315c9ed1504ba52f77b2..29c456f86e1edcfd65236092bd4b3b592910d26f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -29,15 +29,12 @@ import org.apache.spark.SparkEnv
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.ThreadUtils
-
-
-/** Unique identifier for a [[StateStore]] */
-case class StateStoreId(checkpointLocation: String, operatorId: Long, partitionId: Int)
+import org.apache.spark.util.{ThreadUtils, Utils}
 
 
 /**
- * Base trait for a versioned key-value store used for streaming aggregations
+ * Base trait for a versioned key-value store. Each instance of a `StateStore` represents a specific
+ * version of state data, and such instances are created through a [[StateStoreProvider]].
  */
 trait StateStore {
 
@@ -47,50 +44,54 @@ trait StateStore {
   /** Version of the data in this store before committing updates. */
   def version: Long
 
-  /** Get the current value of a key. */
-  def get(key: UnsafeRow): Option[UnsafeRow]
-
   /**
-   * Return an iterator of key-value pairs that satisfy a certain condition.
-   * Note that the iterator must be fail-safe towards modification to the store, that is,
-   * it must be based on the snapshot of store the time of this call, and any change made to the
-   * store while iterating through iterator should not cause the iterator to fail or have
-   * any affect on the values in the iterator.
+   * Get the current value of a non-null key.
+   * @return a non-null row if the key exists in the store, otherwise null.
    */
-  def filter(condition: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)]
+  def get(key: UnsafeRow): UnsafeRow
 
-  /** Put a new value for a key. */
+  /**
+   * Put a new value for a non-null key. Implementations must be aware that the UnsafeRows in
+   * the params can be reused, and must make copies of the data as needed for persistence.
+   */
   def put(key: UnsafeRow, value: UnsafeRow): Unit
 
   /**
-   * Remove keys that match the following condition.
+   * Remove a single non-null key.
    */
-  def remove(condition: UnsafeRow => Boolean): Unit
+  def remove(key: UnsafeRow): Unit
 
   /**
-   * Remove a single key.
+   * Get key value pairs with optional approximate `start` and `end` extents.
+   * If the State Store implementation maintains indices for the data based on the optional
+   * `keyIndexOrdinal` over fields `keySchema` (see `StateStoreProvider.init()`), then it can use
+   * `start` and `end` to make a best-effort scan over the data. Default implementation returns
+   * the full data scan iterator, which is correct but inefficient. Custom implementations must
+   * ensure that updates (puts, removes) can be made while iterating over this iterator.
+   *
+   * @param start UnsafeRow having the `keyIndexOrdinal` column set with appropriate starting value.
+   * @param end UnsafeRow having the `keyIndexOrdinal` column set with appropriate ending value.
+   * @return An iterator of key-value pairs that is guaranteed not miss any key between start and
+   *         end, both inclusive.
    */
-  def remove(key: UnsafeRow): Unit
+  def getRange(start: Option[UnsafeRow], end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = {
+    iterator()
+  }
 
   /**
    * Commit all the updates that have been made to the store, and return the new version.
+   * Implementations should ensure that no more updates (puts, removes) can be after a commit in
+   * order to avoid incorrect usage.
    */
   def commit(): Long
 
-  /** Abort all the updates that have been made to the store. */
-  def abort(): Unit
-
   /**
-   * Iterator of store data after a set of updates have been committed.
-   * This can be called only after committing all the updates made in the current thread.
+   * Abort all the updates that have been made to the store. Implementations should ensure that
+   * no more updates (puts, removes) can be after an abort in order to avoid incorrect usage.
    */
-  def iterator(): Iterator[(UnsafeRow, UnsafeRow)]
+  def abort(): Unit
 
-  /**
-   * Iterator of the updates that have been committed.
-   * This can be called only after committing all the updates made in the current thread.
-   */
-  def updates(): Iterator[StoreUpdate]
+  def iterator(): Iterator[UnsafeRowPair]
 
   /** Number of keys in the state store */
   def numKeys(): Long
@@ -102,28 +103,98 @@ trait StateStore {
 }
 
 
-/** Trait representing a provider of a specific version of a [[StateStore]]. */
+/**
+ * Trait representing a provider that provide [[StateStore]] instances representing
+ * versions of state data.
+ *
+ * The life cycle of a provider and its provide stores are as follows.
+ *
+ * - A StateStoreProvider is created in a executor for each unique [[StateStoreId]] when
+ *   the first batch of a streaming query is executed on the executor. All subsequent batches reuse
+ *   this provider instance until the query is stopped.
+ *
+ * - Every batch of streaming data request a specific version of the state data by invoking
+ *   `getStore(version)` which returns an instance of [[StateStore]] through which the required
+ *   version of the data can be accessed. It is the responsible of the provider to populate
+ *   this store with context information like the schema of keys and values, etc.
+ *
+ * - After the streaming query is stopped, the created provider instances are lazily disposed off.
+ */
 trait StateStoreProvider {
 
-  /** Get the store with the existing version. */
+  /**
+   * Initialize the provide with more contextual information from the SQL operator.
+   * This method will be called first after creating an instance of the StateStoreProvider by
+   * reflection.
+   *
+   * @param stateStoreId Id of the versioned StateStores that this provider will generate
+   * @param keySchema Schema of keys to be stored
+   * @param valueSchema Schema of value to be stored
+   * @param keyIndexOrdinal Optional column (represent as the ordinal of the field in keySchema) by
+   *                        which the StateStore implementation could index the data.
+   * @param storeConfs Configurations used by the StateStores
+   * @param hadoopConf Hadoop configuration that could be used by StateStore to save state data
+   */
+  def init(
+      stateStoreId: StateStoreId,
+      keySchema: StructType,
+      valueSchema: StructType,
+      keyIndexOrdinal: Option[Int], // for sorting the data by their keys
+      storeConfs: StateStoreConf,
+      hadoopConf: Configuration): Unit
+
+  /**
+   * Return the id of the StateStores this provider will generate.
+   * Should be the same as the one passed in init().
+   */
+  def id: StateStoreId
+
+  /** Called when the provider instance is unloaded from the executor */
+  def close(): Unit
+
+  /** Return an instance of [[StateStore]] representing state data of the given version */
   def getStore(version: Long): StateStore
 
-  /** Optional method for providers to allow for background maintenance */
+  /** Optional method for providers to allow for background maintenance (e.g. compactions) */
   def doMaintenance(): Unit = { }
 }
 
-
-/** Trait representing updates made to a [[StateStore]]. */
-sealed trait StoreUpdate {
-  def key: UnsafeRow
-  def value: UnsafeRow
+object StateStoreProvider {
+  /**
+   * Return a provider instance of the given provider class.
+   * The instance will be already initialized.
+   */
+  def instantiate(
+      stateStoreId: StateStoreId,
+      keySchema: StructType,
+      valueSchema: StructType,
+      indexOrdinal: Option[Int], // for sorting the data
+      storeConf: StateStoreConf,
+      hadoopConf: Configuration): StateStoreProvider = {
+    val providerClass = storeConf.providerClass.map(Utils.classForName)
+        .getOrElse(classOf[HDFSBackedStateStoreProvider])
+    val provider = providerClass.newInstance().asInstanceOf[StateStoreProvider]
+    provider.init(stateStoreId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
+    provider
+  }
 }
 
-case class ValueAdded(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate
 
-case class ValueUpdated(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate
+/** Unique identifier for a bunch of keyed state data. */
+case class StateStoreId(
+    checkpointLocation: String,
+    operatorId: Long,
+    partitionId: Int,
+    name: String = "")
 
-case class ValueRemoved(key: UnsafeRow, value: UnsafeRow) extends StoreUpdate
+/** Mutable, and reusable class for representing a pair of UnsafeRows. */
+class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) {
+  def withRows(key: UnsafeRow, value: UnsafeRow): UnsafeRowPair = {
+    this.key = key
+    this.value = value
+    this
+  }
+}
 
 
 /**
@@ -185,6 +256,7 @@ object StateStore extends Logging {
       storeId: StateStoreId,
       keySchema: StructType,
       valueSchema: StructType,
+      indexOrdinal: Option[Int],
       version: Long,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): StateStore = {
@@ -193,7 +265,9 @@ object StateStore extends Logging {
       startMaintenanceIfNeeded()
       val provider = loadedProviders.getOrElseUpdate(
         storeId,
-        new HDFSBackedStateStoreProvider(storeId, keySchema, valueSchema, storeConf, hadoopConf))
+        StateStoreProvider.instantiate(
+          storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
+      )
       reportActiveStoreInstance(storeId)
       provider
     }
@@ -202,7 +276,7 @@ object StateStore extends Logging {
 
   /** Unload a state store provider */
   def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized {
-    loadedProviders.remove(storeId)
+    loadedProviders.remove(storeId).foreach(_.close())
   }
 
   /** Whether a state store provider is loaded or not */
@@ -216,6 +290,7 @@ object StateStore extends Logging {
 
   /** Unload and stop all state store providers */
   def stop(): Unit = loadedProviders.synchronized {
+    loadedProviders.keySet.foreach { key => unload(key) }
     loadedProviders.clear()
     _coordRef = null
     if (maintenanceTask != null) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
index acfaa8e5eb3c4bcf2bf61e3a771c5efe082c874f..bab297c7df594e26479e4b1965af7ba87d94eaf8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
@@ -20,16 +20,34 @@ package org.apache.spark.sql.execution.streaming.state
 import org.apache.spark.sql.internal.SQLConf
 
 /** A class that contains configuration parameters for [[StateStore]]s. */
-private[streaming] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable {
+class StateStoreConf(@transient private val sqlConf: SQLConf)
+  extends Serializable {
 
   def this() = this(new SQLConf)
 
-  val minDeltasForSnapshot = conf.stateStoreMinDeltasForSnapshot
-
-  val minVersionsToRetain = conf.minBatchesToRetain
+  /**
+   * Minimum number of delta files in a chain after which HDFSBackedStateStore will
+   * consider generating a snapshot.
+   */
+  val minDeltasForSnapshot: Int = sqlConf.stateStoreMinDeltasForSnapshot
+
+  /** Minimum versions a State Store implementation should retain to allow rollbacks */
+  val minVersionsToRetain: Int = sqlConf.minBatchesToRetain
+
+  /**
+   * Optional fully qualified name of the subclass of [[StateStoreProvider]]
+   * managing state data. That is, the implementation of the State Store to use.
+   */
+  val providerClass: Option[String] = sqlConf.stateStoreProviderClass
+
+  /**
+   * Additional configurations related to state store. This will capture all configs in
+   * SQLConf that start with `spark.sql.streaming.stateStore.` */
+  val confs: Map[String, String] =
+    sqlConf.getAllConfs.filter(_._1.startsWith("spark.sql.streaming.stateStore."))
 }
 
-private[streaming] object StateStoreConf {
+object StateStoreConf {
   val empty = new StateStoreConf()
 
   def apply(conf: SQLConf): StateStoreConf = new StateStoreConf(conf)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
index e16dda8a5b5640f9da9ae2dddf3298eb5ca3e47d..b744c25dc97a81e4ac741447cffb41c5ebd5e54a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -35,9 +35,11 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
     storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
     checkpointLocation: String,
     operatorId: Long,
+    storeName: String,
     storeVersion: Long,
     keySchema: StructType,
     valueSchema: StructType,
+    indexOrdinal: Option[Int],
     sessionState: SessionState,
     @transient private val storeCoordinator: Option[StateStoreCoordinatorRef])
   extends RDD[U](dataRDD) {
@@ -45,21 +47,22 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
   private val storeConf = new StateStoreConf(sessionState.conf)
 
   // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it
-  private val confBroadcast = dataRDD.context.broadcast(
+  private val hadoopConfBroadcast = dataRDD.context.broadcast(
     new SerializableConfiguration(sessionState.newHadoopConf()))
 
   override protected def getPartitions: Array[Partition] = dataRDD.partitions
 
   override def getPreferredLocations(partition: Partition): Seq[String] = {
-    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
+    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName)
     storeCoordinator.flatMap(_.getLocation(storeId)).toSeq
   }
 
   override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = {
     var store: StateStore = null
-    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index)
+    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName)
     store = StateStore.get(
-      storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value)
+      storeId, keySchema, valueSchema, indexOrdinal, storeVersion,
+      storeConf, hadoopConfBroadcast.value.value)
     val inputIter = dataRDD.iterator(partition, ctxt)
     storeUpdateFunction(store, inputIter)
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index 589042afb1e52ce646cd136675154e93c8d4dcd5..228fe86d59940416509f41e342a08fb36aeb0220 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -34,17 +34,21 @@ package object state {
         sqlContext: SQLContext,
         checkpointLocation: String,
         operatorId: Long,
+        storeName: String,
         storeVersion: Long,
         keySchema: StructType,
-        valueSchema: StructType)(
+        valueSchema: StructType,
+        indexOrdinal: Option[Int])(
         storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
 
       mapPartitionsWithStateStore(
         checkpointLocation,
         operatorId,
+        storeName,
         storeVersion,
         keySchema,
         valueSchema,
+        indexOrdinal,
         sqlContext.sessionState,
         Some(sqlContext.streams.stateStoreCoordinator))(
         storeUpdateFunction)
@@ -54,9 +58,11 @@ package object state {
     private[streaming] def mapPartitionsWithStateStore[U: ClassTag](
         checkpointLocation: String,
         operatorId: Long,
+        storeName: String,
         storeVersion: Long,
         keySchema: StructType,
         valueSchema: StructType,
+        indexOrdinal: Option[Int],
         sessionState: SessionState,
         storeCoordinator: Option[StateStoreCoordinatorRef])(
         storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
@@ -69,14 +75,17 @@ package object state {
         })
         cleanedF(store, iter)
       }
+
       new StateStoreRDD(
         dataRDD,
         wrappedF,
         checkpointLocation,
         operatorId,
+        storeName,
         storeVersion,
         keySchema,
         valueSchema,
+        indexOrdinal,
         sessionState,
         storeCoordinator)
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index 8dbda298c87bcae220dd33b830614beaa6e5c6e3..3e57f3fbada32ecd1cd93efe323a043aef38573b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -17,21 +17,22 @@
 
 package org.apache.spark.sql.execution.streaming
 
+import java.util.concurrent.TimeUnit._
+
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate}
-import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalGroupState, ProcessingTimeTimeout}
+import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
 import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.execution.streaming.state._
-import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
+import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types._
-import org.apache.spark.util.CompletionIterator
+import org.apache.spark.util.{CompletionIterator, NextIterator}
 
 
 /** Used to identify the state store for a given operator. */
@@ -61,11 +62,24 @@ trait StateStoreReader extends StatefulOperator {
 }
 
 /** An operator that writes to a StateStore. */
-trait StateStoreWriter extends StatefulOperator {
+trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
+
   override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
     "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"),
-    "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"))
+    "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"),
+    "allUpdatesTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total time to update rows"),
+    "allRemovalsTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "total time to remove rows"),
+    "commitTimeMs" -> SQLMetrics.createTimingMetric(sparkContext, "time to commit changes")
+  )
+
+  /** Records the duration of running `body` for the next query progress update. */
+  protected def timeTakenMs(body: => Unit): Long = {
+    val startTime = System.nanoTime()
+    val result = body
+    val endTime = System.nanoTime()
+    math.max(NANOSECONDS.toMillis(endTime - startTime), 0)
+  }
 }
 
 /** An operator that supports watermark. */
@@ -108,6 +122,16 @@ trait WatermarkSupport extends UnaryExecNode {
   /** Predicate based on the child output that matches data older than the watermark. */
   lazy val watermarkPredicateForData: Option[Predicate] =
     watermarkExpression.map(newPredicate(_, child.output))
+
+  protected def removeKeysOlderThanWatermark(store: StateStore): Unit = {
+    if (watermarkPredicateForKeys.nonEmpty) {
+      store.getRange(None, None).foreach { rowPair =>
+        if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
+          store.remove(rowPair.key)
+        }
+      }
+    }
+  }
 }
 
 /**
@@ -126,9 +150,11 @@ case class StateStoreRestoreExec(
     child.execute().mapPartitionsWithStateStore(
       getStateId.checkpointLocation,
       operatorId = getStateId.operatorId,
+      storeName = "default",
       storeVersion = getStateId.batchId,
       keyExpressions.toStructType,
       child.output.toStructType,
+      indexOrdinal = None,
       sqlContext.sessionState,
       Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
         val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
@@ -136,7 +162,7 @@ case class StateStoreRestoreExec(
           val key = getKey(row)
           val savedState = store.get(key)
           numOutputRows += 1
-          row +: savedState.toSeq
+          row +: Option(savedState).toSeq
         }
     }
   }
@@ -165,54 +191,88 @@ case class StateStoreSaveExec(
     child.execute().mapPartitionsWithStateStore(
       getStateId.checkpointLocation,
       getStateId.operatorId,
+      storeName = "default",
       getStateId.batchId,
       keyExpressions.toStructType,
       child.output.toStructType,
+      indexOrdinal = None,
       sqlContext.sessionState,
       Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
         val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
         val numOutputRows = longMetric("numOutputRows")
         val numTotalStateRows = longMetric("numTotalStateRows")
         val numUpdatedStateRows = longMetric("numUpdatedStateRows")
+        val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
+        val allRemovalsTimeMs = longMetric("allRemovalsTimeMs")
+        val commitTimeMs = longMetric("commitTimeMs")
 
         outputMode match {
           // Update and output all rows in the StateStore.
           case Some(Complete) =>
-            while (iter.hasNext) {
-              val row = iter.next().asInstanceOf[UnsafeRow]
-              val key = getKey(row)
-              store.put(key.copy(), row.copy())
-              numUpdatedStateRows += 1
+            allUpdatesTimeMs += timeTakenMs {
+              while (iter.hasNext) {
+                val row = iter.next().asInstanceOf[UnsafeRow]
+                val key = getKey(row)
+                store.put(key, row)
+                numUpdatedStateRows += 1
+              }
+            }
+            allRemovalsTimeMs += 0
+            commitTimeMs += timeTakenMs {
+              store.commit()
             }
-            store.commit()
             numTotalStateRows += store.numKeys()
-            store.iterator().map { case (k, v) =>
+            store.iterator().map { rowPair =>
               numOutputRows += 1
-              v.asInstanceOf[InternalRow]
+              rowPair.value
             }
 
           // Update and output only rows being evicted from the StateStore
+          // Assumption: watermark predicates must be non-empty if append mode is allowed
           case Some(Append) =>
-            while (iter.hasNext) {
-              val row = iter.next().asInstanceOf[UnsafeRow]
-              val key = getKey(row)
-              store.put(key.copy(), row.copy())
-              numUpdatedStateRows += 1
+            allUpdatesTimeMs += timeTakenMs {
+              val filteredIter = iter.filter(row => !watermarkPredicateForData.get.eval(row))
+              while (filteredIter.hasNext) {
+                val row = filteredIter.next().asInstanceOf[UnsafeRow]
+                val key = getKey(row)
+                store.put(key, row)
+                numUpdatedStateRows += 1
+              }
             }
 
-            // Assumption: Append mode can be done only when watermark has been specified
-            store.remove(watermarkPredicateForKeys.get.eval _)
-            store.commit()
+            val removalStartTimeNs = System.nanoTime
+            val rangeIter = store.getRange(None, None)
+
+            new NextIterator[InternalRow] {
+              override protected def getNext(): InternalRow = {
+                var removedValueRow: InternalRow = null
+                while(rangeIter.hasNext && removedValueRow == null) {
+                  val rowPair = rangeIter.next()
+                  if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
+                    store.remove(rowPair.key)
+                    removedValueRow = rowPair.value
+                  }
+                }
+                if (removedValueRow == null) {
+                  finished = true
+                  null
+                } else {
+                  removedValueRow
+                }
+              }
 
-            numTotalStateRows += store.numKeys()
-            store.updates().filter(_.isInstanceOf[ValueRemoved]).map { removed =>
-              numOutputRows += 1
-              removed.value.asInstanceOf[InternalRow]
+              override protected def close(): Unit = {
+                allRemovalsTimeMs += NANOSECONDS.toMillis(System.nanoTime - removalStartTimeNs)
+                commitTimeMs += timeTakenMs { store.commit() }
+                numTotalStateRows += store.numKeys()
+              }
             }
 
           // Update and output modified rows from the StateStore.
           case Some(Update) =>
 
+            val updatesStartTimeNs = System.nanoTime
+
             new Iterator[InternalRow] {
 
               // Filter late date using watermark if specified
@@ -223,11 +283,11 @@ case class StateStoreSaveExec(
 
               override def hasNext: Boolean = {
                 if (!baseIterator.hasNext) {
+                  allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
+
                   // Remove old aggregates if watermark specified
-                  if (watermarkPredicateForKeys.nonEmpty) {
-                    store.remove(watermarkPredicateForKeys.get.eval _)
-                  }
-                  store.commit()
+                  allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
+                  commitTimeMs += timeTakenMs { store.commit() }
                   numTotalStateRows += store.numKeys()
                   false
                 } else {
@@ -238,7 +298,7 @@ case class StateStoreSaveExec(
               override def next(): InternalRow = {
                 val row = baseIterator.next().asInstanceOf[UnsafeRow]
                 val key = getKey(row)
-                store.put(key.copy(), row.copy())
+                store.put(key, row)
                 numOutputRows += 1
                 numUpdatedStateRows += 1
                 row
@@ -273,27 +333,34 @@ case class StreamingDeduplicateExec(
     child.execute().mapPartitionsWithStateStore(
       getStateId.checkpointLocation,
       getStateId.operatorId,
+      storeName = "default",
       getStateId.batchId,
       keyExpressions.toStructType,
       child.output.toStructType,
+      indexOrdinal = None,
       sqlContext.sessionState,
       Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
       val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
       val numOutputRows = longMetric("numOutputRows")
       val numTotalStateRows = longMetric("numTotalStateRows")
       val numUpdatedStateRows = longMetric("numUpdatedStateRows")
+      val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
+      val allRemovalsTimeMs = longMetric("allRemovalsTimeMs")
+      val commitTimeMs = longMetric("commitTimeMs")
 
       val baseIterator = watermarkPredicateForData match {
         case Some(predicate) => iter.filter(row => !predicate.eval(row))
         case None => iter
       }
 
+      val updatesStartTimeNs = System.nanoTime
+
       val result = baseIterator.filter { r =>
         val row = r.asInstanceOf[UnsafeRow]
         val key = getKey(row)
         val value = store.get(key)
-        if (value.isEmpty) {
-          store.put(key.copy(), StreamingDeduplicateExec.EMPTY_ROW)
+        if (value == null) {
+          store.put(key, StreamingDeduplicateExec.EMPTY_ROW)
           numUpdatedStateRows += 1
           numOutputRows += 1
           true
@@ -304,8 +371,9 @@ case class StreamingDeduplicateExec(
       }
 
       CompletionIterator[InternalRow, Iterator[InternalRow]](result, {
-        watermarkPredicateForKeys.foreach(f => store.remove(f.eval _))
-        store.commit()
+        allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
+        allRemovalsTimeMs += timeTakenMs { removeKeysOlderThanWatermark(store) }
+        commitTimeMs += timeTakenMs { store.commit() }
         numTotalStateRows += store.numKeys()
       })
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
index bd197be655d5862337c8305b13e0e615dbd10d97..4a1a089af54c2178c233fb9f0f413bfd59c05be2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -38,13 +38,13 @@ import org.apache.spark.util.{CompletionIterator, Utils}
 
 class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll {
 
+  import StateStoreTestsHelper._
+
   private val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName)
-  private var tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString
+  private val tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString
   private val keySchema = StructType(Seq(StructField("key", StringType, true)))
   private val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
 
-  import StateStoreSuite._
-
   after {
     StateStore.stop()
   }
@@ -60,13 +60,14 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
       val opId = 0
       val rdd1 =
         makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
-            spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(
+            spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(
             increment)
       assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
 
       // Generate next version of stores
       val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore(
-        spark.sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
+        spark.sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)(
+        increment)
       assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
 
       // Make sure the previous RDD still has the same data.
@@ -84,7 +85,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
         storeVersion: Int): RDD[(String, Int)] = {
       implicit val sqlContext = spark.sqlContext
       makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore(
-        sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment)
+        sqlContext, path, opId, "name", storeVersion, keySchema, valueSchema, None)(increment)
     }
 
     // Generate RDDs and state store data
@@ -110,7 +111,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
       def iteratorOfPuts(store: StateStore, iter: Iterator[String]): Iterator[(String, Int)] = {
         val resIterator = iter.map { s =>
           val key = stringToRow(s)
-          val oldValue = store.get(key).map(rowToInt).getOrElse(0)
+          val oldValue = Option(store.get(key)).map(rowToInt).getOrElse(0)
           val newValue = oldValue + 1
           store.put(key, intToRow(newValue))
           (s, newValue)
@@ -125,21 +126,24 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
           iter: Iterator[String]): Iterator[(String, Option[Int])] = {
         iter.map { s =>
           val key = stringToRow(s)
-          val value = store.get(key).map(rowToInt)
+          val value = Option(store.get(key)).map(rowToInt)
           (s, value)
         }
       }
 
       val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore(
-        spark.sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets)
+        spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(
+        iteratorOfGets)
       assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None))
 
       val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
-        sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts)
+        sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(
+        iteratorOfPuts)
       assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1))
 
       val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore(
-        sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets)
+        sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)(
+        iteratorOfGets)
       assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None))
     }
   }
@@ -152,15 +156,16 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
       withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
         implicit val sqlContext = spark.sqlContext
         val coordinatorRef = sqlContext.streams.stateStoreCoordinator
-        coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1")
-        coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2")
+        coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0, "name"), "host1", "exec1")
+        coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1, "name"), "host2", "exec2")
 
         assert(
-          coordinatorRef.getLocation(StateStoreId(path, opId, 0)) ===
+          coordinatorRef.getLocation(StateStoreId(path, opId, 0, "name")) ===
             Some(ExecutorCacheTaskLocation("host1", "exec1").toString))
 
         val rdd = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
-          sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment)
+          sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(
+          increment)
         require(rdd.partitions.length === 2)
 
         assert(
@@ -187,12 +192,12 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
         val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
         val opId = 0
         val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
-          sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment)
+          sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(increment)
         assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
 
         // Generate next version of stores
         val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore(
-          sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment)
+          sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)(increment)
         assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
 
         // Make sure the previous RDD still has the same data.
@@ -208,7 +213,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
   private val increment = (store: StateStore, iter: Iterator[String]) => {
     iter.foreach { s =>
       val key = stringToRow(s)
-      val oldValue = store.get(key).map(rowToInt).getOrElse(0)
+      val oldValue = Option(store.get(key)).map(rowToInt).getOrElse(0)
       store.put(key, intToRow(oldValue + 1))
     }
     store.commit()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index cc09b2d5b77638afc3ed2b7ae647174d4b676671..af2b9f1c11fb6a6bd09f70cf0e6fe877fe2c3057 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -40,15 +40,15 @@ import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.Utils
 
-class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester {
+class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
+  with BeforeAndAfter with PrivateMethodTester {
   type MapType = mutable.HashMap[UnsafeRow, UnsafeRow]
 
   import StateStoreCoordinatorSuite._
-  import StateStoreSuite._
+  import StateStoreTestsHelper._
 
-  private val tempDir = Utils.createTempDir().toString
-  private val keySchema = StructType(Seq(StructField("key", StringType, true)))
-  private val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
+  val keySchema = StructType(Seq(StructField("key", StringType, true)))
+  val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
 
   before {
     StateStore.stop()
@@ -60,186 +60,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     require(!StateStore.isMaintenanceRunning)
   }
 
-  test("get, put, remove, commit, and all data iterator") {
-    val provider = newStoreProvider()
-
-    // Verify state before starting a new set of updates
-    assert(provider.latestIterator().isEmpty)
-
-    val store = provider.getStore(0)
-    assert(!store.hasCommitted)
-    intercept[IllegalStateException] {
-      store.iterator()
-    }
-    intercept[IllegalStateException] {
-      store.updates()
-    }
-
-    // Verify state after updating
-    put(store, "a", 1)
-    assert(store.numKeys() === 1)
-    intercept[IllegalStateException] {
-      store.iterator()
-    }
-    intercept[IllegalStateException] {
-      store.updates()
-    }
-    assert(provider.latestIterator().isEmpty)
-
-    // Make updates, commit and then verify state
-    put(store, "b", 2)
-    put(store, "aa", 3)
-    assert(store.numKeys() === 3)
-    remove(store, _.startsWith("a"))
-    assert(store.numKeys() === 1)
-    assert(store.commit() === 1)
-
-    assert(store.hasCommitted)
-    assert(rowsToSet(store.iterator()) === Set("b" -> 2))
-    assert(rowsToSet(provider.latestIterator()) === Set("b" -> 2))
-    assert(fileExists(provider, version = 1, isSnapshot = false))
-
-    assert(getDataFromFiles(provider) === Set("b" -> 2))
-
-    // Trying to get newer versions should fail
-    intercept[Exception] {
-      provider.getStore(2)
-    }
-    intercept[Exception] {
-      getDataFromFiles(provider, 2)
-    }
-
-    // New updates to the reloaded store with new version, and does not change old version
-    val reloadedProvider = new HDFSBackedStateStoreProvider(
-      store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration)
-    val reloadedStore = reloadedProvider.getStore(1)
-    assert(reloadedStore.numKeys() === 1)
-    put(reloadedStore, "c", 4)
-    assert(reloadedStore.numKeys() === 2)
-    assert(reloadedStore.commit() === 2)
-    assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
-    assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4))
-    assert(getDataFromFiles(provider, version = 1) === Set("b" -> 2))
-    assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4))
-  }
-
-  test("filter and concurrent updates") {
-    val provider = newStoreProvider()
-
-    // Verify state before starting a new set of updates
-    assert(provider.latestIterator.isEmpty)
-    val store = provider.getStore(0)
-    put(store, "a", 1)
-    put(store, "b", 2)
-
-    // Updates should work while iterating of filtered entries
-    val filtered = store.filter { case (keyRow, _) => rowToString(keyRow) == "a" }
-    filtered.foreach { case (keyRow, valueRow) =>
-      store.put(keyRow, intToRow(rowToInt(valueRow) + 1))
-    }
-    assert(get(store, "a") === Some(2))
-
-    // Removes should work while iterating of filtered entries
-    val filtered2 = store.filter { case (keyRow, _) => rowToString(keyRow) == "b" }
-    filtered2.foreach { case (keyRow, _) =>
-      store.remove(keyRow)
-    }
-    assert(get(store, "b") === None)
-  }
-
-  test("updates iterator with all combos of updates and removes") {
-    val provider = newStoreProvider()
-    var currentVersion: Int = 0
-
-    def withStore(body: StateStore => Unit): Unit = {
-      val store = provider.getStore(currentVersion)
-      body(store)
-      currentVersion += 1
-    }
-
-    // New data should be seen in updates as value added, even if they had multiple updates
-    withStore { store =>
-      put(store, "a", 1)
-      put(store, "aa", 1)
-      put(store, "aa", 2)
-      store.commit()
-      assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2)))
-      assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2))
-    }
-
-    // Multiple updates to same key should be collapsed in the updates as a single value update
-    // Keys that have not been updated should not appear in the updates
-    withStore { store =>
-      put(store, "a", 4)
-      put(store, "a", 6)
-      store.commit()
-      assert(updatesToSet(store.updates()) === Set(Updated("a", 6)))
-      assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2))
-    }
-
-    // Keys added, updated and finally removed before commit should not appear in updates
-    withStore { store =>
-      put(store, "b", 4)     // Added, finally removed
-      put(store, "bb", 5)    // Added, updated, finally removed
-      put(store, "bb", 6)
-      remove(store, _.startsWith("b"))
-      store.commit()
-      assert(updatesToSet(store.updates()) === Set.empty)
-      assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2))
-    }
-
-    // Removed data should be seen in updates as a key removed
-    // Removed, but re-added data should be seen in updates as a value update
-    withStore { store =>
-      remove(store, _.startsWith("a"))
-      put(store, "a", 10)
-      store.commit()
-      assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa")))
-      assert(rowsToSet(store.iterator()) === Set("a" -> 10))
-    }
-  }
-
-  test("cancel") {
-    val provider = newStoreProvider()
-    val store = provider.getStore(0)
-    put(store, "a", 1)
-    store.commit()
-    assert(rowsToSet(store.iterator()) === Set("a" -> 1))
-
-    // cancelUpdates should not change the data in the files
-    val store1 = provider.getStore(1)
-    put(store1, "b", 1)
-    store1.abort()
-    assert(getDataFromFiles(provider) === Set("a" -> 1))
-  }
-
-  test("getStore with unexpected versions") {
-    val provider = newStoreProvider()
-
-    intercept[IllegalArgumentException] {
-      provider.getStore(-1)
-    }
-
-    // Prepare some data in the store
-    val store = provider.getStore(0)
-    put(store, "a", 1)
-    assert(store.commit() === 1)
-    assert(rowsToSet(store.iterator()) === Set("a" -> 1))
-
-    intercept[IllegalStateException] {
-      provider.getStore(2)
-    }
-
-    // Update store version with some data
-    val store1 = provider.getStore(1)
-    put(store1, "b", 1)
-    assert(store1.commit() === 2)
-    assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1))
-    assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1))
-  }
-
   test("snapshotting") {
-    val provider = newStoreProvider(minDeltasForSnapshot = 5)
+    val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5)
 
     var currentVersion = 0
     def updateVersionTo(targetVersion: Int): Unit = {
@@ -253,9 +75,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     }
 
     updateVersionTo(2)
-    require(getDataFromFiles(provider) === Set("a" -> 2))
+    require(getData(provider) === Set("a" -> 2))
     provider.doMaintenance()               // should not generate snapshot files
-    assert(getDataFromFiles(provider) === Set("a" -> 2))
+    assert(getData(provider) === Set("a" -> 2))
 
     for (i <- 1 to currentVersion) {
       assert(fileExists(provider, i, isSnapshot = false))  // all delta files present
@@ -264,22 +86,22 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
 
     // After version 6, snapshotting should generate one snapshot file
     updateVersionTo(6)
-    require(getDataFromFiles(provider) === Set("a" -> 6), "store not updated correctly")
+    require(getData(provider) === Set("a" -> 6), "store not updated correctly")
     provider.doMaintenance()       // should generate snapshot files
 
     val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true))
     assert(snapshotVersion.nonEmpty, "snapshot file not generated")
     deleteFilesEarlierThanVersion(provider, snapshotVersion.get)
     assert(
-      getDataFromFiles(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get),
+      getData(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get),
       "snapshotting messed up the data of the snapshotted version")
     assert(
-      getDataFromFiles(provider) === Set("a" -> 6),
+      getData(provider) === Set("a" -> 6),
       "snapshotting messed up the data of the final version")
 
     // After version 20, snapshotting should generate newer snapshot files
     updateVersionTo(20)
-    require(getDataFromFiles(provider) === Set("a" -> 20), "store not updated correctly")
+    require(getData(provider) === Set("a" -> 20), "store not updated correctly")
     provider.doMaintenance()       // do snapshot
 
     val latestSnapshotVersion = (0 to 20).filter(version =>
@@ -288,11 +110,11 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated")
 
     deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get)
-    assert(getDataFromFiles(provider) === Set("a" -> 20), "snapshotting messed up the data")
+    assert(getData(provider) === Set("a" -> 20), "snapshotting messed up the data")
   }
 
   test("cleaning") {
-    val provider = newStoreProvider(minDeltasForSnapshot = 5)
+    val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5)
 
     for (i <- 1 to 20) {
       val store = provider.getStore(i - 1)
@@ -307,8 +129,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted
 
     // last couple of versions should be retrievable
-    assert(getDataFromFiles(provider, 20) === Set("a" -> 20))
-    assert(getDataFromFiles(provider, 19) === Set("a" -> 19))
+    assert(getData(provider, 20) === Set("a" -> 20))
+    assert(getData(provider, 19) === Set("a" -> 19))
   }
 
   test("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") {
@@ -316,7 +138,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName)
     conf.set("fs.defaultFS", "fake:///")
 
-    val provider = newStoreProvider(hadoopConf = conf)
+    val provider = newStoreProvider(opId = Random.nextInt, partition = 0, hadoopConf = conf)
     provider.getStore(0).commit()
     provider.getStore(0).commit()
 
@@ -327,7 +149,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
   }
 
   test("corrupted file handling") {
-    val provider = newStoreProvider(minDeltasForSnapshot = 5)
+    val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5)
     for (i <- 1 to 6) {
       val store = provider.getStore(i - 1)
       put(store, "a", i)
@@ -338,62 +160,75 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
       fileExists(provider, version, isSnapshot = true)).getOrElse(fail("snapshot file not found"))
 
     // Corrupt snapshot file and verify that it throws error
-    assert(getDataFromFiles(provider, snapshotVersion) === Set("a" -> snapshotVersion))
+    assert(getData(provider, snapshotVersion) === Set("a" -> snapshotVersion))
     corruptFile(provider, snapshotVersion, isSnapshot = true)
     intercept[Exception] {
-      getDataFromFiles(provider, snapshotVersion)
+      getData(provider, snapshotVersion)
     }
 
     // Corrupt delta file and verify that it throws error
-    assert(getDataFromFiles(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1)))
+    assert(getData(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1)))
     corruptFile(provider, snapshotVersion - 1, isSnapshot = false)
     intercept[Exception] {
-      getDataFromFiles(provider, snapshotVersion - 1)
+      getData(provider, snapshotVersion - 1)
     }
 
     // Delete delta file and verify that it throws error
     deleteFilesEarlierThanVersion(provider, snapshotVersion)
     intercept[Exception] {
-      getDataFromFiles(provider, snapshotVersion - 1)
+      getData(provider, snapshotVersion - 1)
     }
   }
 
   test("StateStore.get") {
     quietly {
-      val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString
+      val dir = newDir()
       val storeId = StateStoreId(dir, 0, 0)
       val storeConf = StateStoreConf.empty
       val hadoopConf = new Configuration()
 
-
       // Verify that trying to get incorrect versions throw errors
       intercept[IllegalArgumentException] {
-        StateStore.get(storeId, keySchema, valueSchema, -1, storeConf, hadoopConf)
+        StateStore.get(
+          storeId, keySchema, valueSchema, None, -1, storeConf, hadoopConf)
       }
       assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store
 
       intercept[IllegalStateException] {
-        StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf)
+        StateStore.get(
+          storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf)
       }
 
-      // Increase version of the store
-      val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf)
+      // Increase version of the store and try to get again
+      val store0 = StateStore.get(
+        storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf)
       assert(store0.version === 0)
       put(store0, "a", 1)
       store0.commit()
 
-      assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1)
-      assert(StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf).version == 0)
+      val store1 = StateStore.get(
+        storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf)
+      assert(StateStore.isLoaded(storeId))
+      assert(store1.version === 1)
+      assert(rowsToSet(store1.iterator()) === Set("a" -> 1))
+
+      // Verify that you can also load older version
+      val store0reloaded = StateStore.get(
+        storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf)
+      assert(store0reloaded.version === 0)
+      assert(rowsToSet(store0reloaded.iterator()) === Set.empty)
 
       // Verify that you can remove the store and still reload and use it
       StateStore.unload(storeId)
       assert(!StateStore.isLoaded(storeId))
 
-      val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf)
+      val store1reloaded = StateStore.get(
+        storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf)
       assert(StateStore.isLoaded(storeId))
-      put(store1, "a", 2)
-      assert(store1.commit() === 2)
-      assert(rowsToSet(store1.iterator()) === Set("a" -> 2))
+      assert(store1reloaded.version === 1)
+      put(store1reloaded, "a", 2)
+      assert(store1reloaded.commit() === 2)
+      assert(rowsToSet(store1reloaded.iterator()) === Set("a" -> 2))
     }
   }
 
@@ -407,21 +242,20 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
       // fails to talk to the StateStoreCoordinator and unloads all the StateStores
       .set("spark.rpc.numRetries", "1")
     val opId = 0
-    val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString
+    val dir = newDir()
     val storeId = StateStoreId(dir, opId, 0)
     val sqlConf = new SQLConf()
     sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
     val storeConf = StateStoreConf(sqlConf)
     val hadoopConf = new Configuration()
-    val provider = new HDFSBackedStateStoreProvider(
-      storeId, keySchema, valueSchema, storeConf, hadoopConf)
+    val provider = newStoreProvider(storeId)
 
     var latestStoreVersion = 0
 
     def generateStoreVersions() {
       for (i <- 1 to 20) {
-        val store = StateStore.get(
-          storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf)
+        val store = StateStore.get(storeId, keySchema, valueSchema, None,
+          latestStoreVersion, storeConf, hadoopConf)
         put(store, "a", i)
         store.commit()
         latestStoreVersion += 1
@@ -469,7 +303,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
           }
 
           // Reload the store and verify
-          StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf)
+          StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None,
+            latestStoreVersion, storeConf, hadoopConf)
           assert(StateStore.isLoaded(storeId))
 
           // If some other executor loads the store, then this instance should be unloaded
@@ -479,7 +314,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
           }
 
           // Reload the store and verify
-          StateStore.get(storeId, keySchema, valueSchema, latestStoreVersion, storeConf, hadoopConf)
+          StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None,
+            latestStoreVersion, storeConf, hadoopConf)
           assert(StateStore.isLoaded(storeId))
         }
       }
@@ -495,10 +331,11 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
 
   test("SPARK-18342: commit fails when rename fails") {
     import RenameReturnsFalseFileSystem._
-    val dir = scheme + "://" + Utils.createDirectory(tempDir, Random.nextString(5)).toURI.getPath
+    val dir = scheme + "://" + newDir()
     val conf = new Configuration()
     conf.set(s"fs.$scheme.impl", classOf[RenameReturnsFalseFileSystem].getName)
-    val provider = newStoreProvider(dir = dir, hadoopConf = conf)
+    val provider = newStoreProvider(
+      opId = Random.nextInt, partition = 0, dir = dir, hadoopConf = conf)
     val store = provider.getStore(0)
     put(store, "a", 0)
     val e = intercept[IllegalStateException](store.commit())
@@ -506,7 +343,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
   }
 
   test("SPARK-18416: do not create temp delta file until the store is updated") {
-    val dir = Utils.createDirectory(tempDir, Random.nextString(5)).toString
+    val dir = newDir()
     val storeId = StateStoreId(dir, 0, 0)
     val storeConf = StateStoreConf.empty
     val hadoopConf = new Configuration()
@@ -533,7 +370,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
 
     // Getting the store should not create temp file
     val store0 = shouldNotCreateTempFile {
-      StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf)
+      StateStore.get(
+        storeId, keySchema, valueSchema, indexOrdinal = None, version = 0, storeConf, hadoopConf)
     }
 
     // Put should create a temp file
@@ -548,7 +386,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
 
     // Remove should create a temp file
     val store1 = shouldNotCreateTempFile {
-      StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf)
+      StateStore.get(
+        storeId, keySchema, valueSchema, indexOrdinal = None, version = 1, storeConf, hadoopConf)
     }
     remove(store1, _ == "a")
     assert(numTempFiles === 1)
@@ -561,31 +400,55 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
 
     // Commit without any updates should create a delta file
     val store2 = shouldNotCreateTempFile {
-      StateStore.get(storeId, keySchema, valueSchema, 2, storeConf, hadoopConf)
+      StateStore.get(
+        storeId, keySchema, valueSchema, indexOrdinal = None, version = 2, storeConf, hadoopConf)
     }
     store2.commit()
     assert(numTempFiles === 0)
     assert(numDeltaFiles === 3)
   }
 
-  def getDataFromFiles(
-      provider: HDFSBackedStateStoreProvider,
+  override def newStoreProvider(): HDFSBackedStateStoreProvider = {
+    newStoreProvider(opId = Random.nextInt(), partition = 0)
+  }
+
+  override def newStoreProvider(storeId: StateStoreId): HDFSBackedStateStoreProvider = {
+    newStoreProvider(storeId.operatorId, storeId.partitionId, dir = storeId.checkpointLocation)
+  }
+
+  override def getLatestData(storeProvider: HDFSBackedStateStoreProvider): Set[(String, Int)] = {
+    getData(storeProvider)
+  }
+
+  override def getData(
+    provider: HDFSBackedStateStoreProvider,
     version: Int = -1): Set[(String, Int)] = {
-    val reloadedProvider = new HDFSBackedStateStoreProvider(
-      provider.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration)
+    val reloadedProvider = newStoreProvider(provider.id)
     if (version < 0) {
       reloadedProvider.latestIterator().map(rowsToStringInt).toSet
     } else {
-      reloadedProvider.iterator(version).map(rowsToStringInt).toSet
+      reloadedProvider.getStore(version).iterator().map(rowsToStringInt).toSet
     }
   }
 
-  def assertMap(
-      testMapOption: Option[MapType],
-      expectedMap: Map[String, Int]): Unit = {
-    assert(testMapOption.nonEmpty, "no map present")
-    val convertedMap = testMapOption.get.map(rowsToStringInt)
-    assert(convertedMap === expectedMap)
+  def newStoreProvider(
+      opId: Long,
+      partition: Int,
+      dir: String = newDir(),
+      minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
+      hadoopConf: Configuration = new Configuration): HDFSBackedStateStoreProvider = {
+    val sqlConf = new SQLConf()
+    sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot)
+    sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
+    val provider = new HDFSBackedStateStoreProvider()
+    provider.init(
+      StateStoreId(dir, opId, partition),
+      keySchema,
+      valueSchema,
+      indexOrdinal = None,
+      new StateStoreConf(sqlConf),
+      hadoopConf)
+    provider
   }
 
   def fileExists(
@@ -622,56 +485,150 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
     filePath.delete()
     filePath.createNewFile()
   }
+}
 
-  def storeLoaded(storeId: StateStoreId): Boolean = {
-    val method = PrivateMethod[mutable.HashMap[StateStoreId, StateStore]]('loadedStores)
-    val loadedStores = StateStore invokePrivate method()
-    loadedStores.contains(storeId)
-  }
+abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
+  extends SparkFunSuite {
+  import StateStoreTestsHelper._
 
-  def unloadStore(storeId: StateStoreId): Boolean = {
-    val method = PrivateMethod('remove)
-    StateStore invokePrivate method(storeId)
-  }
+  test("get, put, remove, commit, and all data iterator") {
+    val provider = newStoreProvider()
 
-  def newStoreProvider(
-      opId: Long = Random.nextLong,
-      partition: Int = 0,
-      minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
-      dir: String = Utils.createDirectory(tempDir, Random.nextString(5)).toString,
-      hadoopConf: Configuration = new Configuration()
-    ): HDFSBackedStateStoreProvider = {
-    val sqlConf = new SQLConf()
-    sqlConf.setConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT, minDeltasForSnapshot)
-    sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
-    new HDFSBackedStateStoreProvider(
-      StateStoreId(dir, opId, partition),
-      keySchema,
-      valueSchema,
-      new StateStoreConf(sqlConf),
-      hadoopConf)
+    // Verify state before starting a new set of updates
+    assert(getLatestData(provider).isEmpty)
+
+    val store = provider.getStore(0)
+    assert(!store.hasCommitted)
+    assert(get(store, "a") === None)
+    assert(store.iterator().isEmpty)
+    assert(store.numKeys() === 0)
+
+    // Verify state after updating
+    put(store, "a", 1)
+    assert(get(store, "a") === Some(1))
+    assert(store.numKeys() === 1)
+
+    assert(store.iterator().nonEmpty)
+    assert(getLatestData(provider).isEmpty)
+
+    // Make updates, commit and then verify state
+    put(store, "b", 2)
+    put(store, "aa", 3)
+    assert(store.numKeys() === 3)
+    remove(store, _.startsWith("a"))
+    assert(store.numKeys() === 1)
+    assert(store.commit() === 1)
+
+    assert(store.hasCommitted)
+    assert(rowsToSet(store.iterator()) === Set("b" -> 2))
+    assert(getLatestData(provider) === Set("b" -> 2))
+
+    // Trying to get newer versions should fail
+    intercept[Exception] {
+      provider.getStore(2)
+    }
+    intercept[Exception] {
+      getData(provider, 2)
+    }
+
+    // New updates to the reloaded store with new version, and does not change old version
+    val reloadedProvider = newStoreProvider(store.id)
+    val reloadedStore = reloadedProvider.getStore(1)
+    assert(reloadedStore.numKeys() === 1)
+    put(reloadedStore, "c", 4)
+    assert(reloadedStore.numKeys() === 2)
+    assert(reloadedStore.commit() === 2)
+    assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
+    assert(getLatestData(provider) === Set("b" -> 2, "c" -> 4))
+    assert(getData(provider, version = 1) === Set("b" -> 2))
   }
 
-  def remove(store: StateStore, condition: String => Boolean): Unit = {
-    store.remove(row => condition(rowToString(row)))
+  test("removing while iterating") {
+    val provider = newStoreProvider()
+
+    // Verify state before starting a new set of updates
+    assert(getLatestData(provider).isEmpty)
+    val store = provider.getStore(0)
+    put(store, "a", 1)
+    put(store, "b", 2)
+
+    // Updates should work while iterating of filtered entries
+    val filtered = store.iterator.filter { tuple => rowToString(tuple.key) == "a" }
+    filtered.foreach { tuple =>
+      store.put(tuple.key, intToRow(rowToInt(tuple.value) + 1))
+    }
+    assert(get(store, "a") === Some(2))
+
+    // Removes should work while iterating of filtered entries
+    val filtered2 = store.iterator.filter { tuple => rowToString(tuple.key) == "b" }
+    filtered2.foreach { tuple => store.remove(tuple.key) }
+    assert(get(store, "b") === None)
   }
 
-  private def put(store: StateStore, key: String, value: Int): Unit = {
-    store.put(stringToRow(key), intToRow(value))
+  test("abort") {
+    val provider = newStoreProvider()
+    val store = provider.getStore(0)
+    put(store, "a", 1)
+    store.commit()
+    assert(rowsToSet(store.iterator()) === Set("a" -> 1))
+
+    // cancelUpdates should not change the data in the files
+    val store1 = provider.getStore(1)
+    put(store1, "b", 1)
+    store1.abort()
   }
 
-  private def get(store: StateStore, key: String): Option[Int] = {
-    store.get(stringToRow(key)).map(rowToInt)
+  test("getStore with invalid versions") {
+    val provider = newStoreProvider()
+
+    def checkInvalidVersion(version: Int): Unit = {
+      intercept[Exception] {
+        provider.getStore(version)
+      }
+    }
+
+    checkInvalidVersion(-1)
+    checkInvalidVersion(1)
+
+    val store = provider.getStore(0)
+    put(store, "a", 1)
+    assert(store.commit() === 1)
+    assert(rowsToSet(store.iterator()) === Set("a" -> 1))
+
+    val store1_ = provider.getStore(1)
+    assert(rowsToSet(store1_.iterator()) === Set("a" -> 1))
+
+    checkInvalidVersion(-1)
+    checkInvalidVersion(2)
+
+    // Update store version with some data
+    val store1 = provider.getStore(1)
+    assert(rowsToSet(store1.iterator()) === Set("a" -> 1))
+    put(store1, "b", 1)
+    assert(store1.commit() === 2)
+    assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1))
+
+    checkInvalidVersion(-1)
+    checkInvalidVersion(3)
   }
-}
 
-private[state] object StateStoreSuite {
+  /** Return a new provider with a random id */
+  def newStoreProvider(): ProviderClass
+
+  /** Return a new provider with the given id */
+  def newStoreProvider(storeId: StateStoreId): ProviderClass
+
+  /** Get the latest data referred to by the given provider but not using this provider */
+  def getLatestData(storeProvider: ProviderClass): Set[(String, Int)]
+
+  /**
+   * Get a specific version of data referred to by the given provider but not using
+   * this provider
+   */
+  def getData(storeProvider: ProviderClass, version: Int): Set[(String, Int)]
+}
 
-  /** Trait and classes mirroring [[StoreUpdate]] for testing store updates iterator */
-  trait TestUpdate
-  case class Added(key: String, value: Int) extends TestUpdate
-  case class Updated(key: String, value: Int) extends TestUpdate
-  case class Removed(key: String) extends TestUpdate
+object StateStoreTestsHelper {
 
   val strProj = UnsafeProjection.create(Array[DataType](StringType))
   val intProj = UnsafeProjection.create(Array[DataType](IntegerType))
@@ -692,26 +649,29 @@ private[state] object StateStoreSuite {
     row.getInt(0)
   }
 
-  def rowsToIntInt(row: (UnsafeRow, UnsafeRow)): (Int, Int) = {
-    (rowToInt(row._1), rowToInt(row._2))
+  def rowsToStringInt(row: UnsafeRowPair): (String, Int) = {
+    (rowToString(row.key), rowToInt(row.value))
   }
 
+  def rowsToSet(iterator: Iterator[UnsafeRowPair]): Set[(String, Int)] = {
+    iterator.map(rowsToStringInt).toSet
+  }
 
-  def rowsToStringInt(row: (UnsafeRow, UnsafeRow)): (String, Int) = {
-    (rowToString(row._1), rowToInt(row._2))
+  def remove(store: StateStore, condition: String => Boolean): Unit = {
+    store.getRange(None, None).foreach { rowPair =>
+      if (condition(rowToString(rowPair.key))) store.remove(rowPair.key)
+    }
   }
 
-  def rowsToSet(iterator: Iterator[(UnsafeRow, UnsafeRow)]): Set[(String, Int)] = {
-    iterator.map(rowsToStringInt).toSet
+  def put(store: StateStore, key: String, value: Int): Unit = {
+    store.put(stringToRow(key), intToRow(value))
   }
 
-  def updatesToSet(iterator: Iterator[StoreUpdate]): Set[TestUpdate] = {
-    iterator.map {
-      case ValueAdded(key, value) => Added(rowToString(key), rowToInt(value))
-      case ValueUpdated(key, value) => Updated(rowToString(key), rowToInt(value))
-      case ValueRemoved(key, _) => Removed(rowToString(key))
-    }.toSet
+  def get(store: StateStore, key: String): Option[Int] = {
+    Option(store.get(stringToRow(key))).map(rowToInt)
   }
+
+  def newDir(): String = Utils.createTempDir().toString
 }
 
 /**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 6bb9408ce99ed28f0a6e3b7df4ad119393743258..0d9ca81349be5f26e1636b910ad79df74e63d914 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.execution.RDDScanExec
 import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream}
-import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StoreUpdate}
+import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, UnsafeRowPair}
 import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore
 import org.apache.spark.sql.streaming.util.StreamManualClock
 import org.apache.spark.sql.types.{DataType, IntegerType}
@@ -508,22 +508,6 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
     expectedState = Some(5),                                  // state should change
     expectedTimeoutTimestamp = 5000)                          // timestamp should change
 
-  test("StateStoreUpdater - rows are cloned before writing to StateStore") {
-    // function for running count
-    val func = (key: Int, values: Iterator[Int], state: GroupState[Int]) => {
-      state.update(state.getOption.getOrElse(0) + values.size)
-      Iterator.empty
-    }
-    val store = newStateStore()
-    val plan = newFlatMapGroupsWithStateExec(func)
-    val updater = new plan.StateStoreUpdater(store)
-    val data = Seq(1, 1, 2)
-    val returnIter = updater.updateStateForKeysWithData(data.iterator.map(intToRow))
-    returnIter.size // consume the iterator to force store updates
-    val storeData = store.iterator.map { case (k, v) => (rowToInt(k), rowToInt(v)) }.toSet
-    assert(storeData === Set((1, 2), (2, 1)))
-  }
-
   test("flatMapGroupsWithState - streaming") {
     // Function to maintain running count up to 2, and then remove the count
     // Returns the data and the count if state is defined, otherwise does not return anything
@@ -1016,11 +1000,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
       callFunction()
       val updatedStateRow = store.get(key)
       assert(
-        updater.getStateObj(updatedStateRow).map(_.toString.toInt) === expectedState,
+        Option(updater.getStateObj(updatedStateRow)).map(_.toString.toInt) === expectedState,
         "final state not as expected")
-      if (updatedStateRow.nonEmpty) {
+      if (updatedStateRow != null) {
         assert(
-          updater.getTimeoutTimestamp(updatedStateRow.get) === expectedTimeoutTimestamp,
+          updater.getTimeoutTimestamp(updatedStateRow) === expectedTimeoutTimestamp,
           "final timeout timestamp not as expected")
       }
     }
@@ -1080,25 +1064,19 @@ object FlatMapGroupsWithStateSuite {
     import scala.collection.JavaConverters._
     private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow]
 
-    override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = {
-      map.entrySet.iterator.asScala.map { case e => (e.getKey, e.getValue) }
+    override def iterator(): Iterator[UnsafeRowPair] = {
+      map.entrySet.iterator.asScala.map { case e => new UnsafeRowPair(e.getKey, e.getValue) }
     }
 
-    override def filter(c: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = {
-      iterator.filter { case (k, v) => c(k, v) }
+    override def get(key: UnsafeRow): UnsafeRow = map.get(key)
+    override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = {
+      map.put(key.copy(), newValue.copy())
     }
-
-    override def get(key: UnsafeRow): Option[UnsafeRow] = Option(map.get(key))
-    override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = map.put(key, newValue)
     override def remove(key: UnsafeRow): Unit = { map.remove(key) }
-    override def remove(condition: (UnsafeRow) => Boolean): Unit = {
-      iterator.map(_._1).filter(condition).foreach(map.remove)
-    }
     override def commit(): Long = version + 1
     override def abort(): Unit = { }
     override def id: StateStoreId = null
     override def version: Long = 0
-    override def updates(): Iterator[StoreUpdate] = { throw new UnsupportedOperationException }
     override def numKeys(): Long = map.size
     override def hasCommitted: Boolean = true
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index 1fc062974e185a9518d53ec896abd27a151ee34f..280f2dc27b4a7d7df4d38a914558ec3a9f88dd52 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -24,6 +24,7 @@ import scala.reflect.ClassTag
 import scala.util.control.ControlThrowable
 
 import org.apache.commons.io.FileUtils
+import org.apache.hadoop.conf.Configuration
 
 import org.apache.spark.SparkContext
 import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
@@ -31,6 +32,7 @@ import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
 import org.apache.spark.sql.execution.command.ExplainCommand
 import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources.StreamSourceProvider
@@ -614,6 +616,30 @@ class StreamSuite extends StreamTest {
     assertDescContainsQueryNameAnd(batch = 2)
     query.stop()
   }
+
+  testQuietly("specify custom state store provider") {
+    val queryName = "memStream"
+    val providerClassName = classOf[TestStateStoreProvider].getCanonicalName
+    withSQLConf("spark.sql.streaming.stateStore.providerClass" -> providerClassName) {
+      val input = MemoryStream[Int]
+      val query = input
+        .toDS()
+        .groupBy()
+        .count()
+        .writeStream
+        .outputMode("complete")
+        .format("memory")
+        .queryName(queryName)
+        .start()
+      input.addData(1, 2, 3)
+      val e = intercept[Exception] {
+        query.awaitTermination()
+      }
+
+      assert(e.getMessage.contains(providerClassName))
+      assert(e.getMessage.contains("instantiated"))
+    }
+  }
 }
 
 abstract class FakeSource extends StreamSourceProvider {
@@ -719,3 +745,22 @@ object ThrowingInterruptedIOException {
    */
   @volatile var createSourceLatch: CountDownLatch = null
 }
+
+class TestStateStoreProvider extends StateStoreProvider {
+
+  override def init(
+      stateStoreId: StateStoreId,
+      keySchema: StructType,
+      valueSchema: StructType,
+      indexOrdinal: Option[Int],
+      storeConfs: StateStoreConf,
+      hadoopConf: Configuration): Unit = {
+    throw new Exception("Successfully instantiated")
+  }
+
+  override def id: StateStoreId = null
+
+  override def close(): Unit = { }
+
+  override def getStore(version: Long): StateStore = null
+}