diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
index aa789af6f812fb10b09ac940ddf860a16e6d7e42..12f8cffb6774ad1ecdd38d30702072fc9dcb8451 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
@@ -311,7 +311,7 @@ object AggUtils {
     val saved =
       StateStoreSaveExec(
         groupingAttributes,
-        stateId = None,
+        stateInfo = None,
         outputMode = None,
         eventTimeWatermark = None,
         partialMerged2)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
index 2d82fcf4da6e9222064683d28d0464d8a9da92e2..81bc93e7ebcf405894d77f2552ebc5cae36f0887 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.command
 
+import java.util.UUID
+
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{Row, SparkSession}
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
@@ -117,7 +119,8 @@ case class ExplainCommand(
         // This is used only by explaining `Dataset/DataFrame` created by `spark.readStream`, so the
         // output mode does not matter since there is no `Sink`.
         new IncrementalExecution(
-          sparkSession, logicalPlan, OutputMode.Append(), "<unknown>", 0, OffsetSeqMetadata(0, 0))
+          sparkSession, logicalPlan, OutputMode.Append(), "<unknown>",
+          UUID.randomUUID, 0, OffsetSeqMetadata(0, 0))
       } else {
         sparkSession.sessionState.executePlan(logicalPlan)
       }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 2aad8701a4eca7a799a0507f058bcf2f510381b7..9dcac33b4107c43034c7321ce7b23eec80ceadc4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -50,7 +50,7 @@ case class FlatMapGroupsWithStateExec(
     groupingAttributes: Seq[Attribute],
     dataAttributes: Seq[Attribute],
     outputObjAttr: Attribute,
-    stateId: Option[OperatorStateId],
+    stateInfo: Option[StatefulOperatorStateInfo],
     stateEncoder: ExpressionEncoder[Any],
     outputMode: OutputMode,
     timeoutConf: GroupStateTimeout,
@@ -107,10 +107,7 @@ case class FlatMapGroupsWithStateExec(
     }
 
     child.execute().mapPartitionsWithStateStore[InternalRow](
-      getStateId.checkpointLocation,
-      getStateId.operatorId,
-      storeName = "default",
-      getStateId.batchId,
+      getStateInfo,
       groupingAttributes.toStructType,
       stateAttributes.toStructType,
       indexOrdinal = None,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 622e049630db26d809b20070c043b4bffdfda919..ab89dc6b705d53be1c0f4091f95d4d88ed6d5c5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution.streaming
 
+import java.util.UUID
 import java.util.concurrent.atomic.AtomicInteger
 
 import org.apache.spark.internal.Logging
@@ -36,6 +37,7 @@ class IncrementalExecution(
     logicalPlan: LogicalPlan,
     val outputMode: OutputMode,
     val checkpointLocation: String,
+    val runId: UUID,
     val currentBatchId: Long,
     offsetSeqMetadata: OffsetSeqMetadata)
   extends QueryExecution(sparkSession, logicalPlan) with Logging {
@@ -69,7 +71,13 @@ class IncrementalExecution(
    * Records the current id for a given stateful operator in the query plan as the `state`
    * preparation walks the query plan.
    */
-  private val operatorId = new AtomicInteger(0)
+  private val statefulOperatorId = new AtomicInteger(0)
+
+  /** Get the state info of the next stateful operator */
+  private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = {
+    StatefulOperatorStateInfo(
+      checkpointLocation, runId, statefulOperatorId.getAndIncrement(), currentBatchId)
+  }
 
   /** Locates save/restore pairs surrounding aggregation. */
   val state = new Rule[SparkPlan] {
@@ -78,35 +86,28 @@ class IncrementalExecution(
       case StateStoreSaveExec(keys, None, None, None,
              UnaryExecNode(agg,
                StateStoreRestoreExec(keys2, None, child))) =>
-        val stateId =
-          OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
-
+        val aggStateInfo = nextStatefulOperationStateInfo
         StateStoreSaveExec(
           keys,
-          Some(stateId),
+          Some(aggStateInfo),
           Some(outputMode),
           Some(offsetSeqMetadata.batchWatermarkMs),
           agg.withNewChildren(
             StateStoreRestoreExec(
               keys,
-              Some(stateId),
+              Some(aggStateInfo),
               child) :: Nil))
 
       case StreamingDeduplicateExec(keys, child, None, None) =>
-        val stateId =
-          OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
-
         StreamingDeduplicateExec(
           keys,
           child,
-          Some(stateId),
+          Some(nextStatefulOperationStateInfo),
           Some(offsetSeqMetadata.batchWatermarkMs))
 
       case m: FlatMapGroupsWithStateExec =>
-        val stateId =
-          OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId)
         m.copy(
-          stateId = Some(stateId),
+          stateInfo = Some(nextStatefulOperationStateInfo),
           batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs),
           eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs))
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 74f0f509bbf856d673cd2af8b760e70c2a8d2aaa..06bdec8b064079ce19b20f25d4a41ccd8f4ab65e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -652,6 +652,7 @@ class StreamExecution(
         triggerLogicalPlan,
         outputMode,
         checkpointFile("state"),
+        runId,
         currentBatchId,
         offsetSeqMetadata)
       lastExecution.executedPlan // Force the lazy generation of execution plan
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index 67d86daf10812e668d0163b65d3af75110305962..bae7a15165e4349023ba7574323611425f527bcf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -92,7 +92,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
     @volatile private var state: STATE = UPDATING
     @volatile private var finalDeltaFile: Path = null
 
-    override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id
+    override def id: StateStoreId = HDFSBackedStateStoreProvider.this.stateStoreId
 
     override def get(key: UnsafeRow): UnsafeRow = {
       mapToUpdate.get(key)
@@ -177,7 +177,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
     /**
      * Whether all updates have been committed
      */
-    override private[streaming] def hasCommitted: Boolean = {
+    override def hasCommitted: Boolean = {
       state == COMMITTED
     }
 
@@ -205,7 +205,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
       indexOrdinal: Option[Int], // for sorting the data
       storeConf: StateStoreConf,
       hadoopConf: Configuration): Unit = {
-    this.stateStoreId = stateStoreId
+    this.stateStoreId_ = stateStoreId
     this.keySchema = keySchema
     this.valueSchema = valueSchema
     this.storeConf = storeConf
@@ -213,7 +213,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
     fs.mkdirs(baseDir)
   }
 
-  override def id: StateStoreId = stateStoreId
+  override def stateStoreId: StateStoreId = stateStoreId_
 
   /** Do maintenance backing data files, including creating snapshots and cleaning up old files */
   override def doMaintenance(): Unit = {
@@ -231,20 +231,20 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
   }
 
   override def toString(): String = {
-    s"HDFSStateStoreProvider[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]"
+    s"HDFSStateStoreProvider[" +
+      s"id = (op=${stateStoreId.operatorId},part=${stateStoreId.partitionId}),dir = $baseDir]"
   }
 
   /* Internal fields and methods */
 
-  @volatile private var stateStoreId: StateStoreId = _
+  @volatile private var stateStoreId_ : StateStoreId = _
   @volatile private var keySchema: StructType = _
   @volatile private var valueSchema: StructType = _
   @volatile private var storeConf: StateStoreConf = _
   @volatile private var hadoopConf: Configuration = _
 
   private lazy val loadedMaps = new mutable.HashMap[Long, MapType]
-  private lazy val baseDir =
-    new Path(id.checkpointLocation, s"${id.operatorId}/${id.partitionId.toString}")
+  private lazy val baseDir = stateStoreId.storeCheckpointLocation()
   private lazy val fs = baseDir.getFileSystem(hadoopConf)
   private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 29c456f86e1edcfd65236092bd4b3b592910d26f..a94ff8a7ebd1e67e3f60445655c64f0d0e5c708a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import java.util.UUID
 import java.util.concurrent.{ScheduledFuture, TimeUnit}
 import javax.annotation.concurrent.GuardedBy
 
@@ -24,14 +25,14 @@ import scala.collection.mutable
 import scala.util.control.NonFatal
 
 import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
 
-import org.apache.spark.SparkEnv
+import org.apache.spark.{SparkContext, SparkEnv}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.{ThreadUtils, Utils}
 
-
 /**
  * Base trait for a versioned key-value store. Each instance of a `StateStore` represents a specific
  * version of state data, and such instances are created through a [[StateStoreProvider]].
@@ -99,7 +100,7 @@ trait StateStore {
   /**
    * Whether all updates have been committed
    */
-  private[streaming] def hasCommitted: Boolean
+  def hasCommitted: Boolean
 }
 
 
@@ -147,7 +148,7 @@ trait StateStoreProvider {
    * Return the id of the StateStores this provider will generate.
    * Should be the same as the one passed in init().
    */
-  def id: StateStoreId
+  def stateStoreId: StateStoreId
 
   /** Called when the provider instance is unloaded from the executor */
   def close(): Unit
@@ -179,13 +180,46 @@ object StateStoreProvider {
   }
 }
 
+/**
+ * Unique identifier for a provider, used to identify when providers can be reused.
+ * Note that `queryRunId` is used uniquely identify a provider, so that the same provider
+ * instance is not reused across query restarts.
+ */
+case class StateStoreProviderId(storeId: StateStoreId, queryRunId: UUID)
 
-/** Unique identifier for a bunch of keyed state data. */
+/**
+ * Unique identifier for a bunch of keyed state data.
+ * @param checkpointRootLocation Root directory where all the state data of a query is stored
+ * @param operatorId Unique id of a stateful operator
+ * @param partitionId Index of the partition of an operators state data
+ * @param storeName Optional, name of the store. Each partition can optionally use multiple state
+ *                  stores, but they have to be identified by distinct names.
+ */
 case class StateStoreId(
-    checkpointLocation: String,
+    checkpointRootLocation: String,
     operatorId: Long,
     partitionId: Int,
-    name: String = "")
+    storeName: String = StateStoreId.DEFAULT_STORE_NAME) {
+
+  /**
+   * Checkpoint directory to be used by a single state store, identified uniquely by the tuple
+   * (operatorId, partitionId, storeName). All implementations of [[StateStoreProvider]] should
+   * use this path for saving state data, as this ensures that distinct stores will write to
+   * different locations.
+   */
+  def storeCheckpointLocation(): Path = {
+    if (storeName == StateStoreId.DEFAULT_STORE_NAME) {
+      // For reading state store data that was generated before store names were used (Spark <= 2.2)
+      new Path(checkpointRootLocation, s"$operatorId/$partitionId")
+    } else {
+      new Path(checkpointRootLocation, s"$operatorId/$partitionId/$storeName")
+    }
+  }
+}
+
+object StateStoreId {
+  val DEFAULT_STORE_NAME = "default"
+}
 
 /** Mutable, and reusable class for representing a pair of UnsafeRows. */
 class UnsafeRowPair(var key: UnsafeRow = null, var value: UnsafeRow = null) {
@@ -211,7 +245,7 @@ object StateStore extends Logging {
   val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60
 
   @GuardedBy("loadedProviders")
-  private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]()
+  private val loadedProviders = new mutable.HashMap[StateStoreProviderId, StateStoreProvider]()
 
   /**
    * Runs the `task` periodically and automatically cancels it if there is an exception. `onError`
@@ -253,7 +287,7 @@ object StateStore extends Logging {
 
   /** Get or create a store associated with the id. */
   def get(
-      storeId: StateStoreId,
+      storeProviderId: StateStoreProviderId,
       keySchema: StructType,
       valueSchema: StructType,
       indexOrdinal: Option[Int],
@@ -264,24 +298,24 @@ object StateStore extends Logging {
     val storeProvider = loadedProviders.synchronized {
       startMaintenanceIfNeeded()
       val provider = loadedProviders.getOrElseUpdate(
-        storeId,
+        storeProviderId,
         StateStoreProvider.instantiate(
-          storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
+          storeProviderId.storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
       )
-      reportActiveStoreInstance(storeId)
+      reportActiveStoreInstance(storeProviderId)
       provider
     }
     storeProvider.getStore(version)
   }
 
   /** Unload a state store provider */
-  def unload(storeId: StateStoreId): Unit = loadedProviders.synchronized {
-    loadedProviders.remove(storeId).foreach(_.close())
+  def unload(storeProviderId: StateStoreProviderId): Unit = loadedProviders.synchronized {
+    loadedProviders.remove(storeProviderId).foreach(_.close())
   }
 
   /** Whether a state store provider is loaded or not */
-  def isLoaded(storeId: StateStoreId): Boolean = loadedProviders.synchronized {
-    loadedProviders.contains(storeId)
+  def isLoaded(storeProviderId: StateStoreProviderId): Boolean = loadedProviders.synchronized {
+    loadedProviders.contains(storeProviderId)
   }
 
   def isMaintenanceRunning: Boolean = loadedProviders.synchronized {
@@ -340,21 +374,21 @@ object StateStore extends Logging {
     }
   }
 
-  private def reportActiveStoreInstance(storeId: StateStoreId): Unit = {
+  private def reportActiveStoreInstance(storeProviderId: StateStoreProviderId): Unit = {
     if (SparkEnv.get != null) {
       val host = SparkEnv.get.blockManager.blockManagerId.host
       val executorId = SparkEnv.get.blockManager.blockManagerId.executorId
-      coordinatorRef.foreach(_.reportActiveInstance(storeId, host, executorId))
-      logDebug(s"Reported that the loaded instance $storeId is active")
+      coordinatorRef.foreach(_.reportActiveInstance(storeProviderId, host, executorId))
+      logInfo(s"Reported that the loaded instance $storeProviderId is active")
     }
   }
 
-  private def verifyIfStoreInstanceActive(storeId: StateStoreId): Boolean = {
+  private def verifyIfStoreInstanceActive(storeProviderId: StateStoreProviderId): Boolean = {
     if (SparkEnv.get != null) {
       val executorId = SparkEnv.get.blockManager.blockManagerId.executorId
       val verified =
-        coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false)
-      logDebug(s"Verified whether the loaded instance $storeId is active: $verified")
+        coordinatorRef.map(_.verifyIfInstanceActive(storeProviderId, executorId)).getOrElse(false)
+      logDebug(s"Verified whether the loaded instance $storeProviderId is active: $verified")
       verified
     } else {
       false
@@ -364,12 +398,21 @@ object StateStore extends Logging {
   private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized {
     val env = SparkEnv.get
     if (env != null) {
-      if (_coordRef == null) {
+      logInfo("Env is not null")
+      val isDriver =
+        env.executorId == SparkContext.DRIVER_IDENTIFIER ||
+          env.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER
+      // If running locally, then the coordinator reference in _coordRef may be have become inactive
+      // as SparkContext + SparkEnv may have been restarted. Hence, when running in driver,
+      // always recreate the reference.
+      if (isDriver || _coordRef == null) {
+        logInfo("Getting StateStoreCoordinatorRef")
         _coordRef = StateStoreCoordinatorRef.forExecutor(env)
       }
-      logDebug(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}")
+      logInfo(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}")
       Some(_coordRef)
     } else {
+      logInfo("Env is null")
       _coordRef = null
       None
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
index d0f81887e62d16eca89f0a809f9f05ab35f437a0..3884f5e6ce766af84d4be5fdb426daf1c2db5209 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import java.util.UUID
+
 import scala.collection.mutable
 
 import org.apache.spark.SparkEnv
@@ -29,16 +31,19 @@ import org.apache.spark.util.RpcUtils
 private sealed trait StateStoreCoordinatorMessage extends Serializable
 
 /** Classes representing messages */
-private case class ReportActiveInstance(storeId: StateStoreId, host: String, executorId: String)
+private case class ReportActiveInstance(
+    storeId: StateStoreProviderId,
+    host: String,
+    executorId: String)
   extends StateStoreCoordinatorMessage
 
-private case class VerifyIfInstanceActive(storeId: StateStoreId, executorId: String)
+private case class VerifyIfInstanceActive(storeId: StateStoreProviderId, executorId: String)
   extends StateStoreCoordinatorMessage
 
-private case class GetLocation(storeId: StateStoreId)
+private case class GetLocation(storeId: StateStoreProviderId)
   extends StateStoreCoordinatorMessage
 
-private case class DeactivateInstances(checkpointLocation: String)
+private case class DeactivateInstances(runId: UUID)
   extends StateStoreCoordinatorMessage
 
 private object StopCoordinator
@@ -80,25 +85,27 @@ object StateStoreCoordinatorRef extends Logging {
 class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) {
 
   private[state] def reportActiveInstance(
-      storeId: StateStoreId,
+      stateStoreProviderId: StateStoreProviderId,
       host: String,
       executorId: String): Unit = {
-    rpcEndpointRef.send(ReportActiveInstance(storeId, host, executorId))
+    rpcEndpointRef.send(ReportActiveInstance(stateStoreProviderId, host, executorId))
   }
 
   /** Verify whether the given executor has the active instance of a state store */
-  private[state] def verifyIfInstanceActive(storeId: StateStoreId, executorId: String): Boolean = {
-    rpcEndpointRef.askSync[Boolean](VerifyIfInstanceActive(storeId, executorId))
+  private[state] def verifyIfInstanceActive(
+      stateStoreProviderId: StateStoreProviderId,
+      executorId: String): Boolean = {
+    rpcEndpointRef.askSync[Boolean](VerifyIfInstanceActive(stateStoreProviderId, executorId))
   }
 
   /** Get the location of the state store */
-  private[state] def getLocation(storeId: StateStoreId): Option[String] = {
-    rpcEndpointRef.askSync[Option[String]](GetLocation(storeId))
+  private[state] def getLocation(stateStoreProviderId: StateStoreProviderId): Option[String] = {
+    rpcEndpointRef.askSync[Option[String]](GetLocation(stateStoreProviderId))
   }
 
-  /** Deactivate instances related to a set of operator */
-  private[state] def deactivateInstances(storeRootLocation: String): Unit = {
-    rpcEndpointRef.askSync[Boolean](DeactivateInstances(storeRootLocation))
+  /** Deactivate instances related to a query */
+  private[sql] def deactivateInstances(runId: UUID): Unit = {
+    rpcEndpointRef.askSync[Boolean](DeactivateInstances(runId))
   }
 
   private[state] def stop(): Unit = {
@@ -113,7 +120,7 @@ class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) {
  */
 private class StateStoreCoordinator(override val rpcEnv: RpcEnv)
     extends ThreadSafeRpcEndpoint with Logging {
-  private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation]
+  private val instances = new mutable.HashMap[StateStoreProviderId, ExecutorCacheTaskLocation]
 
   override def receive: PartialFunction[Any, Unit] = {
     case ReportActiveInstance(id, host, executorId) =>
@@ -135,11 +142,11 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv)
       logDebug(s"Got location of the state store $id: $executorId")
       context.reply(executorId)
 
-    case DeactivateInstances(checkpointLocation) =>
+    case DeactivateInstances(runId) =>
       val storeIdsToRemove =
-        instances.keys.filter(_.checkpointLocation == checkpointLocation).toSeq
+        instances.keys.filter(_.queryRunId == runId).toSeq
       instances --= storeIdsToRemove
-      logDebug(s"Deactivating instances related to checkpoint location $checkpointLocation: " +
+      logDebug(s"Deactivating instances related to checkpoint location $runId: " +
         storeIdsToRemove.mkString(", "))
       context.reply(true)
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
index b744c25dc97a81e4ac741447cffb41c5ebd5e54a..01d8e759809937b8e372b16364f684fb3f71034c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import java.util.UUID
+
 import scala.reflect.ClassTag
 
 import org.apache.spark.{Partition, TaskContext}
@@ -34,8 +36,8 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
     dataRDD: RDD[T],
     storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U],
     checkpointLocation: String,
+    queryRunId: UUID,
     operatorId: Long,
-    storeName: String,
     storeVersion: Long,
     keySchema: StructType,
     valueSchema: StructType,
@@ -52,16 +54,25 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
 
   override protected def getPartitions: Array[Partition] = dataRDD.partitions
 
+  /**
+   * Set the preferred location of each partition using the executor that has the related
+   * [[StateStoreProvider]] already loaded.
+   */
   override def getPreferredLocations(partition: Partition): Seq[String] = {
-    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName)
-    storeCoordinator.flatMap(_.getLocation(storeId)).toSeq
+    val stateStoreProviderId = StateStoreProviderId(
+      StateStoreId(checkpointLocation, operatorId, partition.index),
+      queryRunId)
+    storeCoordinator.flatMap(_.getLocation(stateStoreProviderId)).toSeq
   }
 
   override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = {
     var store: StateStore = null
-    val storeId = StateStoreId(checkpointLocation, operatorId, partition.index, storeName)
+    val storeProviderId = StateStoreProviderId(
+      StateStoreId(checkpointLocation, operatorId, partition.index),
+      queryRunId)
+
     store = StateStore.get(
-      storeId, keySchema, valueSchema, indexOrdinal, storeVersion,
+      storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion,
       storeConf, hadoopConfBroadcast.value.value)
     val inputIter = dataRDD.iterator(partition, ctxt)
     storeUpdateFunction(store, inputIter)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index 228fe86d59940416509f41e342a08fb36aeb0220..a0086e251f9c682231dc71f25d8347dc6ebe350e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.streaming
 
+import java.util.UUID
+
 import scala.reflect.ClassTag
 
 import org.apache.spark.TaskContext
@@ -32,20 +34,14 @@ package object state {
     /** Map each partition of an RDD along with data in a [[StateStore]]. */
     def mapPartitionsWithStateStore[U: ClassTag](
         sqlContext: SQLContext,
-        checkpointLocation: String,
-        operatorId: Long,
-        storeName: String,
-        storeVersion: Long,
+        stateInfo: StatefulOperatorStateInfo,
         keySchema: StructType,
         valueSchema: StructType,
         indexOrdinal: Option[Int])(
         storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
 
       mapPartitionsWithStateStore(
-        checkpointLocation,
-        operatorId,
-        storeName,
-        storeVersion,
+        stateInfo,
         keySchema,
         valueSchema,
         indexOrdinal,
@@ -56,10 +52,7 @@ package object state {
 
     /** Map each partition of an RDD along with data in a [[StateStore]]. */
     private[streaming] def mapPartitionsWithStateStore[U: ClassTag](
-        checkpointLocation: String,
-        operatorId: Long,
-        storeName: String,
-        storeVersion: Long,
+        stateInfo: StatefulOperatorStateInfo,
         keySchema: StructType,
         valueSchema: StructType,
         indexOrdinal: Option[Int],
@@ -79,10 +72,10 @@ package object state {
       new StateStoreRDD(
         dataRDD,
         wrappedF,
-        checkpointLocation,
-        operatorId,
-        storeName,
-        storeVersion,
+        stateInfo.checkpointLocation,
+        stateInfo.queryRunId,
+        stateInfo.operatorId,
+        stateInfo.storeVersion,
         keySchema,
         valueSchema,
         indexOrdinal,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index 3e57f3fbada32ecd1cd93efe323a043aef38573b..c5722466a33af3f8b019181be9d0f8ad345008ae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution.streaming
 
+import java.util.UUID
 import java.util.concurrent.TimeUnit._
 
 import org.apache.spark.rdd.RDD
@@ -36,20 +37,22 @@ import org.apache.spark.util.{CompletionIterator, NextIterator}
 
 
 /** Used to identify the state store for a given operator. */
-case class OperatorStateId(
+case class StatefulOperatorStateInfo(
     checkpointLocation: String,
+    queryRunId: UUID,
     operatorId: Long,
-    batchId: Long)
+    storeVersion: Long)
 
 /**
- * An operator that reads or writes state from the [[StateStore]].  The [[OperatorStateId]] should
- * be filled in by `prepareForExecution` in [[IncrementalExecution]].
+ * An operator that reads or writes state from the [[StateStore]].
+ * The [[StatefulOperatorStateInfo]] should be filled in by `prepareForExecution` in
+ * [[IncrementalExecution]].
  */
 trait StatefulOperator extends SparkPlan {
-  def stateId: Option[OperatorStateId]
+  def stateInfo: Option[StatefulOperatorStateInfo]
 
-  protected def getStateId: OperatorStateId = attachTree(this) {
-    stateId.getOrElse {
+  protected def getStateInfo: StatefulOperatorStateInfo = attachTree(this) {
+    stateInfo.getOrElse {
       throw new IllegalStateException("State location not present for execution")
     }
   }
@@ -140,7 +143,7 @@ trait WatermarkSupport extends UnaryExecNode {
  */
 case class StateStoreRestoreExec(
     keyExpressions: Seq[Attribute],
-    stateId: Option[OperatorStateId],
+    stateInfo: Option[StatefulOperatorStateInfo],
     child: SparkPlan)
   extends UnaryExecNode with StateStoreReader {
 
@@ -148,10 +151,7 @@ case class StateStoreRestoreExec(
     val numOutputRows = longMetric("numOutputRows")
 
     child.execute().mapPartitionsWithStateStore(
-      getStateId.checkpointLocation,
-      operatorId = getStateId.operatorId,
-      storeName = "default",
-      storeVersion = getStateId.batchId,
+      getStateInfo,
       keyExpressions.toStructType,
       child.output.toStructType,
       indexOrdinal = None,
@@ -177,7 +177,7 @@ case class StateStoreRestoreExec(
  */
 case class StateStoreSaveExec(
     keyExpressions: Seq[Attribute],
-    stateId: Option[OperatorStateId] = None,
+    stateInfo: Option[StatefulOperatorStateInfo] = None,
     outputMode: Option[OutputMode] = None,
     eventTimeWatermark: Option[Long] = None,
     child: SparkPlan)
@@ -189,10 +189,7 @@ case class StateStoreSaveExec(
       "Incorrect planning in IncrementalExecution, outputMode has not been set")
 
     child.execute().mapPartitionsWithStateStore(
-      getStateId.checkpointLocation,
-      getStateId.operatorId,
-      storeName = "default",
-      getStateId.batchId,
+      getStateInfo,
       keyExpressions.toStructType,
       child.output.toStructType,
       indexOrdinal = None,
@@ -319,7 +316,7 @@ case class StateStoreSaveExec(
 case class StreamingDeduplicateExec(
     keyExpressions: Seq[Attribute],
     child: SparkPlan,
-    stateId: Option[OperatorStateId] = None,
+    stateInfo: Option[StatefulOperatorStateInfo] = None,
     eventTimeWatermark: Option[Long] = None)
   extends UnaryExecNode with StateStoreWriter with WatermarkSupport {
 
@@ -331,10 +328,7 @@ case class StreamingDeduplicateExec(
     metrics // force lazy init at driver
 
     child.execute().mapPartitionsWithStateStore(
-      getStateId.checkpointLocation,
-      getStateId.operatorId,
-      storeName = "default",
-      getStateId.batchId,
+      getStateInfo,
       keyExpressions.toStructType,
       child.output.toStructType,
       indexOrdinal = None,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
index 002c45413b4c2949b38235fcb50257d9f94c389f..48b0ea20e5da136d4114542ba4dce21e0ff656fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
@@ -332,5 +332,6 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
       }
       awaitTerminationLock.notifyAll()
     }
+    stateStoreCoordinator.deactivateInstances(terminatedQuery.runId)
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala
index a7e32626264cc954c5229ec176d354006fbdb785..9a7595eee7bd02172c21850b8e729fabf4e8627f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala
@@ -17,11 +17,17 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import java.util.UUID
+
 import org.scalatest.concurrent.Eventually._
 import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite}
 import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper}
+import org.apache.spark.sql.functions.count
+import org.apache.spark.util.Utils
 
 class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext {
 
@@ -29,7 +35,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext {
 
   test("report, verify, getLocation") {
     withCoordinatorRef(sc) { coordinatorRef =>
-      val id = StateStoreId("x", 0, 0)
+      val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID)
 
       assert(coordinatorRef.verifyIfInstanceActive(id, "exec1") === false)
       assert(coordinatorRef.getLocation(id) === None)
@@ -57,9 +63,11 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext {
 
   test("make inactive") {
     withCoordinatorRef(sc) { coordinatorRef =>
-      val id1 = StateStoreId("x", 0, 0)
-      val id2 = StateStoreId("y", 1, 0)
-      val id3 = StateStoreId("x", 0, 1)
+      val runId1 = UUID.randomUUID
+      val runId2 = UUID.randomUUID
+      val id1 = StateStoreProviderId(StateStoreId("x", 0, 0), runId1)
+      val id2 = StateStoreProviderId(StateStoreId("y", 1, 0), runId2)
+      val id3 = StateStoreProviderId(StateStoreId("x", 0, 1), runId1)
       val host = "hostX"
       val exec = "exec1"
 
@@ -73,7 +81,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext {
         assert(coordinatorRef.verifyIfInstanceActive(id3, exec) === true)
       }
 
-      coordinatorRef.deactivateInstances("x")
+      coordinatorRef.deactivateInstances(runId1)
 
       assert(coordinatorRef.verifyIfInstanceActive(id1, exec) === false)
       assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === true)
@@ -85,7 +93,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext {
           Some(ExecutorCacheTaskLocation(host, exec).toString))
       assert(coordinatorRef.getLocation(id3) === None)
 
-      coordinatorRef.deactivateInstances("y")
+      coordinatorRef.deactivateInstances(runId2)
       assert(coordinatorRef.verifyIfInstanceActive(id2, exec) === false)
       assert(coordinatorRef.getLocation(id2) === None)
     }
@@ -95,7 +103,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext {
     withCoordinatorRef(sc) { coordRef1 =>
       val coordRef2 = StateStoreCoordinatorRef.forDriver(sc.env)
 
-      val id = StateStoreId("x", 0, 0)
+      val id = StateStoreProviderId(StateStoreId("x", 0, 0), UUID.randomUUID)
 
       coordRef1.reportActiveInstance(id, "hostX", "exec1")
 
@@ -107,6 +115,45 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext {
       }
     }
   }
+
+  test("query stop deactivates related store providers") {
+    var coordRef: StateStoreCoordinatorRef = null
+    try {
+      val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
+      SparkSession.setActiveSession(spark)
+      import spark.implicits._
+      coordRef = spark.streams.stateStoreCoordinator
+      implicit val sqlContext = spark.sqlContext
+      spark.conf.set("spark.sql.shuffle.partitions", "1")
+
+      // Start a query and run a batch to load state stores
+      val inputData = MemoryStream[Int]
+      val aggregated = inputData.toDF().groupBy("value").agg(count("*")) // stateful query
+      val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+      val query = aggregated.writeStream
+        .format("memory")
+        .outputMode("update")
+        .queryName("query")
+        .option("checkpointLocation", checkpointLocation.toString)
+        .start()
+      inputData.addData(1, 2, 3)
+      query.processAllAvailable()
+
+      // Verify state store has been loaded
+      val stateCheckpointDir =
+        query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution.checkpointLocation
+      val providerId = StateStoreProviderId(StateStoreId(stateCheckpointDir, 0, 0), query.runId)
+      assert(coordRef.getLocation(providerId).nonEmpty)
+
+      // Stop and verify whether the stores are deactivated in the coordinator
+      query.stop()
+      assert(coordRef.getLocation(providerId).isEmpty)
+    } finally {
+      SparkSession.getActiveSession.foreach(_.streams.active.foreach(_.stop()))
+      if (coordRef != null) coordRef.stop()
+      StateStore.stop()
+    }
+  }
 }
 
 object StateStoreCoordinatorSuite {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
index 4a1a089af54c2178c233fb9f0f413bfd59c05be2..defb9ed63a881cb145582225de3c78b7417bb4b4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -19,20 +19,19 @@ package org.apache.spark.sql.execution.streaming.state
 
 import java.io.File
 import java.nio.file.Files
+import java.util.UUID
 
 import scala.util.Random
 
 import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
-import org.scalatest.concurrent.Eventually._
-import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
-import org.apache.spark.sql.LocalSparkSession._
-import org.apache.spark.LocalSparkContext._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+import org.apache.spark.sql.LocalSparkSession._
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.util.quietly
+import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
 import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
 import org.apache.spark.util.{CompletionIterator, Utils}
 
@@ -57,16 +56,14 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
   test("versioning and immutability") {
     withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
       val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
-      val opId = 0
-      val rdd1 =
-        makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
-            spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(
+      val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
+            spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)(
             increment)
       assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
 
       // Generate next version of stores
       val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore(
-        spark.sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)(
+        spark.sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)(
         increment)
       assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
 
@@ -76,7 +73,6 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
   }
 
   test("recovering from files") {
-    val opId = 0
     val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
 
     def makeStoreRDD(
@@ -85,7 +81,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
         storeVersion: Int): RDD[(String, Int)] = {
       implicit val sqlContext = spark.sqlContext
       makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore(
-        sqlContext, path, opId, "name", storeVersion, keySchema, valueSchema, None)(increment)
+        sqlContext, operatorStateInfo(path, version = storeVersion),
+        keySchema, valueSchema, None)(increment)
     }
 
     // Generate RDDs and state store data
@@ -132,17 +129,17 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
       }
 
       val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore(
-        spark.sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(
+        spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)(
         iteratorOfGets)
       assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None))
 
       val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
-        sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(
+        sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)(
         iteratorOfPuts)
       assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1))
 
       val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore(
-        sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)(
+        sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)(
         iteratorOfGets)
       assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None))
     }
@@ -150,22 +147,25 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
 
   test("preferred locations using StateStoreCoordinator") {
     quietly {
+      val queryRunId = UUID.randomUUID
       val opId = 0
       val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
 
       withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
         implicit val sqlContext = spark.sqlContext
         val coordinatorRef = sqlContext.streams.stateStoreCoordinator
-        coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0, "name"), "host1", "exec1")
-        coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1, "name"), "host2", "exec2")
+        val storeProviderId1 = StateStoreProviderId(StateStoreId(path, opId, 0), queryRunId)
+        val storeProviderId2 = StateStoreProviderId(StateStoreId(path, opId, 1), queryRunId)
+        coordinatorRef.reportActiveInstance(storeProviderId1, "host1", "exec1")
+        coordinatorRef.reportActiveInstance(storeProviderId2, "host2", "exec2")
 
-        assert(
-          coordinatorRef.getLocation(StateStoreId(path, opId, 0, "name")) ===
+        require(
+          coordinatorRef.getLocation(storeProviderId1) ===
             Some(ExecutorCacheTaskLocation("host1", "exec1").toString))
 
         val rdd = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
-          sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(
-          increment)
+          sqlContext, operatorStateInfo(path, queryRunId = queryRunId),
+          keySchema, valueSchema, None)(increment)
         require(rdd.partitions.length === 2)
 
         assert(
@@ -192,12 +192,12 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
         val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString
         val opId = 0
         val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
-          sqlContext, path, opId, "name", storeVersion = 0, keySchema, valueSchema, None)(increment)
+          sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)(increment)
         assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
 
         // Generate next version of stores
         val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore(
-          sqlContext, path, opId, "name", storeVersion = 1, keySchema, valueSchema, None)(increment)
+          sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)(increment)
         assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
 
         // Make sure the previous RDD still has the same data.
@@ -210,6 +210,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
     sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2)
   }
 
+  private def operatorStateInfo(
+      path: String,
+      queryRunId: UUID = UUID.randomUUID,
+      version: Int = 0): StatefulOperatorStateInfo = {
+    StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version)
+  }
+
   private val increment = (store: StateStore, iter: Iterator[String]) => {
     iter.foreach { s =>
       val key = stringToRow(s)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index af2b9f1c11fb6a6bd09f70cf0e6fe877fe2c3057..c2087ec219e57d4fd199127611768f210d072e35 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state
 
 import java.io.{File, IOException}
 import java.net.URI
+import java.util.UUID
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
@@ -33,8 +34,11 @@ import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite}
 import org.apache.spark.LocalSparkContext._
+import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.catalyst.util.quietly
+import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper}
+import org.apache.spark.sql.functions.count
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -143,7 +147,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     provider.getStore(0).commit()
 
     // Verify we don't leak temp files
-    val tempFiles = FileUtils.listFiles(new File(provider.id.checkpointLocation),
+    val tempFiles = FileUtils.listFiles(new File(provider.stateStoreId.checkpointRootLocation),
       null, true).asScala.filter(_.getName.startsWith("temp-"))
     assert(tempFiles.isEmpty)
   }
@@ -183,7 +187,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
   test("StateStore.get") {
     quietly {
       val dir = newDir()
-      val storeId = StateStoreId(dir, 0, 0)
+      val storeId = StateStoreProviderId(StateStoreId(dir, 0, 0), UUID.randomUUID)
       val storeConf = StateStoreConf.empty
       val hadoopConf = new Configuration()
 
@@ -243,18 +247,18 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
       .set("spark.rpc.numRetries", "1")
     val opId = 0
     val dir = newDir()
-    val storeId = StateStoreId(dir, opId, 0)
+    val storeProviderId = StateStoreProviderId(StateStoreId(dir, opId, 0), UUID.randomUUID)
     val sqlConf = new SQLConf()
     sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
     val storeConf = StateStoreConf(sqlConf)
     val hadoopConf = new Configuration()
-    val provider = newStoreProvider(storeId)
+    val provider = newStoreProvider(storeProviderId.storeId)
 
     var latestStoreVersion = 0
 
     def generateStoreVersions() {
       for (i <- 1 to 20) {
-        val store = StateStore.get(storeId, keySchema, valueSchema, None,
+        val store = StateStore.get(storeProviderId, keySchema, valueSchema, None,
           latestStoreVersion, storeConf, hadoopConf)
         put(store, "a", i)
         store.commit()
@@ -274,7 +278,8 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
 
           eventually(timeout(timeoutDuration)) {
             // Store should have been reported to the coordinator
-            assert(coordinatorRef.getLocation(storeId).nonEmpty, "active instance was not reported")
+            assert(coordinatorRef.getLocation(storeProviderId).nonEmpty,
+              "active instance was not reported")
 
             // Background maintenance should clean up and generate snapshots
             assert(StateStore.isMaintenanceRunning, "Maintenance task is not running")
@@ -295,35 +300,35 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
             assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted")
           }
 
-          // If driver decides to deactivate all instances of the store, then this instance
-          // should be unloaded
-          coordinatorRef.deactivateInstances(dir)
+          // If driver decides to deactivate all stores related to a query run,
+          // then this instance should be unloaded
+          coordinatorRef.deactivateInstances(storeProviderId.queryRunId)
           eventually(timeout(timeoutDuration)) {
-            assert(!StateStore.isLoaded(storeId))
+            assert(!StateStore.isLoaded(storeProviderId))
           }
 
           // Reload the store and verify
-          StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None,
+          StateStore.get(storeProviderId, keySchema, valueSchema, indexOrdinal = None,
             latestStoreVersion, storeConf, hadoopConf)
-          assert(StateStore.isLoaded(storeId))
+          assert(StateStore.isLoaded(storeProviderId))
 
           // If some other executor loads the store, then this instance should be unloaded
-          coordinatorRef.reportActiveInstance(storeId, "other-host", "other-exec")
+          coordinatorRef.reportActiveInstance(storeProviderId, "other-host", "other-exec")
           eventually(timeout(timeoutDuration)) {
-            assert(!StateStore.isLoaded(storeId))
+            assert(!StateStore.isLoaded(storeProviderId))
           }
 
           // Reload the store and verify
-          StateStore.get(storeId, keySchema, valueSchema, indexOrdinal = None,
+          StateStore.get(storeProviderId, keySchema, valueSchema, indexOrdinal = None,
             latestStoreVersion, storeConf, hadoopConf)
-          assert(StateStore.isLoaded(storeId))
+          assert(StateStore.isLoaded(storeProviderId))
         }
       }
 
       // Verify if instance is unloaded if SparkContext is stopped
       eventually(timeout(timeoutDuration)) {
         require(SparkEnv.get === null)
-        assert(!StateStore.isLoaded(storeId))
+        assert(!StateStore.isLoaded(storeProviderId))
         assert(!StateStore.isMaintenanceRunning)
       }
     }
@@ -344,7 +349,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
 
   test("SPARK-18416: do not create temp delta file until the store is updated") {
     val dir = newDir()
-    val storeId = StateStoreId(dir, 0, 0)
+    val storeId = StateStoreProviderId(StateStoreId(dir, 0, 0), UUID.randomUUID)
     val storeConf = StateStoreConf.empty
     val hadoopConf = new Configuration()
     val deltaFileDir = new File(s"$dir/0/0/")
@@ -408,12 +413,60 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     assert(numDeltaFiles === 3)
   }
 
+  test("SPARK-21145: Restarted queries create new provider instances") {
+    try {
+      val checkpointLocation = Utils.createTempDir().getAbsoluteFile
+      val spark = SparkSession.builder().master("local[2]").getOrCreate()
+      SparkSession.setActiveSession(spark)
+      implicit val sqlContext = spark.sqlContext
+      spark.conf.set("spark.sql.shuffle.partitions", "1")
+      import spark.implicits._
+      val inputData = MemoryStream[Int]
+
+      def runQueryAndGetLoadedProviders(): Seq[StateStoreProvider] = {
+        val aggregated = inputData.toDF().groupBy("value").agg(count("*"))
+        // stateful query
+        val query = aggregated.writeStream
+          .format("memory")
+          .outputMode("complete")
+          .queryName("query")
+          .option("checkpointLocation", checkpointLocation.toString)
+          .start()
+        inputData.addData(1, 2, 3)
+        query.processAllAvailable()
+        require(query.lastProgress != null) // at least one batch processed after start
+        val loadedProvidersMethod =
+          PrivateMethod[mutable.HashMap[StateStoreProviderId, StateStoreProvider]]('loadedProviders)
+        val loadedProvidersMap = StateStore invokePrivate loadedProvidersMethod()
+        val loadedProviders = loadedProvidersMap.synchronized { loadedProvidersMap.values.toSeq }
+        query.stop()
+        loadedProviders
+      }
+
+      val loadedProvidersAfterRun1 = runQueryAndGetLoadedProviders()
+      require(loadedProvidersAfterRun1.length === 1)
+
+      val loadedProvidersAfterRun2 = runQueryAndGetLoadedProviders()
+      assert(loadedProvidersAfterRun2.length === 2)   // two providers loaded for 2 runs
+
+      // Both providers should have the same StateStoreId, but the should be different objects
+      assert(loadedProvidersAfterRun2(0).stateStoreId === loadedProvidersAfterRun2(1).stateStoreId)
+      assert(loadedProvidersAfterRun2(0) ne loadedProvidersAfterRun2(1))
+
+    } finally {
+      SparkSession.getActiveSession.foreach { spark =>
+        spark.streams.active.foreach(_.stop())
+        spark.stop()
+      }
+    }
+  }
+
   override def newStoreProvider(): HDFSBackedStateStoreProvider = {
     newStoreProvider(opId = Random.nextInt(), partition = 0)
   }
 
   override def newStoreProvider(storeId: StateStoreId): HDFSBackedStateStoreProvider = {
-    newStoreProvider(storeId.operatorId, storeId.partitionId, dir = storeId.checkpointLocation)
+    newStoreProvider(storeId.operatorId, storeId.partitionId, dir = storeId.checkpointRootLocation)
   }
 
   override def getLatestData(storeProvider: HDFSBackedStateStoreProvider): Set[(String, Int)] = {
@@ -423,7 +476,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
   override def getData(
     provider: HDFSBackedStateStoreProvider,
     version: Int = -1): Set[(String, Int)] = {
-    val reloadedProvider = newStoreProvider(provider.id)
+    val reloadedProvider = newStoreProvider(provider.stateStoreId)
     if (version < 0) {
       reloadedProvider.latestIterator().map(rowsToStringInt).toSet
     } else {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index 4ede4fd9a035e11b5b481697aaa6c3d2976660e7..86c3a35a59c13f54683a59d958f63716748c8db4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -777,7 +777,7 @@ class TestStateStoreProvider extends StateStoreProvider {
     throw new Exception("Successfully instantiated")
   }
 
-  override def id: StateStoreId = null
+  override def stateStoreId: StateStoreId = null
 
   override def close(): Unit = { }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index 2a4039cc5831ab3fd4868952e4e5260ab35a1ec2..b2c42eef88f6d83de5d11cfdc15cc2ffd34658f0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -26,9 +26,8 @@ import scala.reflect.ClassTag
 import scala.util.Random
 import scala.util.control.NonFatal
 
-import org.scalatest.Assertions
+import org.scalatest.{Assertions, BeforeAndAfterAll}
 import org.scalatest.concurrent.{Eventually, Timeouts}
-import org.scalatest.concurrent.Eventually._
 import org.scalatest.concurrent.PatienceConfiguration.Timeout
 import org.scalatest.exceptions.TestFailedDueToTimeoutException
 import org.scalatest.time.Span
@@ -39,9 +38,10 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, Ro
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.state.StateStore
 import org.apache.spark.sql.streaming.StreamingQueryListener._
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils}
+import org.apache.spark.util.{Clock, SystemClock, Utils}
 
 /**
  * A framework for implementing tests for streaming queries and sources.
@@ -67,7 +67,12 @@ import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils}
  * avoid hanging forever in the case of failures. However, individual suites can change this
  * by overriding `streamingTimeout`.
  */
-trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
+trait StreamTest extends QueryTest with SharedSQLContext with Timeouts with BeforeAndAfterAll {
+
+  override def afterAll(): Unit = {
+    super.afterAll()
+    StateStore.stop() // stop the state store maintenance thread and unload store providers
+  }
 
   /** How long to wait for an active stream to catch up when checking a result. */
   val streamingTimeout = 10.seconds