Skip to content
Snippets Groups Projects
Commit f816e739 authored by Cheng Lian's avatar Cheng Lian
Browse files

[SPARK-5751] [SQL] [WIP] Revamped HiveThriftServer2Suite for robustness

**NOTICE** Do NOT merge this, as we're waiting for #3881 to be merged.

`HiveThriftServer2Suite` has been notorious for its flakiness for a while. This was mostly due to spawning and communicate with external server processes. This PR revamps this test suite for better robustness:

1. Fixes a racing condition occurred while using `tail -f` to check log file

   It's possible that the line we are looking for has already been printed into the log file before we start the `tail -f` process. This PR uses `tail -n +0 -f` to ensure all lines are checked.

2. Retries up to 3 times if the server fails to start

   In most of the cases, the server fails to start because of port conflict. This PR no longer asks the system to choose an available TCP port, but uses a random port first, and retries up to 3 times if the server fails to start.

3. A server instance is reused among all test cases within a single suite

   The original `HiveThriftServer2Suite` is splitted into two test suites, `HiveThriftBinaryServerSuite` and `HiveThriftHttpServerSuite`. Each suite starts a `HiveThriftServer2` instance and reuses it for all of its test cases.

**TODO**

- [ ] Starts the Thrift server in foreground once #3881 is merged (adding `--foreground` flag to `spark-daemon.sh`)

<!-- Reviewable:start -->
[<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/4720)
<!-- Reviewable:end -->

Author: Cheng Lian <lian@databricks.com>

Closes #4720 from liancheng/revamp-thrift-server-tests and squashes the following commits:

d6c80eb [Cheng Lian] Relaxes server startup timeout
6f14eb1 [Cheng Lian] Revamped HiveThriftServer2Suite for robustness
parent 2a0fe348
No related branches found
No related tags found
No related merge requests found
......@@ -18,217 +18,75 @@
package org.apache.spark.sql.hive.thriftserver
import java.io.File
import java.net.ServerSocket
import java.sql.{Date, DriverManager, Statement}
import java.util.concurrent.TimeoutException
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
import scala.concurrent.{Await, Promise}
import scala.sys.process.{Process, ProcessLogger}
import scala.util.Try
import scala.util.{Random, Try}
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.jdbc.HiveDriver
import org.apache.hive.service.auth.PlainSaslHelper
import org.apache.hive.service.cli.GetInfoType
import org.apache.hive.service.cli.thrift.TCLIService.Client
import org.apache.hive.service.cli.thrift._
import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient
import org.apache.thrift.protocol.TBinaryProtocol
import org.apache.thrift.transport.TSocket
import org.scalatest.FunSuite
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.util.getTempFilePath
import org.apache.spark.sql.catalyst.util
import org.apache.spark.sql.hive.HiveShim
/**
* Tests for the HiveThriftServer2 using JDBC.
*
* NOTE: SPARK_PREPEND_CLASSES is explicitly disabled in this test suite. Assembly jar must be
* rebuilt after changing HiveThriftServer2 related code.
*/
class HiveThriftServer2Suite extends FunSuite with Logging {
Class.forName(classOf[HiveDriver].getCanonicalName)
object TestData {
def getTestDataFilePath(name: String) = {
Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name")
}
val smallKv = getTestDataFilePath("small_kv.txt")
val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt")
object TestData {
def getTestDataFilePath(name: String) = {
Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name")
}
def randomListeningPort = {
// Let the system to choose a random available port to avoid collision with other parallel
// builds.
val socket = new ServerSocket(0)
val port = socket.getLocalPort
socket.close()
port
}
val smallKv = getTestDataFilePath("small_kv.txt")
val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt")
}
def withJdbcStatement(
serverStartTimeout: FiniteDuration = 1.minute,
httpMode: Boolean = false)(
f: Statement => Unit) {
val port = randomListeningPort
startThriftServer(port, serverStartTimeout, httpMode) {
val jdbcUri = if (httpMode) {
s"jdbc:hive2://${"localhost"}:$port/" +
"default?hive.server2.transport.mode=http;hive.server2.thrift.http.path=cliservice"
} else {
s"jdbc:hive2://${"localhost"}:$port/"
}
class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
override def mode = ServerMode.binary
val user = System.getProperty("user.name")
val connection = DriverManager.getConnection(jdbcUri, user, "")
val statement = connection.createStatement()
try {
f(statement)
} finally {
statement.close()
connection.close()
}
}
}
private def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = {
// Transport creation logics below mimics HiveConnection.createBinaryTransport
val rawTransport = new TSocket("localhost", serverPort)
val user = System.getProperty("user.name")
val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
val protocol = new TBinaryProtocol(transport)
val client = new ThriftCLIServiceClient(new Client(protocol))
def withCLIServiceClient(
serverStartTimeout: FiniteDuration = 1.minute)(
f: ThriftCLIServiceClient => Unit) {
val port = randomListeningPort
transport.open()
try f(client) finally transport.close()
}
startThriftServer(port) {
// Transport creation logics below mimics HiveConnection.createBinaryTransport
val rawTransport = new TSocket("localhost", port)
test("GetInfo Thrift API") {
withCLIServiceClient { client =>
val user = System.getProperty("user.name")
val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
val protocol = new TBinaryProtocol(transport)
val client = new ThriftCLIServiceClient(new Client(protocol))
transport.open()
try {
f(client)
} finally {
transport.close()
}
}
}
val sessionHandle = client.openSession(user, "")
def startThriftServer(
port: Int,
serverStartTimeout: FiniteDuration = 1.minute,
httpMode: Boolean = false)(
f: => Unit) {
val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator)
val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator)
val warehousePath = getTempFilePath("warehouse")
val metastorePath = getTempFilePath("metastore")
val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true"
val command =
if (httpMode) {
s"""$startScript
| --master local
| --hiveconf hive.root.logger=INFO,console
| --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
| --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost
| --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=http
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT}=$port
| --driver-class-path ${sys.props("java.class.path")}
| --conf spark.ui.enabled=false
""".stripMargin.split("\\s+").toSeq
} else {
s"""$startScript
| --master local
| --hiveconf hive.root.logger=INFO,console
| --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
| --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$port
| --driver-class-path ${sys.props("java.class.path")}
| --conf spark.ui.enabled=false
""".stripMargin.split("\\s+").toSeq
assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") {
client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue
}
val serverRunning = Promise[Unit]()
val buffer = new ArrayBuffer[String]()
val LOGGING_MARK =
s"starting ${HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$")}, logging to "
var logTailingProcess: Process = null
var logFilePath: String = null
def captureLogOutput(line: String): Unit = {
buffer += line
if (line.contains("ThriftBinaryCLIService listening on") ||
line.contains("Started ThriftHttpCLIService in http")) {
serverRunning.success(())
assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") {
client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue
}
}
def captureThriftServerOutput(source: String)(line: String): Unit = {
if (line.startsWith(LOGGING_MARK)) {
logFilePath = line.drop(LOGGING_MARK.length).trim
// Ensure that the log file is created so that the `tail' command won't fail
Try(new File(logFilePath).createNewFile())
logTailingProcess = Process(s"/usr/bin/env tail -f $logFilePath")
.run(ProcessLogger(captureLogOutput, _ => ()))
assertResult(true, "Spark version shouldn't be \"Unknown\"") {
val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue
logInfo(s"Spark version: $version")
version != "Unknown"
}
}
val env = Seq(
// Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths
"SPARK_TESTING" -> "0")
Process(command, None, env: _*).run(ProcessLogger(
captureThriftServerOutput("stdout"),
captureThriftServerOutput("stderr")))
try {
Await.result(serverRunning.future, serverStartTimeout)
f
} catch {
case cause: Exception =>
cause match {
case _: TimeoutException =>
logError(s"Failed to start Hive Thrift server within $serverStartTimeout", cause)
case _ =>
}
logError(
s"""
|=====================================
|HiveThriftServer2Suite failure output
|=====================================
|HiveThriftServer2 command line: ${command.mkString(" ")}
|Binding port: $port
|System user: ${System.getProperty("user.name")}
|
|${buffer.mkString("\n")}
|=========================================
|End HiveThriftServer2Suite failure output
|=========================================
""".stripMargin, cause)
throw cause
} finally {
warehousePath.delete()
metastorePath.delete()
Process(stopScript, None, env: _*).run().exitValue()
// The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while.
Thread.sleep(3.seconds.toMillis)
Option(logTailingProcess).map(_.destroy())
Option(logFilePath).map(new File(_).delete())
}
}
test("Test JDBC query execution") {
withJdbcStatement() { statement =>
test("JDBC query execution") {
withJdbcStatement { statement =>
val queries = Seq(
"SET spark.sql.shuffle.partitions=3",
"DROP TABLE IF EXISTS test",
......@@ -246,27 +104,16 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
}
}
test("Test JDBC query execution in Http Mode") {
withJdbcStatement(httpMode = true) { statement =>
val queries = Seq(
"SET spark.sql.shuffle.partitions=3",
"DROP TABLE IF EXISTS test",
"CREATE TABLE test(key INT, val STRING)",
s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test",
"CACHE TABLE test")
queries.foreach(statement.execute)
assertResult(5, "Row count mismatch") {
val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test")
resultSet.next()
resultSet.getInt(1)
}
test("Checks Hive version") {
withJdbcStatement { statement =>
val resultSet = statement.executeQuery("SET spark.sql.hive.version")
resultSet.next()
assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
}
}
test("SPARK-3004 regression: result set containing NULL") {
withJdbcStatement() { statement =>
withJdbcStatement { statement =>
val queries = Seq(
"DROP TABLE IF EXISTS test_null",
"CREATE TABLE test_null(key INT, val STRING)",
......@@ -286,45 +133,8 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
}
}
test("GetInfo Thrift API") {
withCLIServiceClient() { client =>
val user = System.getProperty("user.name")
val sessionHandle = client.openSession(user, "")
assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") {
client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue
}
assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") {
client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue
}
assertResult(true, "Spark version shouldn't be \"Unknown\"") {
val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue
logInfo(s"Spark version: $version")
version != "Unknown"
}
}
}
test("Checks Hive version") {
withJdbcStatement() { statement =>
val resultSet = statement.executeQuery("SET spark.sql.hive.version")
resultSet.next()
assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
}
}
test("Checks Hive version in Http Mode") {
withJdbcStatement(httpMode = true) { statement =>
val resultSet = statement.executeQuery("SET spark.sql.hive.version")
resultSet.next()
assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
}
}
test("SPARK-4292 regression: result set iterator issue") {
withJdbcStatement() { statement =>
withJdbcStatement { statement =>
val queries = Seq(
"DROP TABLE IF EXISTS test_4292",
"CREATE TABLE test_4292(key INT, val STRING)",
......@@ -344,7 +154,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
}
test("SPARK-4309 regression: Date type support") {
withJdbcStatement() { statement =>
withJdbcStatement { statement =>
val queries = Seq(
"DROP TABLE IF EXISTS test_date",
"CREATE TABLE test_date(key INT, value STRING)",
......@@ -362,7 +172,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
}
test("SPARK-4407 regression: Complex type support") {
withJdbcStatement() { statement =>
withJdbcStatement { statement =>
val queries = Seq(
"DROP TABLE IF EXISTS test_map",
"CREATE TABLE test_map(key INT, value STRING)",
......@@ -385,3 +195,209 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
}
}
}
class HiveThriftHttpServerSuite extends HiveThriftJdbcTest {
override def mode = ServerMode.http
test("JDBC query execution") {
withJdbcStatement { statement =>
val queries = Seq(
"SET spark.sql.shuffle.partitions=3",
"DROP TABLE IF EXISTS test",
"CREATE TABLE test(key INT, val STRING)",
s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test",
"CACHE TABLE test")
queries.foreach(statement.execute)
assertResult(5, "Row count mismatch") {
val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test")
resultSet.next()
resultSet.getInt(1)
}
}
}
test("Checks Hive version") {
withJdbcStatement { statement =>
val resultSet = statement.executeQuery("SET spark.sql.hive.version")
resultSet.next()
assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
}
}
}
object ServerMode extends Enumeration {
val binary, http = Value
}
abstract class HiveThriftJdbcTest extends HiveThriftServer2Test {
Class.forName(classOf[HiveDriver].getCanonicalName)
private def jdbcUri = if (mode == ServerMode.http) {
s"""jdbc:hive2://localhost:$serverPort/
|default?
|hive.server2.transport.mode=http;
|hive.server2.thrift.http.path=cliservice
""".stripMargin.split("\n").mkString.trim
} else {
s"jdbc:hive2://localhost:$serverPort/"
}
protected def withJdbcStatement(f: Statement => Unit): Unit = {
val connection = DriverManager.getConnection(jdbcUri, user, "")
val statement = connection.createStatement()
try f(statement) finally {
statement.close()
connection.close()
}
}
}
abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll with Logging {
def mode: ServerMode.Value
private val CLASS_NAME = HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$")
private val LOG_FILE_MARK = s"starting $CLASS_NAME, logging to "
private val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator)
private val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator)
private var listeningPort: Int = _
protected def serverPort: Int = listeningPort
protected def user = System.getProperty("user.name")
private var warehousePath: File = _
private var metastorePath: File = _
private def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true"
private var logPath: File = _
private var logTailingProcess: Process = _
private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String]
private def serverStartCommand(port: Int) = {
val portConf = if (mode == ServerMode.binary) {
ConfVars.HIVE_SERVER2_THRIFT_PORT
} else {
ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT
}
s"""$startScript
| --master local
| --hiveconf hive.root.logger=INFO,console
| --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
| --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost
| --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode
| --hiveconf $portConf=$port
| --driver-class-path ${sys.props("java.class.path")}
| --conf spark.ui.enabled=false
""".stripMargin.split("\\s+").toSeq
}
private def startThriftServer(port: Int, attempt: Int) = {
warehousePath = util.getTempFilePath("warehouse")
metastorePath = util.getTempFilePath("metastore")
logPath = null
logTailingProcess = null
val command = serverStartCommand(port)
diagnosisBuffer ++=
s"""
|### Attempt $attempt ###
|HiveThriftServer2 command line: $command
|Listening port: $port
|System user: $user
""".stripMargin.split("\n")
logInfo(s"Trying to start HiveThriftServer2: port=$port, mode=$mode, attempt=$attempt")
logPath = Process(command, None, "SPARK_TESTING" -> "0").lines.collectFirst {
case line if line.contains(LOG_FILE_MARK) => new File(line.drop(LOG_FILE_MARK.length))
}.getOrElse {
throw new RuntimeException("Failed to find HiveThriftServer2 log file.")
}
val serverStarted = Promise[Unit]()
// Ensures that the following "tail" command won't fail.
logPath.createNewFile()
logTailingProcess =
// Using "-n +0" to make sure all lines in the log file are checked.
Process(s"/usr/bin/env tail -n +0 -f ${logPath.getCanonicalPath}").run(ProcessLogger(
(line: String) => {
diagnosisBuffer += line
if (line.contains("ThriftBinaryCLIService listening on") ||
line.contains("Started ThriftHttpCLIService in http")) {
serverStarted.trySuccess(())
} else if (line.contains("HiveServer2 is stopped")) {
// This log line appears when the server fails to start and terminates gracefully (e.g.
// because of port contention).
serverStarted.tryFailure(new RuntimeException("Failed to start HiveThriftServer2"))
}
}))
Await.result(serverStarted.future, 2.minute)
}
private def stopThriftServer(): Unit = {
// The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while.
Process(stopScript, None).run().exitValue()
Thread.sleep(3.seconds.toMillis)
warehousePath.delete()
warehousePath = null
metastorePath.delete()
metastorePath = null
Option(logPath).foreach(_.delete())
logPath = null
Option(logTailingProcess).foreach(_.destroy())
logTailingProcess = null
}
private def dumpLogs(): Unit = {
logError(
s"""
|=====================================
|HiveThriftServer2Suite failure output
|=====================================
|${diagnosisBuffer.mkString("\n")}
|=========================================
|End HiveThriftServer2Suite failure output
|=========================================
""".stripMargin)
}
override protected def beforeAll(): Unit = {
// Chooses a random port between 10000 and 19999
listeningPort = 10000 + Random.nextInt(10000)
diagnosisBuffer.clear()
// Retries up to 3 times with different port numbers if the server fails to start
(1 to 3).foldLeft(Try(startThriftServer(listeningPort, 0))) { case (started, attempt) =>
started.orElse {
listeningPort += 1
stopThriftServer()
Try(startThriftServer(listeningPort, attempt))
}
}.recover {
case cause: Throwable =>
dumpLogs()
throw cause
}.get
logInfo(s"HiveThriftServer2 started successfully")
}
override protected def afterAll(): Unit = {
stopThriftServer()
logInfo("HiveThriftServer2 stopped")
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment