From ea31f92bb8554a901ff5b48986097a2642c64399 Mon Sep 17 00:00:00 2001
From: Shixiong Zhu <shixiong@databricks.com>
Date: Fri, 20 Jan 2017 17:49:26 -0800
Subject: [PATCH] [SPARK-19267][SS] Fix a race condition when stopping
 StateStore

## What changes were proposed in this pull request?

There is a race condition when stopping StateStore which makes `StateStoreSuite.maintenance` flaky. `StateStore.stop` doesn't wait for the running task to finish, and an out-of-date task may fail `doMaintenance` and cancel the new task. Here is a reproducer: https://github.com/zsxwing/spark/commit/dde1b5b106ba034861cf19e16883cfe181faa6f3

This PR adds MaintenanceTask to eliminate the race condition.

## How was this patch tested?

Jenkins

Author: Shixiong Zhu <shixiong@databricks.com>
Author: Tathagata Das <tathagata.das1565@gmail.com>

Closes #16627 from zsxwing/SPARK-19267.
---
 .../streaming/state/StateStore.scala          | 88 +++++++++++++------
 1 file changed, 61 insertions(+), 27 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index d59746f947..e61d95a1b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.execution.streaming.state
 
 import java.util.concurrent.{ScheduledFuture, TimeUnit}
+import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.mutable
 import scala.util.control.NonFatal
@@ -124,12 +125,46 @@ object StateStore extends Logging {
   val MAINTENANCE_INTERVAL_CONFIG = "spark.sql.streaming.stateStore.maintenanceInterval"
   val MAINTENANCE_INTERVAL_DEFAULT_SECS = 60
 
+  @GuardedBy("loadedProviders")
   private val loadedProviders = new mutable.HashMap[StateStoreId, StateStoreProvider]()
-  private val maintenanceTaskExecutor =
-    ThreadUtils.newDaemonSingleThreadScheduledExecutor("state-store-maintenance-task")
 
-  @volatile private var maintenanceTask: ScheduledFuture[_] = null
-  @volatile private var _coordRef: StateStoreCoordinatorRef = null
+  /**
+   * Runs the `task` periodically and automatically cancels it if there is an exception. `onError`
+   * will be called when an exception happens.
+   */
+  class MaintenanceTask(periodMs: Long, task: => Unit, onError: => Unit) {
+    private val executor =
+      ThreadUtils.newDaemonSingleThreadScheduledExecutor("state-store-maintenance-task")
+
+    private val runnable = new Runnable {
+      override def run(): Unit = {
+        try {
+          task
+        } catch {
+          case NonFatal(e) =>
+            logWarning("Error running maintenance thread", e)
+            onError
+            throw e
+        }
+      }
+    }
+
+    private val future: ScheduledFuture[_] = executor.scheduleAtFixedRate(
+      runnable, periodMs, periodMs, TimeUnit.MILLISECONDS)
+
+    def stop(): Unit = {
+      future.cancel(false)
+      executor.shutdown()
+    }
+
+    def isRunning: Boolean = !future.isDone
+  }
+
+  @GuardedBy("loadedProviders")
+  private var maintenanceTask: MaintenanceTask = null
+
+  @GuardedBy("loadedProviders")
+  private var _coordRef: StateStoreCoordinatorRef = null
 
   /** Get or create a store associated with the id. */
   def get(
@@ -162,7 +197,7 @@ object StateStore extends Logging {
   }
 
   def isMaintenanceRunning: Boolean = loadedProviders.synchronized {
-    maintenanceTask != null
+    maintenanceTask != null && maintenanceTask.isRunning
   }
 
   /** Unload and stop all state store providers */
@@ -170,7 +205,7 @@ object StateStore extends Logging {
     loadedProviders.clear()
     _coordRef = null
     if (maintenanceTask != null) {
-      maintenanceTask.cancel(false)
+      maintenanceTask.stop()
       maintenanceTask = null
     }
     logInfo("StateStore stopped")
@@ -179,14 +214,14 @@ object StateStore extends Logging {
   /** Start the periodic maintenance task if not already started and if Spark active */
   private def startMaintenanceIfNeeded(): Unit = loadedProviders.synchronized {
     val env = SparkEnv.get
-    if (maintenanceTask == null && env != null) {
+    if (env != null && !isMaintenanceRunning) {
       val periodMs = env.conf.getTimeAsMs(
         MAINTENANCE_INTERVAL_CONFIG, s"${MAINTENANCE_INTERVAL_DEFAULT_SECS}s")
-      val runnable = new Runnable {
-        override def run(): Unit = { doMaintenance() }
-      }
-      maintenanceTask = maintenanceTaskExecutor.scheduleAtFixedRate(
-        runnable, periodMs, periodMs, TimeUnit.MILLISECONDS)
+      maintenanceTask = new MaintenanceTask(
+        periodMs,
+        task = { doMaintenance() },
+        onError = { loadedProviders.synchronized { loadedProviders.clear() } }
+      )
       logInfo("State Store maintenance task started")
     }
   }
@@ -198,21 +233,20 @@ object StateStore extends Logging {
   private def doMaintenance(): Unit = {
     logDebug("Doing maintenance")
     if (SparkEnv.get == null) {
-      stop()
-    } else {
-      loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) =>
-        try {
-          if (verifyIfStoreInstanceActive(id)) {
-            provider.doMaintenance()
-          } else {
-            unload(id)
-            logInfo(s"Unloaded $provider")
-          }
-        } catch {
-          case NonFatal(e) =>
-            logWarning(s"Error managing $provider, stopping management thread")
-            stop()
+      throw new IllegalStateException("SparkEnv not active, cannot do maintenance on StateStores")
+    }
+    loadedProviders.synchronized { loadedProviders.toSeq }.foreach { case (id, provider) =>
+      try {
+        if (verifyIfStoreInstanceActive(id)) {
+          provider.doMaintenance()
+        } else {
+          unload(id)
+          logInfo(s"Unloaded $provider")
         }
+      } catch {
+        case NonFatal(e) =>
+          logWarning(s"Error managing $provider, stopping management thread")
+          throw e
       }
     }
   }
@@ -238,7 +272,7 @@ object StateStore extends Logging {
     }
   }
 
-  private def coordinatorRef: Option[StateStoreCoordinatorRef] = synchronized {
+  private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized {
     val env = SparkEnv.get
     if (env != null) {
       if (_coordRef == null) {
-- 
GitLab