Skip to content
Snippets Groups Projects
Commit 7850e0c7 authored by Josh Rosen's avatar Josh Rosen Committed by Reynold Xin
Browse files

[SPARK-4393] Fix memory leak in ConnectionManager ACK timeout TimerTasks; use HashedWheelTimer

This patch is intended to fix a subtle memory leak in ConnectionManager's ACK timeout TimerTasks: in the old code, each TimerTask held a reference to the message being sent and a cancelled TimerTask won't necessarily be garbage-collected until it's scheduled to run, so this caused huge buildups of messages that weren't garbage collected until their timeouts expired, leading to OOMs.

This patch addresses this problem by capturing only the message ID in the TimerTask instead of the whole message, and by keeping a WeakReference to the promise in the TimerTask.  I've also modified this code to use Netty's HashedWheelTimer, whose performance characteristics should be better for this use-case.

Thanks to cristianopris for narrowing down this issue!

Author: Josh Rosen <joshrosen@databricks.com>

Closes #3259 from JoshRosen/connection-manager-timeout-bugfix and squashes the following commits:

afcc8d6 [Josh Rosen] Address rxin's review feedback.
2a2e92d [Josh Rosen] Keep only WeakReference to promise in TimerTask;
0f0913b [Josh Rosen] Spelling fix: timout => timeout
3200c33 [Josh Rosen] Use Netty HashedWheelTimer
f847dd4 [Josh Rosen] Don't capture entire message in ACK timeout task.
parent 84468b2e
No related branches found
No related tags found
No related merge requests found
...@@ -18,13 +18,13 @@ ...@@ -18,13 +18,13 @@
package org.apache.spark.network.nio package org.apache.spark.network.nio
import java.io.IOException import java.io.IOException
import java.lang.ref.WeakReference
import java.net._ import java.net._
import java.nio._ import java.nio._
import java.nio.channels._ import java.nio.channels._
import java.nio.channels.spi._ import java.nio.channels.spi._
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit}
import java.util.{Timer, TimerTask}
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue}
import scala.concurrent.duration._ import scala.concurrent.duration._
...@@ -32,6 +32,7 @@ import scala.concurrent.{Await, ExecutionContext, Future, Promise} ...@@ -32,6 +32,7 @@ import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.language.postfixOps import scala.language.postfixOps
import com.google.common.base.Charsets.UTF_8 import com.google.common.base.Charsets.UTF_8
import io.netty.util.{Timeout, TimerTask, HashedWheelTimer}
import org.apache.spark._ import org.apache.spark._
import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
...@@ -77,7 +78,8 @@ private[nio] class ConnectionManager( ...@@ -77,7 +78,8 @@ private[nio] class ConnectionManager(
} }
private val selector = SelectorProvider.provider.openSelector() private val selector = SelectorProvider.provider.openSelector()
private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true) private val ackTimeoutMonitor =
new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor"))
private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60) private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60)
...@@ -139,7 +141,10 @@ private[nio] class ConnectionManager( ...@@ -139,7 +141,10 @@ private[nio] class ConnectionManager(
new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection]
with SynchronizedMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
private val messageStatuses = new HashMap[Int, MessageStatus] // Tracks sent messages for which we are awaiting acknowledgements. Entries are added to this
// map when messages are sent and are removed when acknowledgement messages are received or when
// acknowledgement timeouts expire
private val messageStatuses = new HashMap[Int, MessageStatus] // [MessageId, MessageStatus]
private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
private val registerRequests = new SynchronizedQueue[SendingConnection] private val registerRequests = new SynchronizedQueue[SendingConnection]
...@@ -899,22 +904,41 @@ private[nio] class ConnectionManager( ...@@ -899,22 +904,41 @@ private[nio] class ConnectionManager(
: Future[Message] = { : Future[Message] = {
val promise = Promise[Message]() val promise = Promise[Message]()
val timeoutTask = new TimerTask { // It's important that the TimerTask doesn't capture a reference to `message`, which can cause
override def run(): Unit = { // memory leaks since cancelled TimerTasks won't necessarily be garbage collected until the time
// at which they would originally be scheduled to run. Therefore, extract the message id
// from outside of the TimerTask closure (see SPARK-4393 for more context).
val messageId = message.id
// Keep a weak reference to the promise so that the completed promise may be garbage-collected
val promiseReference = new WeakReference(promise)
val timeoutTask: TimerTask = new TimerTask {
override def run(timeout: Timeout): Unit = {
messageStatuses.synchronized { messageStatuses.synchronized {
messageStatuses.remove(message.id).foreach ( s => { messageStatuses.remove(messageId).foreach { s =>
val e = new IOException("sendMessageReliably failed because ack " + val e = new IOException("sendMessageReliably failed because ack " +
s"was not received within $ackTimeout sec") s"was not received within $ackTimeout sec")
if (!promise.tryFailure(e)) { val p = promiseReference.get
logWarning("Ignore error because promise is completed", e) if (p != null) {
// Attempt to fail the promise with a Timeout exception
if (!p.tryFailure(e)) {
// If we reach here, then someone else has already signalled success or failure
// on this promise, so log a warning:
logError("Ignore error because promise is completed", e)
}
} else {
// The WeakReference was empty, which should never happen because
// sendMessageReliably's caller should have a strong reference to promise.future;
logError("Promise was garbage collected; this should never happen!", e)
} }
}) }
} }
} }
} }
val timeoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit.SECONDS)
val status = new MessageStatus(message, connectionManagerId, s => { val status = new MessageStatus(message, connectionManagerId, s => {
timeoutTask.cancel() timeoutTaskHandle.cancel()
s match { s match {
case scala.util.Failure(e) => case scala.util.Failure(e) =>
// Indicates a failure where we either never sent or never got ACK'd // Indicates a failure where we either never sent or never got ACK'd
...@@ -943,7 +967,6 @@ private[nio] class ConnectionManager( ...@@ -943,7 +967,6 @@ private[nio] class ConnectionManager(
messageStatuses += ((message.id, status)) messageStatuses += ((message.id, status))
} }
ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000)
sendMessage(connectionManagerId, message) sendMessage(connectionManagerId, message)
promise.future promise.future
} }
...@@ -953,7 +976,7 @@ private[nio] class ConnectionManager( ...@@ -953,7 +976,7 @@ private[nio] class ConnectionManager(
} }
def stop() { def stop() {
ackTimeoutMonitor.cancel() ackTimeoutMonitor.stop()
selectorThread.interrupt() selectorThread.interrupt()
selectorThread.join() selectorThread.join()
selector.close() selector.close()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment