Skip to content
Snippets Groups Projects
Commit 1d5663e9 authored by Andrew Or's avatar Andrew Or
Browse files

[SPARK-5760][SPARK-5761] Fix standalone rest protocol corner cases + revamp tests

The changes are summarized in the commit message. Test or test-related code accounts for 90% of the lines changed.

Author: Andrew Or <andrew@databricks.com>

Closes #4557 from andrewor14/rest-tests and squashes the following commits:

b4dc980 [Andrew Or] Merge branch 'master' of github.com:apache/spark into rest-tests
b55e40f [Andrew Or] Add test for unknown fields
cc96993 [Andrew Or] private[spark] -> private[rest]
578cf45 [Andrew Or] Clean up test code a little
d82d971 [Andrew Or] v1 -> serverVersion
ea48f65 [Andrew Or] Merge branch 'master' of github.com:apache/spark into rest-tests
00999a8 [Andrew Or] Revamp tests + fix a few corner cases
parent 47c73d41
No related branches found
No related tags found
No related merge requests found
......@@ -19,10 +19,11 @@ package org.apache.spark.deploy.rest
import java.io.{DataOutputStream, FileNotFoundException}
import java.net.{HttpURLConnection, SocketException, URL}
import javax.servlet.http.HttpServletResponse
import scala.io.Source
import com.fasterxml.jackson.databind.JsonMappingException
import com.fasterxml.jackson.core.JsonProcessingException
import com.google.common.base.Charsets
import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion}
......@@ -155,10 +156,21 @@ private[spark] class StandaloneRestClient extends Logging {
/**
* Read the response from the server and return it as a validated [[SubmitRestProtocolResponse]].
* If the response represents an error, report the embedded message to the user.
* Exposed for testing.
*/
private def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = {
private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = {
try {
val responseJson = Source.fromInputStream(connection.getInputStream).mkString
val dataStream =
if (connection.getResponseCode == HttpServletResponse.SC_OK) {
connection.getInputStream
} else {
connection.getErrorStream
}
// If the server threw an exception while writing a response, it will not have a body
if (dataStream == null) {
throw new SubmitRestProtocolException("Server returned empty body")
}
val responseJson = Source.fromInputStream(dataStream).mkString
logDebug(s"Response from the server:\n$responseJson")
val response = SubmitRestProtocolMessage.fromJson(responseJson)
response.validate()
......@@ -177,7 +189,7 @@ private[spark] class StandaloneRestClient extends Logging {
case unreachable @ (_: FileNotFoundException | _: SocketException) =>
throw new SubmitRestConnectionException(
s"Unable to connect to server ${connection.getURL}", unreachable)
case malformed @ (_: SubmitRestProtocolException | _: JsonMappingException) =>
case malformed @ (_: JsonProcessingException | _: SubmitRestProtocolException) =>
throw new SubmitRestProtocolException(
"Malformed response received from server", malformed)
}
......@@ -284,7 +296,27 @@ private[spark] object StandaloneRestClient {
val REPORT_DRIVER_STATUS_MAX_TRIES = 10
val PROTOCOL_VERSION = "v1"
/** Submit an application, assuming Spark parameters are specified through system properties. */
/**
* Submit an application, assuming Spark parameters are specified through the given config.
* This is abstracted to its own method for testing purposes.
*/
private[rest] def run(
appResource: String,
mainClass: String,
appArgs: Array[String],
conf: SparkConf,
env: Map[String, String] = sys.env): SubmitRestProtocolResponse = {
val master = conf.getOption("spark.master").getOrElse {
throw new IllegalArgumentException("'spark.master' must be set.")
}
val sparkProperties = conf.getAll.toMap
val environmentVariables = env.filter { case (k, _) => k.startsWith("SPARK_") }
val client = new StandaloneRestClient
val submitRequest = client.constructSubmitRequest(
appResource, mainClass, appArgs, sparkProperties, environmentVariables)
client.createSubmission(master, submitRequest)
}
def main(args: Array[String]): Unit = {
if (args.size < 2) {
sys.error("Usage: StandaloneRestClient [app resource] [main class] [app args*]")
......@@ -294,14 +326,6 @@ private[spark] object StandaloneRestClient {
val mainClass = args(1)
val appArgs = args.slice(2, args.size)
val conf = new SparkConf
val master = conf.getOption("spark.master").getOrElse {
throw new IllegalArgumentException("'spark.master' must be set.")
}
val sparkProperties = conf.getAll.toMap
val environmentVariables = sys.env.filter { case (k, _) => k.startsWith("SPARK_") }
val client = new StandaloneRestClient
val submitRequest = client.constructSubmitRequest(
appResource, mainClass, appArgs, sparkProperties, environmentVariables)
client.createSubmission(master, submitRequest)
run(appResource, mainClass, appArgs, conf)
}
}
......@@ -17,15 +17,14 @@
package org.apache.spark.deploy.rest
import java.io.{DataOutputStream, File}
import java.io.File
import java.net.InetSocketAddress
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
import scala.io.Source
import akka.actor.ActorRef
import com.fasterxml.jackson.databind.JsonMappingException
import com.google.common.base.Charsets
import com.fasterxml.jackson.core.JsonProcessingException
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler}
import org.eclipse.jetty.util.thread.QueuedThreadPool
......@@ -70,14 +69,14 @@ private[spark] class StandaloneRestServer(
import StandaloneRestServer._
private var _server: Option[Server] = None
private val baseContext = s"/$PROTOCOL_VERSION/submissions"
// A mapping from servlets to the URL prefixes they are responsible for
private val servletToContext = Map[StandaloneRestServlet, String](
new SubmitRequestServlet(masterActor, masterUrl, masterConf) -> s"$baseContext/create/*",
new KillRequestServlet(masterActor, masterConf) -> s"$baseContext/kill/*",
new StatusRequestServlet(masterActor, masterConf) -> s"$baseContext/status/*",
new ErrorServlet -> "/*" // default handler
// A mapping from URL prefixes to servlets that serve them. Exposed for testing.
protected val baseContext = s"/$PROTOCOL_VERSION/submissions"
protected val contextToServlet = Map[String, StandaloneRestServlet](
s"$baseContext/create/*" -> new SubmitRequestServlet(masterActor, masterUrl, masterConf),
s"$baseContext/kill/*" -> new KillRequestServlet(masterActor, masterConf),
s"$baseContext/status/*" -> new StatusRequestServlet(masterActor, masterConf),
"/*" -> new ErrorServlet // default handler
)
/** Start the server and return the bound port. */
......@@ -99,7 +98,7 @@ private[spark] class StandaloneRestServer(
server.setThreadPool(threadPool)
val mainHandler = new ServletContextHandler
mainHandler.setContextPath("/")
servletToContext.foreach { case (servlet, prefix) =>
contextToServlet.foreach { case (prefix, servlet) =>
mainHandler.addServlet(new ServletHolder(servlet), prefix)
}
server.setHandler(mainHandler)
......@@ -113,7 +112,7 @@ private[spark] class StandaloneRestServer(
}
}
private object StandaloneRestServer {
private[rest] object StandaloneRestServer {
val PROTOCOL_VERSION = StandaloneRestClient.PROTOCOL_VERSION
val SC_UNKNOWN_PROTOCOL_VERSION = 468
}
......@@ -121,20 +120,7 @@ private object StandaloneRestServer {
/**
* An abstract servlet for handling requests passed to the [[StandaloneRestServer]].
*/
private abstract class StandaloneRestServlet extends HttpServlet with Logging {
/** Service a request. If an exception is thrown in the process, indicate server error. */
protected override def service(
request: HttpServletRequest,
response: HttpServletResponse): Unit = {
try {
super.service(request, response)
} catch {
case e: Exception =>
logError("Exception while handling request", e)
response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
}
}
private[rest] abstract class StandaloneRestServlet extends HttpServlet with Logging {
/**
* Serialize the given response message to JSON and send it through the response servlet.
......@@ -146,11 +132,7 @@ private abstract class StandaloneRestServlet extends HttpServlet with Logging {
val message = validateResponse(responseMessage, responseServlet)
responseServlet.setContentType("application/json")
responseServlet.setCharacterEncoding("utf-8")
responseServlet.setStatus(HttpServletResponse.SC_OK)
val content = message.toJson.getBytes(Charsets.UTF_8)
val out = new DataOutputStream(responseServlet.getOutputStream)
out.write(content)
out.close()
responseServlet.getWriter.write(message.toJson)
}
/**
......@@ -186,6 +168,19 @@ private abstract class StandaloneRestServlet extends HttpServlet with Logging {
e
}
/**
* Parse a submission ID from the relative path, assuming it is the first part of the path.
* For instance, we expect the path to take the form /[submission ID]/maybe/something/else.
* The returned submission ID cannot be empty. If the path is unexpected, return None.
*/
protected def parseSubmissionId(path: String): Option[String] = {
if (path == null || path.isEmpty) {
None
} else {
path.stripPrefix("/").split("/").headOption.filter(_.nonEmpty)
}
}
/**
* Validate the response to ensure that it is correctly constructed.
*
......@@ -209,7 +204,7 @@ private abstract class StandaloneRestServlet extends HttpServlet with Logging {
/**
* A servlet for handling kill requests passed to the [[StandaloneRestServer]].
*/
private class KillRequestServlet(masterActor: ActorRef, conf: SparkConf)
private[rest] class KillRequestServlet(masterActor: ActorRef, conf: SparkConf)
extends StandaloneRestServlet {
/**
......@@ -219,18 +214,15 @@ private class KillRequestServlet(masterActor: ActorRef, conf: SparkConf)
protected override def doPost(
request: HttpServletRequest,
response: HttpServletResponse): Unit = {
val submissionId = request.getPathInfo.stripPrefix("/")
val responseMessage =
if (submissionId.nonEmpty) {
handleKill(submissionId)
} else {
response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
handleError("Submission ID is missing in kill request.")
}
val submissionId = parseSubmissionId(request.getPathInfo)
val responseMessage = submissionId.map(handleKill).getOrElse {
response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
handleError("Submission ID is missing in kill request.")
}
sendResponse(responseMessage, response)
}
private def handleKill(submissionId: String): KillSubmissionResponse = {
protected def handleKill(submissionId: String): KillSubmissionResponse = {
val askTimeout = AkkaUtils.askTimeout(conf)
val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse](
DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout)
......@@ -246,7 +238,7 @@ private class KillRequestServlet(masterActor: ActorRef, conf: SparkConf)
/**
* A servlet for handling status requests passed to the [[StandaloneRestServer]].
*/
private class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf)
private[rest] class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf)
extends StandaloneRestServlet {
/**
......@@ -256,18 +248,15 @@ private class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf)
protected override def doGet(
request: HttpServletRequest,
response: HttpServletResponse): Unit = {
val submissionId = request.getPathInfo.stripPrefix("/")
val responseMessage =
if (submissionId.nonEmpty) {
handleStatus(submissionId)
} else {
response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
handleError("Submission ID is missing in status request.")
}
val submissionId = parseSubmissionId(request.getPathInfo)
val responseMessage = submissionId.map(handleStatus).getOrElse {
response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
handleError("Submission ID is missing in status request.")
}
sendResponse(responseMessage, response)
}
private def handleStatus(submissionId: String): SubmissionStatusResponse = {
protected def handleStatus(submissionId: String): SubmissionStatusResponse = {
val askTimeout = AkkaUtils.askTimeout(conf)
val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse](
DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout)
......@@ -287,7 +276,7 @@ private class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf)
/**
* A servlet for handling submit requests passed to the [[StandaloneRestServer]].
*/
private class SubmitRequestServlet(
private[rest] class SubmitRequestServlet(
masterActor: ActorRef,
masterUrl: String,
conf: SparkConf)
......@@ -313,7 +302,7 @@ private class SubmitRequestServlet(
handleSubmit(requestMessageJson, requestMessage, responseServlet)
} catch {
// The client failed to provide a valid JSON, so this is not our fault
case e @ (_: JsonMappingException | _: SubmitRestProtocolException) =>
case e @ (_: JsonProcessingException | _: SubmitRestProtocolException) =>
responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST)
handleError("Malformed request: " + formatException(e))
}
......@@ -413,7 +402,7 @@ private class ErrorServlet extends StandaloneRestServlet {
request: HttpServletRequest,
response: HttpServletResponse): Unit = {
val path = request.getPathInfo
val parts = path.stripPrefix("/").split("/").toSeq
val parts = path.stripPrefix("/").split("/").filter(_.nonEmpty).toList
var versionMismatch = false
var msg =
parts match {
......@@ -423,10 +412,10 @@ private class ErrorServlet extends StandaloneRestServlet {
case `serverVersion` :: Nil =>
// http://host:port/correct-version
"Missing the /submissions prefix."
case `serverVersion` :: "submissions" :: Nil =>
// http://host:port/correct-version/submissions
case `serverVersion` :: "submissions" :: tail =>
// http://host:port/correct-version/submissions/*
"Missing an action: please specify one of /create, /kill, or /status."
case unknownVersion :: _ =>
case unknownVersion :: tail =>
// http://host:port/unknown-version/*
versionMismatch = true
s"Unknown protocol version '$unknownVersion'."
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment