diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
index 5aa0636850255d1d3d212a3831023e9aacb8267a..812e1b0a3957049c79eefa6c7b145604608e24a0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
 
 import org.apache.spark.SparkEnv
 import org.apache.spark.internal.Logging
-import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
+import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
 import org.apache.spark.scheduler.ExecutorCacheTaskLocation
 import org.apache.spark.util.RpcUtils
 
@@ -112,7 +112,7 @@ private[sql] class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointR
  * Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster,
  * and get their locations for job scheduling.
  */
-private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends RpcEndpoint {
+private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint {
   private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation]
 
   override def receive: PartialFunction[Any, Unit] = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
index df50cbde56087179933ca32bff5ad3187c33a7e9..85db05157c359ba22834f53fcab3719e8fe2ac5f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -124,11 +124,9 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
         coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1")
         coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2")
 
-        eventually(timeout(10 seconds)) {
-          assert(
-            coordinatorRef.getLocation(StateStoreId(path, opId, 0)) ===
-              Some(ExecutorCacheTaskLocation("host1", "exec1").toString))
-        }
+        assert(
+          coordinatorRef.getLocation(StateStoreId(path, opId, 0)) ===
+            Some(ExecutorCacheTaskLocation("host1", "exec1").toString))
 
         val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore(
           increment, path, opId, storeVersion = 0, keySchema, valueSchema)