Skip to content
Snippets Groups Projects
Commit 2f98ee67 authored by Shixiong Zhu's avatar Shixiong Zhu Committed by Marcelo Vanzin
Browse files

[SPARK-14169][CORE] Add UninterruptibleThread

## What changes were proposed in this pull request?

Extract the workaround for HADOOP-10622 introduced by #11940 into UninterruptibleThread so that we can test and reuse it.

## How was this patch tested?

Unit tests

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #11971 from zsxwing/uninterrupt.
parent b7836492
No related branches found
No related tags found
No related merge requests found
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util
import javax.annotation.concurrent.GuardedBy
/**
* A special Thread that provides "runUninterruptibly" to allow running codes without being
* interrupted by `Thread.interrupt()`. If `Thread.interrupt()` is called during runUninterruptibly
* is running, it won't set the interrupted status. Instead, setting the interrupted status will be
* deferred until it's returning from "runUninterruptibly".
*
* Note: "runUninterruptibly" should be called only in `this` thread.
*/
private[spark] class UninterruptibleThread(name: String) extends Thread(name) {
/** A monitor to protect "uninterruptible" and "interrupted" */
private val uninterruptibleLock = new Object
/**
* Indicates if `this` thread are in the uninterruptible status. If so, interrupting
* "this" will be deferred until `this` enters into the interruptible status.
*/
@GuardedBy("uninterruptibleLock")
private var uninterruptible = false
/**
* Indicates if we should interrupt `this` when we are leaving the uninterruptible zone.
*/
@GuardedBy("uninterruptibleLock")
private var shouldInterruptThread = false
/**
* Run `f` uninterruptibly in `this` thread. The thread won't be interrupted before returning
* from `f`.
*
* If this method finds that `interrupt` is called before calling `f` and it's not inside another
* `runUninterruptibly`, it will throw `InterruptedException`.
*
* Note: this method should be called only in `this` thread.
*/
def runUninterruptibly[T](f: => T): T = {
if (Thread.currentThread() != this) {
throw new IllegalStateException(s"Call runUninterruptibly in a wrong thread. " +
s"Expected: $this but was ${Thread.currentThread()}")
}
if (uninterruptibleLock.synchronized { uninterruptible }) {
// We are already in the uninterruptible status. So just run "f" and return
return f
}
uninterruptibleLock.synchronized {
// Clear the interrupted status if it's set.
if (Thread.interrupted() || shouldInterruptThread) {
shouldInterruptThread = false
// Since it's interrupted, we don't need to run `f` which may be a long computation.
// Throw InterruptedException as we don't have a T to return.
throw new InterruptedException()
}
uninterruptible = true
}
try {
f
} finally {
uninterruptibleLock.synchronized {
uninterruptible = false
if (shouldInterruptThread) {
// Recover the interrupted status
super.interrupt()
shouldInterruptThread = false
}
}
}
}
/**
* Tests whether `interrupt()` has been called.
*/
override def isInterrupted: Boolean = {
super.isInterrupted || uninterruptibleLock.synchronized { shouldInterruptThread }
}
/**
* Interrupt `this` thread if possible. If `this` is in the uninterruptible status, it won't be
* interrupted until it enters into the interruptible status.
*/
override def interrupt(): Unit = {
uninterruptibleLock.synchronized {
if (uninterruptible) {
shouldInterruptThread = true
} else {
super.interrupt()
}
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.util
import java.util.concurrent.{CountDownLatch, TimeUnit}
import scala.util.Random
import com.google.common.util.concurrent.Uninterruptibles
import org.apache.spark.SparkFunSuite
class UninterruptibleThreadSuite extends SparkFunSuite {
/** Sleep millis and return true if it's interrupted */
private def sleep(millis: Long): Boolean = {
try {
Thread.sleep(millis)
false
} catch {
case _: InterruptedException =>
true
}
}
test("interrupt when runUninterruptibly is running") {
val enterRunUninterruptibly = new CountDownLatch(1)
@volatile var hasInterruptedException = false
@volatile var interruptStatusBeforeExit = false
val t = new UninterruptibleThread("test") {
override def run(): Unit = {
runUninterruptibly {
enterRunUninterruptibly.countDown()
hasInterruptedException = sleep(1000)
}
interruptStatusBeforeExit = Thread.interrupted()
}
}
t.start()
assert(enterRunUninterruptibly.await(10, TimeUnit.SECONDS), "await timeout")
t.interrupt()
t.join()
assert(hasInterruptedException === false)
assert(interruptStatusBeforeExit === true)
}
test("interrupt before runUninterruptibly runs") {
val interruptLatch = new CountDownLatch(1)
@volatile var hasInterruptedException = false
@volatile var interruptStatusBeforeExit = false
val t = new UninterruptibleThread("test") {
override def run(): Unit = {
Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS)
try {
runUninterruptibly {
assert(false, "Should not reach here")
}
} catch {
case _: InterruptedException => hasInterruptedException = true
}
interruptStatusBeforeExit = Thread.interrupted()
}
}
t.start()
t.interrupt()
interruptLatch.countDown()
t.join()
assert(hasInterruptedException === true)
assert(interruptStatusBeforeExit === false)
}
test("nested runUninterruptibly") {
val enterRunUninterruptibly = new CountDownLatch(1)
val interruptLatch = new CountDownLatch(1)
@volatile var hasInterruptedException = false
@volatile var interruptStatusBeforeExit = false
val t = new UninterruptibleThread("test") {
override def run(): Unit = {
runUninterruptibly {
enterRunUninterruptibly.countDown()
Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS)
hasInterruptedException = sleep(1)
runUninterruptibly {
if (sleep(1)) {
hasInterruptedException = true
}
}
if (sleep(1)) {
hasInterruptedException = true
}
}
interruptStatusBeforeExit = Thread.interrupted()
}
}
t.start()
assert(enterRunUninterruptibly.await(10, TimeUnit.SECONDS), "await timeout")
t.interrupt()
interruptLatch.countDown()
t.join()
assert(hasInterruptedException === false)
assert(interruptStatusBeforeExit === true)
}
test("stress test") {
@volatile var hasInterruptedException = false
val t = new UninterruptibleThread("test") {
override def run(): Unit = {
for (i <- 0 until 100) {
try {
runUninterruptibly {
if (sleep(Random.nextInt(10))) {
hasInterruptedException = true
}
runUninterruptibly {
if (sleep(Random.nextInt(10))) {
hasInterruptedException = true
}
}
if (sleep(Random.nextInt(10))) {
hasInterruptedException = true
}
}
Uninterruptibles.sleepUninterruptibly(Random.nextInt(10), TimeUnit.MILLISECONDS)
// 50% chance to clear the interrupted status
if (Random.nextBoolean()) {
Thread.interrupted()
}
} catch {
case _: InterruptedException =>
// The first runUninterruptibly may throw InterruptedException if the interrupt status
// is set before running `f`.
}
}
}
}
t.start()
for (i <- 0 until 400) {
Thread.sleep(Random.nextInt(10))
t.interrupt()
}
t.join()
assert(hasInterruptedException === false)
}
}
...@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.streaming ...@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.streaming
import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal import scala.util.control.NonFatal
...@@ -34,6 +33,7 @@ import org.apache.spark.sql.catalyst.util._ ...@@ -34,6 +33,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.util.ContinuousQueryListener import org.apache.spark.sql.util.ContinuousQueryListener
import org.apache.spark.sql.util.ContinuousQueryListener._ import org.apache.spark.sql.util.ContinuousQueryListener._
import org.apache.spark.util.UninterruptibleThread
/** /**
* Manages the execution of a streaming Spark SQL query that is occurring in a separate thread. * Manages the execution of a streaming Spark SQL query that is occurring in a separate thread.
...@@ -89,9 +89,10 @@ class StreamExecution( ...@@ -89,9 +89,10 @@ class StreamExecution(
private[sql] var streamDeathCause: ContinuousQueryException = null private[sql] var streamDeathCause: ContinuousQueryException = null
/** The thread that runs the micro-batches of this stream. */ /** The thread that runs the micro-batches of this stream. */
private[sql] val microBatchThread = new Thread(s"stream execution thread for $name") { private[sql] val microBatchThread =
override def run(): Unit = { runBatches() } new UninterruptibleThread(s"stream execution thread for $name") {
} override def run(): Unit = { runBatches() }
}
/** /**
* A write-ahead-log that records the offsets that are present in each batch. In order to ensure * A write-ahead-log that records the offsets that are present in each batch. In order to ensure
...@@ -102,65 +103,6 @@ class StreamExecution( ...@@ -102,65 +103,6 @@ class StreamExecution(
private val offsetLog = private val offsetLog =
new HDFSMetadataLog[CompositeOffset](sqlContext, checkpointFile("offsets")) new HDFSMetadataLog[CompositeOffset](sqlContext, checkpointFile("offsets"))
/** A monitor to protect "uninterruptible" and "interrupted" */
private val uninterruptibleLock = new Object
/**
* Indicates if "microBatchThread" are in the uninterruptible status. If so, interrupting
* "microBatchThread" will be deferred until "microBatchThread" enters into the interruptible
* status.
*/
@GuardedBy("uninterruptibleLock")
private var uninterruptible = false
/**
* Indicates if we should interrupt "microBatchThread" when we are leaving the uninterruptible
* zone.
*/
@GuardedBy("uninterruptibleLock")
private var shouldInterruptThread = false
/**
* Interrupt "microBatchThread" if possible. If "microBatchThread" is in the uninterruptible
* status, "microBatchThread" won't be interrupted until it enters into the interruptible status.
*/
private def interruptMicroBatchThreadSafely(): Unit = {
uninterruptibleLock.synchronized {
if (uninterruptible) {
shouldInterruptThread = true
} else {
microBatchThread.interrupt()
}
}
}
/**
* Run `f` uninterruptibly in "microBatchThread". "microBatchThread" won't be interrupted before
* returning from `f`.
*/
private def runUninterruptiblyInMicroBatchThread[T](f: => T): T = {
assert(Thread.currentThread() == microBatchThread)
uninterruptibleLock.synchronized {
uninterruptible = true
// Clear the interrupted status if it's set.
if (Thread.interrupted()) {
shouldInterruptThread = true
}
}
try {
f
} finally {
uninterruptibleLock.synchronized {
uninterruptible = false
if (shouldInterruptThread) {
// Recover the interrupted status
microBatchThread.interrupt()
shouldInterruptThread = false
}
}
}
}
/** Whether the query is currently active or not */ /** Whether the query is currently active or not */
override def isActive: Boolean = state == ACTIVE override def isActive: Boolean = state == ACTIVE
...@@ -294,7 +236,7 @@ class StreamExecution( ...@@ -294,7 +236,7 @@ class StreamExecution(
// method. See SPARK-14131. // method. See SPARK-14131.
// //
// Check to see what new data is available. // Check to see what new data is available.
val newData = runUninterruptiblyInMicroBatchThread { val newData = microBatchThread.runUninterruptibly {
uniqueSources.flatMap(s => s.getOffset.map(o => s -> o)) uniqueSources.flatMap(s => s.getOffset.map(o => s -> o))
} }
availableOffsets ++= newData availableOffsets ++= newData
...@@ -305,7 +247,7 @@ class StreamExecution( ...@@ -305,7 +247,7 @@ class StreamExecution(
// As "offsetLog.add" will create a file using HDFS API and call "Shell.runCommand" to set // As "offsetLog.add" will create a file using HDFS API and call "Shell.runCommand" to set
// the file permission, we should not interrupt "microBatchThread" when running this method. // the file permission, we should not interrupt "microBatchThread" when running this method.
// See SPARK-14131. // See SPARK-14131.
runUninterruptiblyInMicroBatchThread { microBatchThread.runUninterruptibly {
assert( assert(
offsetLog.add(currentBatchId, availableOffsets.toCompositeOffset(sources)), offsetLog.add(currentBatchId, availableOffsets.toCompositeOffset(sources)),
s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId") s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId")
...@@ -395,7 +337,7 @@ class StreamExecution( ...@@ -395,7 +337,7 @@ class StreamExecution(
// intentionally // intentionally
state = TERMINATED state = TERMINATED
if (microBatchThread.isAlive) { if (microBatchThread.isAlive) {
interruptMicroBatchThreadSafely() microBatchThread.interrupt()
microBatchThread.join() microBatchThread.join()
} }
logInfo(s"Query $name was stopped") logInfo(s"Query $name was stopped")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment