Skip to content
Snippets Groups Projects
Commit ea31f92b authored by Shixiong Zhu's avatar Shixiong Zhu Committed by Tathagata Das
Browse files

[SPARK-19267][SS] Fix a race condition when stopping StateStore

## What changes were proposed in this pull request?

There is a race condition when stopping StateStore which makes `StateStoreSuite.maintenance` flaky. `StateStore.stop` doesn't wait for the running task to finish, and an out-of-date task may fail `doMaintenance` and cancel the new task. Here is a reproducer: https://github.com/zsxwing/spark/commit/dde1b5b106ba034861cf19e16883cfe181faa6f3

This PR adds MaintenanceTask to eliminate the race condition.

## How was this patch tested?

Jenkins

Author: Shixiong Zhu <shixiong@databricks.com>
Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #16627 from zsxwing/SPARK-19267.
parent 9b7a03f1
No related branches found
No related tags found
No related merge requests found
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.streaming.state package org.apache.spark.sql.execution.streaming.state
import java.util.concurrent.{ScheduledFuture, TimeUnit} import java.util.concurrent.{ScheduledFuture, TimeUnit}
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable import scala.collection.mutable
import scala.util.control.NonFatal import scala.util.control.NonFatal
...@@ -124,12 +125,46 @@ object StateStore extends Logging { ...@@ -124,12 +125,46 @@ object StateStore extends Logging {
val MAINTENANCE_INTERVAL_CONFIG = "spark.sql.streaming.stateStore.maintenanceInterval" val MAINTENANCE_INTERVAL_CONFIG = "spark.sql.streaming.stateStore.maintenanceInterval"
val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60 val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60
@GuardedBy("loadedProviders")
private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]() private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]()
private val maintenanceTaskExecutor =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("state-store-maintenance-task")
@volatile private var maintenanceTask: ScheduledFuture[_] = null /**
@volatile private var _coordRef: StateStoreCoordinatorRef = null * Runs the `task` periodically and automatically cancels it if there is an exception. `onError`
* will be called when an exception happens.
*/
class MaintenanceTask(periodMs: Long, task: => Unit, onError: => Unit) {
private val executor =
ThreadUtils.newDaemonSingleThreadScheduledExecutor("state-store-maintenance-task")
private val runnable = new Runnable {
override def run(): Unit = {
try {
task
} catch {
case NonFatal(e) =>
logWarning("Error running maintenance thread", e)
onError
throw e
}
}
}
private val future: ScheduledFuture[_] = executor.scheduleAtFixedRate(
runnable, periodMs, periodMs, TimeUnit.MILLISECONDS)
def stop(): Unit = {
future.cancel(false)
executor.shutdown()
}
def isRunning: Boolean = !future.isDone
}
@GuardedBy("loadedProviders")
private var maintenanceTask: MaintenanceTask = null
@GuardedBy("loadedProviders")
private var _coordRef: StateStoreCoordinatorRef = null
/** Get or create a store associated with the id. */ /** Get or create a store associated with the id. */
def get( def get(
...@@ -162,7 +197,7 @@ object StateStore extends Logging { ...@@ -162,7 +197,7 @@ object StateStore extends Logging {
} }
def isMaintenanceRunning: Boolean = loadedProviders.synchronized { def isMaintenanceRunning: Boolean = loadedProviders.synchronized {
maintenanceTask != null maintenanceTask != null && maintenanceTask.isRunning
} }
/** Unload and stop all state store providers */ /** Unload and stop all state store providers */
...@@ -170,7 +205,7 @@ object StateStore extends Logging { ...@@ -170,7 +205,7 @@ object StateStore extends Logging {
loadedProviders.clear() loadedProviders.clear()
_coordRef = null _coordRef = null
if (maintenanceTask != null) { if (maintenanceTask != null) {
maintenanceTask.cancel(false) maintenanceTask.stop()
maintenanceTask = null maintenanceTask = null
} }
logInfo("StateStore stopped") logInfo("StateStore stopped")
...@@ -179,14 +214,14 @@ object StateStore extends Logging { ...@@ -179,14 +214,14 @@ object StateStore extends Logging {
/** Start the periodic maintenance task if not already started and if Spark active */ /** Start the periodic maintenance task if not already started and if Spark active */
private def startMaintenanceIfNeeded(): Unit = loadedProviders.synchronized { private def startMaintenanceIfNeeded(): Unit = loadedProviders.synchronized {
val env = SparkEnv.get val env = SparkEnv.get
if (maintenanceTask == null && env != null) { if (env != null && !isMaintenanceRunning) {
val periodMs = env.conf.getTimeAsMs( val periodMs = env.conf.getTimeAsMs(
MAINTENANCE_INTERVAL_CONFIG, s"${MAINTENANCE_INTERVAL_DEFAULT_SECS}s") MAINTENANCE_INTERVAL_CONFIG, s"${MAINTENANCE_INTERVAL_DEFAULT_SECS}s")
val runnable = new Runnable { maintenanceTask = new MaintenanceTask(
override def run(): Unit = { doMaintenance() } periodMs,
} task = { doMaintenance() },
maintenanceTask = maintenanceTaskExecutor.scheduleAtFixedRate( onError = { loadedProviders.synchronized { loadedProviders.clear() } }
runnable, periodMs, periodMs, TimeUnit.MILLISECONDS) )
logInfo("State Store maintenance task started") logInfo("State Store maintenance task started")
} }
} }
...@@ -198,21 +233,20 @@ object StateStore extends Logging { ...@@ -198,21 +233,20 @@ object StateStore extends Logging {
private def doMaintenance(): Unit = { private def doMaintenance(): Unit = {
logDebug("Doing maintenance") logDebug("Doing maintenance")
if (SparkEnv.get == null) { if (SparkEnv.get == null) {
stop() throw new IllegalStateException("SparkEnv not active, cannot do maintenance on StateStores")
} else { }
loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) => loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) =>
try { try {
if (verifyIfStoreInstanceActive(id)) { if (verifyIfStoreInstanceActive(id)) {
provider.doMaintenance() provider.doMaintenance()
} else { } else {
unload(id) unload(id)
logInfo(s"Unloaded $provider") logInfo(s"Unloaded $provider")
}
} catch {
case NonFatal(e) =>
logWarning(s"Error managing $provider, stopping management thread")
stop()
} }
} catch {
case NonFatal(e) =>
logWarning(s"Error managing $provider, stopping management thread")
throw e
} }
} }
} }
...@@ -238,7 +272,7 @@ object StateStore extends Logging { ...@@ -238,7 +272,7 @@ object StateStore extends Logging {
} }
} }
private def coordinatorRef: Option[StateStoreCoordinatorRef] = synchronized { private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized {
val env = SparkEnv.get val env = SparkEnv.get
if (env != null) { if (env != null) {
if (_coordRef == null) { if (_coordRef == null) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment