diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index c3ccdb012fb1d74f9c1b4d1c7baf36c82440531e..5cdc4eeeccbc61dca0d8efbb37de38851893430f 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -27,6 +27,7 @@ import java.util.Arrays import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream} import javax.net.ssl._ +import javax.servlet.http.HttpServletResponse import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} import scala.collection.JavaConverters._ @@ -186,12 +187,12 @@ private[spark] object TestUtils { } /** - * Returns the response code from an HTTP(S) URL. + * Returns the response code and url (if redirected) from an HTTP(S) URL. */ - def httpResponseCode( + def httpResponseCodeAndURL( url: URL, method: String = "GET", - headers: Seq[(String, String)] = Nil): Int = { + headers: Seq[(String, String)] = Nil): (Int, Option[String]) = { val connection = url.openConnection().asInstanceOf[HttpURLConnection] connection.setRequestMethod(method) headers.foreach { case (k, v) => connection.setRequestProperty(k, v) } @@ -210,16 +211,30 @@ private[spark] object TestUtils { sslCtx.init(null, Array(trustManager), new SecureRandom()) connection.asInstanceOf[HttpsURLConnection].setSSLSocketFactory(sslCtx.getSocketFactory()) connection.asInstanceOf[HttpsURLConnection].setHostnameVerifier(verifier) + connection.setInstanceFollowRedirects(false) } try { connection.connect() - connection.getResponseCode() + if (connection.getResponseCode == HttpServletResponse.SC_FOUND) { + (connection.getResponseCode, Option(connection.getHeaderField("Location"))) + } else { + (connection.getResponseCode(), None) + } } finally { connection.disconnect() } } + /** + * Returns the response code from an HTTP(S) URL. + */ + def httpResponseCode( + url: URL, + method: String = "GET", + headers: Seq[(String, String)] = Nil): Int = { + httpResponseCodeAndURL(url, method, headers)._1 + } } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index fbe8012ea2dae458b43cfef7f2e3418f2a057614..639b8577617f6e05d3a4f9402a4dcc2be074e34a 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -330,7 +330,7 @@ private[spark] object JettyUtils extends Logging { // redirect the HTTP requests to HTTPS port httpConnector.setName(REDIRECT_CONNECTOR_NAME) - collection.addHandler(createRedirectHttpsHandler(securePort, scheme)) + collection.addHandler(createRedirectHttpsHandler(connector, scheme)) Some(connector) case None => @@ -378,7 +378,9 @@ private[spark] object JettyUtils extends Logging { server.getHandler().asInstanceOf[ContextHandlerCollection]) } - private def createRedirectHttpsHandler(securePort: Int, scheme: String): ContextHandler = { + private def createRedirectHttpsHandler( + httpsConnector: ServerConnector, + scheme: String): ContextHandler = { val redirectHandler: ContextHandler = new ContextHandler redirectHandler.setContextPath("/") redirectHandler.setVirtualHosts(Array("@" + REDIRECT_CONNECTOR_NAME)) @@ -391,8 +393,8 @@ private[spark] object JettyUtils extends Logging { if (baseRequest.isSecure) { return } - val httpsURI = createRedirectURI(scheme, baseRequest.getServerName, securePort, - baseRequest.getRequestURI, baseRequest.getQueryString) + val httpsURI = createRedirectURI(scheme, baseRequest.getServerName, + httpsConnector.getLocalPort, baseRequest.getRequestURI, baseRequest.getQueryString) response.setContentLength(0) response.encodeRedirectURL(httpsURI) response.sendRedirect(httpsURI) diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 7c3d891047dec82a77cf898d5b5d40c398498c4a..16fb4666f36211a2cf1be15cc90bde1ef690d924 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -267,8 +267,12 @@ class UISuite extends SparkFunSuite { s"$scheme://localhost:$port/test1/root", s"$scheme://localhost:$port/test2/root") urls.foreach { url => - val rc = TestUtils.httpResponseCode(new URL(url)) + val (rc, redirectUrl) = TestUtils.httpResponseCodeAndURL(new URL(url)) assert(rc === expected, s"Unexpected status $rc for $url") + if (rc == HttpServletResponse.SC_FOUND) { + assert( + TestUtils.httpResponseCode(new URL(redirectUrl.get)) === HttpServletResponse.SC_OK) + } } } } finally {