From eb3ea3a0831b26d3dc35a97566716b92868a7beb Mon Sep 17 00:00:00 2001
From: hyukjinkwon <gurwls223@gmail.com>
Date: Sun, 11 Jun 2017 09:54:57 +0100
Subject: [PATCH] [SPARK-20935][STREAMING] Always close WriteAheadLog and make
 it idempotent

## What changes were proposed in this pull request?

This PR proposes to stop `ReceiverTracker` to close `WriteAheadLog` whenever it is and make `WriteAheadLog` and its implementations idempotent.

## How was this patch tested?

Added a test in `WriteAheadLogSuite`. Note that  the added test looks passing even if it closes twice (namely even without the changes in `FileBasedWriteAheadLog` and `BatchedWriteAheadLog`. It looks both are already idempotent but this is a rather sanity check.

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #18224 from HyukjinKwon/streaming-closing.
---
 .../spark/streaming/util/WriteAheadLog.java   |  2 +-
 .../streaming/scheduler/ReceiverTracker.scala | 27 +++++++------------
 .../streaming/util/BatchedWriteAheadLog.scala | 13 +++++----
 .../util/FileBasedWriteAheadLog.scala         |  8 +++---
 .../scheduler/ReceiverTrackerSuite.scala      |  2 ++
 .../streaming/util/WriteAheadLogSuite.scala   |  2 ++
 6 files changed, 26 insertions(+), 28 deletions(-)

diff --git a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java
index 2803cad809..00c5972874 100644
--- a/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java
+++ b/streaming/src/main/java/org/apache/spark/streaming/util/WriteAheadLog.java
@@ -56,7 +56,7 @@ public abstract class WriteAheadLog {
   public abstract void clean(long threshTime, boolean waitForCompletion);
 
   /**
-   * Close this log and release any resources.
+   * Close this log and release any resources. It must be idempotent.
    */
   public abstract void close();
 }
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index bd7ab0b9bf..6f130c803f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -165,11 +165,11 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
 
   /** Stop the receiver execution thread. */
   def stop(graceful: Boolean): Unit = synchronized {
-    if (isTrackerStarted) {
-      // First, stop the receivers
-      trackerState = Stopping
+    val isStarted: Boolean = isTrackerStarted
+    trackerState = Stopping
+    if (isStarted) {
       if (!skipReceiverLaunch) {
-        // Send the stop signal to all the receivers
+        // First, stop the receivers. Send the stop signal to all the receivers
         endpoint.askSync[Boolean](StopAllReceivers)
 
         // Wait for the Spark job that runs the receivers to be over
@@ -194,17 +194,13 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
       // Finally, stop the endpoint
       ssc.env.rpcEnv.stop(endpoint)
       endpoint = null
-      receivedBlockTracker.stop()
-      logInfo("ReceiverTracker stopped")
-      trackerState = Stopped
-    } else if (isTrackerInitialized) {
-      trackerState = Stopping
-      // `ReceivedBlockTracker` is open when this instance is created. We should
-      // close this even if this `ReceiverTracker` is not started.
-      receivedBlockTracker.stop()
-      logInfo("ReceiverTracker stopped")
-      trackerState = Stopped
     }
+
+    // `ReceivedBlockTracker` is open when this instance is created. We should
+    // close this even if this `ReceiverTracker` is not started.
+    receivedBlockTracker.stop()
+    logInfo("ReceiverTracker stopped")
+    trackerState = Stopped
   }
 
   /** Allocate all unallocated blocks to the given batch. */
@@ -453,9 +449,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
     endpoint.send(StartAllReceivers(receivers))
   }
 
-  /** Check if tracker has been marked for initiated */
-  private def isTrackerInitialized: Boolean = trackerState == Initialized
-
   /** Check if tracker has been marked for starting */
   private def isTrackerStarted: Boolean = trackerState == Started
 
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
index 35f0166ed0..e522bc62d5 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
@@ -19,6 +19,7 @@ package org.apache.spark.streaming.util
 
 import java.nio.ByteBuffer
 import java.util.{Iterator => JIterator}
+import java.util.concurrent.atomic.AtomicBoolean
 import java.util.concurrent.LinkedBlockingQueue
 
 import scala.collection.JavaConverters._
@@ -60,7 +61,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp
   private val walWriteQueue = new LinkedBlockingQueue[Record]()
 
   // Whether the writer thread is active
-  @volatile private var active: Boolean = true
+  private val active: AtomicBoolean = new AtomicBoolean(true)
   private val buffer = new ArrayBuffer[Record]()
 
   private val batchedWriterThread = startBatchedWriterThread()
@@ -72,7 +73,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp
   override def write(byteBuffer: ByteBuffer, time: Long): WriteAheadLogRecordHandle = {
     val promise = Promise[WriteAheadLogRecordHandle]()
     val putSuccessfully = synchronized {
-      if (active) {
+      if (active.get()) {
         walWriteQueue.offer(Record(byteBuffer, time, promise))
         true
       } else {
@@ -121,9 +122,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp
    */
   override def close(): Unit = {
     logInfo(s"BatchedWriteAheadLog shutting down at time: ${System.currentTimeMillis()}.")
-    synchronized {
-      active = false
-    }
+    if (!active.getAndSet(false)) return
     batchedWriterThread.interrupt()
     batchedWriterThread.join()
     while (!walWriteQueue.isEmpty) {
@@ -138,7 +137,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp
   private def startBatchedWriterThread(): Thread = {
     val thread = new Thread(new Runnable {
       override def run(): Unit = {
-        while (active) {
+        while (active.get()) {
           try {
             flushRecords()
           } catch {
@@ -166,7 +165,7 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: Sp
     }
     try {
       var segment: WriteAheadLogRecordHandle = null
-      if (buffer.length > 0) {
+      if (buffer.nonEmpty) {
         logDebug(s"Batched ${buffer.length} records for Write Ahead Log write")
         // threads may not be able to add items in order by time
         val sortedByTime = buffer.sortBy(_.time)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
index 1e5f18797e..d6e15cfdd2 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
@@ -205,10 +205,12 @@ private[streaming] class FileBasedWriteAheadLog(
 
   /** Stop the manager, close any open log writer */
   def close(): Unit = synchronized {
-    if (currentLogWriter != null) {
-      currentLogWriter.close()
+    if (!executionContext.isShutdown) {
+      if (currentLogWriter != null) {
+        currentLogWriter.close()
+      }
+      executionContext.shutdown()
     }
-    executionContext.shutdown()
     logInfo("Stopped write ahead log manager")
   }
 
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
index df122ac090..c206d3169d 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala
@@ -57,6 +57,8 @@ class ReceiverTrackerSuite extends TestSuiteBase {
         }
       } finally {
         tracker.stop(false)
+        // Make sure it is idempotent.
+        tracker.stop(false)
       }
     }
   }
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
index 4bec52b9fe..ede15399f0 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
@@ -140,6 +140,8 @@ abstract class CommonWriteAheadLogTests(
       }
     }
     writeAheadLog.close()
+    // Make sure it is idempotent.
+    writeAheadLog.close()
   }
 
   test(testPrefix + "handling file errors while reading rotating logs") {
-- 
GitLab