Skip to content
Snippets Groups Projects
Commit 7215aa74 authored by Josh Rosen's avatar Josh Rosen Committed by Andrew Or
Browse files

[SPARK-6209] Clean up connections in ExecutorClassLoader after failing to load...

[SPARK-6209] Clean up connections in ExecutorClassLoader after failing to load classes (master branch PR)

ExecutorClassLoader does not ensure proper cleanup of network connections that it opens. If it fails to load a class, it may leak partially-consumed InputStreams that are connected to the REPL's HTTP class server, causing that server to exhaust its thread pool, which can cause the entire job to hang.  See [SPARK-6209](https://issues.apache.org/jira/browse/SPARK-6209) for more details, including a bug reproduction.

This patch fixes this issue by ensuring proper cleanup of these resources.  It also adds logging for unexpected error cases.

This PR is an extended version of #4935 and adds a regression test.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #4944 from JoshRosen/executorclassloader-leak-master-branch and squashes the following commits:

e0e3c25 [Josh Rosen] Wrap try block around getReponseCode; re-enable keep-alive by closing error stream
961c284 [Josh Rosen] Roll back changes that were added to get the regression test to fail
7ee2261 [Josh Rosen] Add a failing regression test
e2d70a3 [Josh Rosen] Properly clean up after errors in ExecutorClassLoader
parent a8f51b82
No related branches found
No related tags found
No related merge requests found
......@@ -84,6 +84,11 @@
<artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-all</artifactId>
<scope>test</scope>
</dependency>
<!-- Explicit listing of transitive deps that are shaded. Otherwise, odd compiler crashes. -->
<dependency>
......
......@@ -17,9 +17,10 @@
package org.apache.spark.repl
import java.io.{ByteArrayOutputStream, InputStream, FileNotFoundException}
import java.net.{URI, URL, URLEncoder}
import java.util.concurrent.{Executors, ExecutorService}
import java.io.{IOException, ByteArrayOutputStream, InputStream}
import java.net.{HttpURLConnection, URI, URL, URLEncoder}
import scala.util.control.NonFatal
import org.apache.hadoop.fs.{FileSystem, Path}
......@@ -43,6 +44,9 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
val parentLoader = new ParentClassLoader(parent)
// Allows HTTP connect and read timeouts to be controlled for testing / debugging purposes
private[repl] var httpUrlConnectionTimeoutMillis: Int = -1
// Hadoop FileSystem object for our URI, if it isn't using HTTP
var fileSystem: FileSystem = {
if (Set("http", "https", "ftp").contains(uri.getScheme)) {
......@@ -71,30 +75,66 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
}
}
private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = {
val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) {
val uri = new URI(classUri + "/" + urlEncode(pathInDirectory))
val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager)
newuri.toURL
} else {
new URL(classUri + "/" + urlEncode(pathInDirectory))
}
val connection: HttpURLConnection = Utils.setupSecureURLConnection(url.openConnection(),
SparkEnv.get.securityManager).asInstanceOf[HttpURLConnection]
// Set the connection timeouts (for testing purposes)
if (httpUrlConnectionTimeoutMillis != -1) {
connection.setConnectTimeout(httpUrlConnectionTimeoutMillis)
connection.setReadTimeout(httpUrlConnectionTimeoutMillis)
}
connection.connect()
try {
if (connection.getResponseCode != 200) {
// Close the error stream so that the connection is eligible for re-use
try {
connection.getErrorStream.close()
} catch {
case ioe: IOException =>
logError("Exception while closing error stream", ioe)
}
throw new ClassNotFoundException(s"Class file not found at URL $url")
} else {
connection.getInputStream
}
} catch {
case NonFatal(e) if !e.isInstanceOf[ClassNotFoundException] =>
connection.disconnect()
throw e
}
}
private def getClassFileInputStreamFromFileSystem(pathInDirectory: String): InputStream = {
val path = new Path(directory, pathInDirectory)
if (fileSystem.exists(path)) {
fileSystem.open(path)
} else {
throw new ClassNotFoundException(s"Class file not found at path $path")
}
}
def findClassLocally(name: String): Option[Class[_]] = {
val pathInDirectory = name.replace('.', '/') + ".class"
var inputStream: InputStream = null
try {
val pathInDirectory = name.replace('.', '/') + ".class"
val inputStream = {
inputStream = {
if (fileSystem != null) {
fileSystem.open(new Path(directory, pathInDirectory))
getClassFileInputStreamFromFileSystem(pathInDirectory)
} else {
val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) {
val uri = new URI(classUri + "/" + urlEncode(pathInDirectory))
val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager)
newuri.toURL
} else {
new URL(classUri + "/" + urlEncode(pathInDirectory))
}
Utils.setupSecureURLConnection(url.openConnection(), SparkEnv.get.securityManager)
.getInputStream
getClassFileInputStreamFromHttpServer(pathInDirectory)
}
}
val bytes = readAndTransformClass(name, inputStream)
inputStream.close()
Some(defineClass(name, bytes, 0, bytes.length))
} catch {
case e: FileNotFoundException =>
case e: ClassNotFoundException =>
// We did not find the class
logDebug(s"Did not load class $name from REPL class server at $uri", e)
None
......@@ -102,6 +142,15 @@ class ExecutorClassLoader(conf: SparkConf, classUri: String, parent: ClassLoader
// Something bad happened while checking if the class exists
logError(s"Failed to check existence of class $name on REPL class server at $uri", e)
None
} finally {
if (inputStream != null) {
try {
inputStream.close()
} catch {
case e: Exception =>
logError("Exception while closing inputStream", e)
}
}
}
}
......
......@@ -20,13 +20,25 @@ package org.apache.spark.repl
import java.io.File
import java.net.{URL, URLClassLoader}
import scala.concurrent.duration._
import scala.language.implicitConversions
import scala.language.postfixOps
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.scalatest.concurrent.Interruptor
import org.scalatest.concurrent.Timeouts._
import org.scalatest.mock.MockitoSugar
import org.mockito.Mockito._
import org.apache.spark.{SparkConf, TestUtils}
import org.apache.spark._
import org.apache.spark.util.Utils
class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
class ExecutorClassLoaderSuite
extends FunSuite
with BeforeAndAfterAll
with MockitoSugar
with Logging {
val childClassNames = List("ReplFakeClass1", "ReplFakeClass2")
val parentClassNames = List("ReplFakeClass1", "ReplFakeClass2", "ReplFakeClass3")
......@@ -34,6 +46,7 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
var tempDir2: File = _
var url1: String = _
var urls2: Array[URL] = _
var classServer: HttpServer = _
override def beforeAll() {
super.beforeAll()
......@@ -47,8 +60,12 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
override def afterAll() {
super.afterAll()
if (classServer != null) {
classServer.stop()
}
Utils.deleteRecursively(tempDir1)
Utils.deleteRecursively(tempDir2)
SparkEnv.set(null)
}
test("child first") {
......@@ -83,4 +100,53 @@ class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll {
}
}
test("failing to fetch classes from HTTP server should not leak resources (SPARK-6209)") {
// This is a regression test for SPARK-6209, a bug where each failed attempt to load a class
// from the driver's class server would leak a HTTP connection, causing the class server's
// thread / connection pool to be exhausted.
val conf = new SparkConf()
val securityManager = new SecurityManager(conf)
classServer = new HttpServer(conf, tempDir1, securityManager)
classServer.start()
// ExecutorClassLoader uses SparkEnv's SecurityManager, so we need to mock this
val mockEnv = mock[SparkEnv]
when(mockEnv.securityManager).thenReturn(securityManager)
SparkEnv.set(mockEnv)
// Create an ExecutorClassLoader that's configured to load classes from the HTTP server
val parentLoader = new URLClassLoader(Array.empty, null)
val classLoader = new ExecutorClassLoader(conf, classServer.uri, parentLoader, false)
classLoader.httpUrlConnectionTimeoutMillis = 500
// Check that this class loader can actually load classes that exist
val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance()
val fakeClassVersion = fakeClass.toString
assert(fakeClassVersion === "1")
// Try to perform a full GC now, since GC during the test might mask resource leaks
System.gc()
// When the original bug occurs, the test thread becomes blocked in a classloading call
// and does not respond to interrupts. Therefore, use a custom ScalaTest interruptor to
// shut down the HTTP server when the test times out
val interruptor: Interruptor = new Interruptor {
override def apply(thread: Thread): Unit = {
classServer.stop()
classServer = null
thread.interrupt()
}
}
def tryAndFailToLoadABunchOfClasses(): Unit = {
// The number of trials here should be much larger than Jetty's thread / connection limit
// in order to expose thread or connection leaks
for (i <- 1 to 1000) {
if (Thread.currentThread().isInterrupted) {
throw new InterruptedException()
}
// Incorporate the iteration number into the class name in order to avoid any response
// caching that might be added in the future
intercept[ClassNotFoundException] {
classLoader.loadClass(s"ReplFakeClassDoesNotExist$i").newInstance()
}
}
}
failAfter(10 seconds)(tryAndFailToLoadABunchOfClasses())(interruptor)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment