Skip to content
Snippets Groups Projects
Commit f816e739 authored by Cheng Lian's avatar Cheng Lian
Browse files

[SPARK-5751] [SQL] [WIP] Revamped HiveThriftServer2Suite for robustness

**NOTICE** Do NOT merge this, as we're waiting for #3881 to be merged.

`HiveThriftServer2Suite` has been notorious for its flakiness for a while. This was mostly due to spawning and communicate with external server processes. This PR revamps this test suite for better robustness:

1. Fixes a racing condition occurred while using `tail -f` to check log file

   It's possible that the line we are looking for has already been printed into the log file before we start the `tail -f` process. This PR uses `tail -n +0 -f` to ensure all lines are checked.

2. Retries up to 3 times if the server fails to start

   In most of the cases, the server fails to start because of port conflict. This PR no longer asks the system to choose an available TCP port, but uses a random port first, and retries up to 3 times if the server fails to start.

3. A server instance is reused among all test cases within a single suite

   The original `HiveThriftServer2Suite` is splitted into two test suites, `HiveThriftBinaryServerSuite` and `HiveThriftHttpServerSuite`. Each suite starts a `HiveThriftServer2` instance and reuses it for all of its test cases.

**TODO**

- [ ] Starts the Thrift server in foreground once #3881 is merged (adding `--foreground` flag to `spark-daemon.sh`)

<!-- Reviewable:start -->
[<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/4720)
<!-- Reviewable:end -->

Author: Cheng Lian <lian@databricks.com>

Closes #4720 from liancheng/revamp-thrift-server-tests and squashes the following commits:

d6c80eb [Cheng Lian] Relaxes server startup timeout
6f14eb1 [Cheng Lian] Revamped HiveThriftServer2Suite for robustness
parent 2a0fe348
No related branches found
No related tags found
No related merge requests found
...@@ -18,217 +18,75 @@ ...@@ -18,217 +18,75 @@
package org.apache.spark.sql.hive.thriftserver package org.apache.spark.sql.hive.thriftserver
import java.io.File import java.io.File
import java.net.ServerSocket
import java.sql.{Date, DriverManager, Statement} import java.sql.{Date, DriverManager, Statement}
import java.util.concurrent.TimeoutException
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._ import scala.concurrent.duration._
import scala.concurrent.{Await, Promise} import scala.concurrent.{Await, Promise}
import scala.sys.process.{Process, ProcessLogger} import scala.sys.process.{Process, ProcessLogger}
import scala.util.Try import scala.util.{Random, Try}
import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.jdbc.HiveDriver import org.apache.hive.jdbc.HiveDriver
import org.apache.hive.service.auth.PlainSaslHelper import org.apache.hive.service.auth.PlainSaslHelper
import org.apache.hive.service.cli.GetInfoType import org.apache.hive.service.cli.GetInfoType
import org.apache.hive.service.cli.thrift.TCLIService.Client import org.apache.hive.service.cli.thrift.TCLIService.Client
import org.apache.hive.service.cli.thrift._ import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient
import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.protocol.TBinaryProtocol
import org.apache.thrift.transport.TSocket import org.apache.thrift.transport.TSocket
import org.scalatest.FunSuite import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.spark.Logging import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.util.getTempFilePath import org.apache.spark.sql.catalyst.util
import org.apache.spark.sql.hive.HiveShim import org.apache.spark.sql.hive.HiveShim
/** object TestData {
* Tests for the HiveThriftServer2 using JDBC. def getTestDataFilePath(name: String) = {
* Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name")
* NOTE: SPARK_PREPEND_CLASSES is explicitly disabled in this test suite. Assembly jar must be
* rebuilt after changing HiveThriftServer2 related code.
*/
class HiveThriftServer2Suite extends FunSuite with Logging {
Class.forName(classOf[HiveDriver].getCanonicalName)
object TestData {
def getTestDataFilePath(name: String) = {
Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name")
}
val smallKv = getTestDataFilePath("small_kv.txt")
val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt")
} }
def randomListeningPort = { val smallKv = getTestDataFilePath("small_kv.txt")
// Let the system to choose a random available port to avoid collision with other parallel val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt")
// builds. }
val socket = new ServerSocket(0)
val port = socket.getLocalPort
socket.close()
port
}
def withJdbcStatement( class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
serverStartTimeout: FiniteDuration = 1.minute, override def mode = ServerMode.binary
httpMode: Boolean = false)(
f: Statement => Unit) {
val port = randomListeningPort
startThriftServer(port, serverStartTimeout, httpMode) {
val jdbcUri = if (httpMode) {
s"jdbc:hive2://${"localhost"}:$port/" +
"default?hive.server2.transport.mode=http;hive.server2.thrift.http.path=cliservice"
} else {
s"jdbc:hive2://${"localhost"}:$port/"
}
val user = System.getProperty("user.name") private def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = {
val connection = DriverManager.getConnection(jdbcUri, user, "") // Transport creation logics below mimics HiveConnection.createBinaryTransport
val statement = connection.createStatement() val rawTransport = new TSocket("localhost", serverPort)
val user = System.getProperty("user.name")
try { val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
f(statement) val protocol = new TBinaryProtocol(transport)
} finally { val client = new ThriftCLIServiceClient(new Client(protocol))
statement.close()
connection.close()
}
}
}
def withCLIServiceClient( transport.open()
serverStartTimeout: FiniteDuration = 1.minute)( try f(client) finally transport.close()
f: ThriftCLIServiceClient => Unit) { }
val port = randomListeningPort
startThriftServer(port) { test("GetInfo Thrift API") {
// Transport creation logics below mimics HiveConnection.createBinaryTransport withCLIServiceClient { client =>
val rawTransport = new TSocket("localhost", port)
val user = System.getProperty("user.name") val user = System.getProperty("user.name")
val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) val sessionHandle = client.openSession(user, "")
val protocol = new TBinaryProtocol(transport)
val client = new ThriftCLIServiceClient(new Client(protocol))
transport.open()
try {
f(client)
} finally {
transport.close()
}
}
}
def startThriftServer( assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") {
port: Int, client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue
serverStartTimeout: FiniteDuration = 1.minute,
httpMode: Boolean = false)(
f: => Unit) {
val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator)
val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator)
val warehousePath = getTempFilePath("warehouse")
val metastorePath = getTempFilePath("metastore")
val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true"
val command =
if (httpMode) {
s"""$startScript
| --master local
| --hiveconf hive.root.logger=INFO,console
| --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
| --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost
| --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=http
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT}=$port
| --driver-class-path ${sys.props("java.class.path")}
| --conf spark.ui.enabled=false
""".stripMargin.split("\\s+").toSeq
} else {
s"""$startScript
| --master local
| --hiveconf hive.root.logger=INFO,console
| --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
| --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$port
| --driver-class-path ${sys.props("java.class.path")}
| --conf spark.ui.enabled=false
""".stripMargin.split("\\s+").toSeq
} }
val serverRunning = Promise[Unit]() assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") {
val buffer = new ArrayBuffer[String]() client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue
val LOGGING_MARK =
s"starting ${HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$")}, logging to "
var logTailingProcess: Process = null
var logFilePath: String = null
def captureLogOutput(line: String): Unit = {
buffer += line
if (line.contains("ThriftBinaryCLIService listening on") ||
line.contains("Started ThriftHttpCLIService in http")) {
serverRunning.success(())
} }
}
def captureThriftServerOutput(source: String)(line: String): Unit = { assertResult(true, "Spark version shouldn't be \"Unknown\"") {
if (line.startsWith(LOGGING_MARK)) { val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue
logFilePath = line.drop(LOGGING_MARK.length).trim logInfo(s"Spark version: $version")
// Ensure that the log file is created so that the `tail' command won't fail version != "Unknown"
Try(new File(logFilePath).createNewFile())
logTailingProcess = Process(s"/usr/bin/env tail -f $logFilePath")
.run(ProcessLogger(captureLogOutput, _ => ()))
} }
} }
val env = Seq(
// Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths
"SPARK_TESTING" -> "0")
Process(command, None, env: _*).run(ProcessLogger(
captureThriftServerOutput("stdout"),
captureThriftServerOutput("stderr")))
try {
Await.result(serverRunning.future, serverStartTimeout)
f
} catch {
case cause: Exception =>
cause match {
case _: TimeoutException =>
logError(s"Failed to start Hive Thrift server within $serverStartTimeout", cause)
case _ =>
}
logError(
s"""
|=====================================
|HiveThriftServer2Suite failure output
|=====================================
|HiveThriftServer2 command line: ${command.mkString(" ")}
|Binding port: $port
|System user: ${System.getProperty("user.name")}
|
|${buffer.mkString("\n")}
|=========================================
|End HiveThriftServer2Suite failure output
|=========================================
""".stripMargin, cause)
throw cause
} finally {
warehousePath.delete()
metastorePath.delete()
Process(stopScript, None, env: _*).run().exitValue()
// The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while.
Thread.sleep(3.seconds.toMillis)
Option(logTailingProcess).map(_.destroy())
Option(logFilePath).map(new File(_).delete())
}
} }
test("Test JDBC query execution") { test("JDBC query execution") {
withJdbcStatement() { statement => withJdbcStatement { statement =>
val queries = Seq( val queries = Seq(
"SET spark.sql.shuffle.partitions=3", "SET spark.sql.shuffle.partitions=3",
"DROP TABLE IF EXISTS test", "DROP TABLE IF EXISTS test",
...@@ -246,27 +104,16 @@ class HiveThriftServer2Suite extends FunSuite with Logging { ...@@ -246,27 +104,16 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
} }
} }
test("Test JDBC query execution in Http Mode") { test("Checks Hive version") {
withJdbcStatement(httpMode = true) { statement => withJdbcStatement { statement =>
val queries = Seq( val resultSet = statement.executeQuery("SET spark.sql.hive.version")
"SET spark.sql.shuffle.partitions=3", resultSet.next()
"DROP TABLE IF EXISTS test", assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
"CREATE TABLE test(key INT, val STRING)",
s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test",
"CACHE TABLE test")
queries.foreach(statement.execute)
assertResult(5, "Row count mismatch") {
val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test")
resultSet.next()
resultSet.getInt(1)
}
} }
} }
test("SPARK-3004 regression: result set containing NULL") { test("SPARK-3004 regression: result set containing NULL") {
withJdbcStatement() { statement => withJdbcStatement { statement =>
val queries = Seq( val queries = Seq(
"DROP TABLE IF EXISTS test_null", "DROP TABLE IF EXISTS test_null",
"CREATE TABLE test_null(key INT, val STRING)", "CREATE TABLE test_null(key INT, val STRING)",
...@@ -286,45 +133,8 @@ class HiveThriftServer2Suite extends FunSuite with Logging { ...@@ -286,45 +133,8 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
} }
} }
test("GetInfo Thrift API") {
withCLIServiceClient() { client =>
val user = System.getProperty("user.name")
val sessionHandle = client.openSession(user, "")
assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") {
client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue
}
assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") {
client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue
}
assertResult(true, "Spark version shouldn't be \"Unknown\"") {
val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue
logInfo(s"Spark version: $version")
version != "Unknown"
}
}
}
test("Checks Hive version") {
withJdbcStatement() { statement =>
val resultSet = statement.executeQuery("SET spark.sql.hive.version")
resultSet.next()
assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
}
}
test("Checks Hive version in Http Mode") {
withJdbcStatement(httpMode = true) { statement =>
val resultSet = statement.executeQuery("SET spark.sql.hive.version")
resultSet.next()
assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
}
}
test("SPARK-4292 regression: result set iterator issue") { test("SPARK-4292 regression: result set iterator issue") {
withJdbcStatement() { statement => withJdbcStatement { statement =>
val queries = Seq( val queries = Seq(
"DROP TABLE IF EXISTS test_4292", "DROP TABLE IF EXISTS test_4292",
"CREATE TABLE test_4292(key INT, val STRING)", "CREATE TABLE test_4292(key INT, val STRING)",
...@@ -344,7 +154,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging { ...@@ -344,7 +154,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
} }
test("SPARK-4309 regression: Date type support") { test("SPARK-4309 regression: Date type support") {
withJdbcStatement() { statement => withJdbcStatement { statement =>
val queries = Seq( val queries = Seq(
"DROP TABLE IF EXISTS test_date", "DROP TABLE IF EXISTS test_date",
"CREATE TABLE test_date(key INT, value STRING)", "CREATE TABLE test_date(key INT, value STRING)",
...@@ -362,7 +172,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging { ...@@ -362,7 +172,7 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
} }
test("SPARK-4407 regression: Complex type support") { test("SPARK-4407 regression: Complex type support") {
withJdbcStatement() { statement => withJdbcStatement { statement =>
val queries = Seq( val queries = Seq(
"DROP TABLE IF EXISTS test_map", "DROP TABLE IF EXISTS test_map",
"CREATE TABLE test_map(key INT, value STRING)", "CREATE TABLE test_map(key INT, value STRING)",
...@@ -385,3 +195,209 @@ class HiveThriftServer2Suite extends FunSuite with Logging { ...@@ -385,3 +195,209 @@ class HiveThriftServer2Suite extends FunSuite with Logging {
} }
} }
} }
class HiveThriftHttpServerSuite extends HiveThriftJdbcTest {
override def mode = ServerMode.http
test("JDBC query execution") {
withJdbcStatement { statement =>
val queries = Seq(
"SET spark.sql.shuffle.partitions=3",
"DROP TABLE IF EXISTS test",
"CREATE TABLE test(key INT, val STRING)",
s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test",
"CACHE TABLE test")
queries.foreach(statement.execute)
assertResult(5, "Row count mismatch") {
val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test")
resultSet.next()
resultSet.getInt(1)
}
}
}
test("Checks Hive version") {
withJdbcStatement { statement =>
val resultSet = statement.executeQuery("SET spark.sql.hive.version")
resultSet.next()
assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
}
}
}
object ServerMode extends Enumeration {
val binary, http = Value
}
abstract class HiveThriftJdbcTest extends HiveThriftServer2Test {
Class.forName(classOf[HiveDriver].getCanonicalName)
private def jdbcUri = if (mode == ServerMode.http) {
s"""jdbc:hive2://localhost:$serverPort/
|default?
|hive.server2.transport.mode=http;
|hive.server2.thrift.http.path=cliservice
""".stripMargin.split("\n").mkString.trim
} else {
s"jdbc:hive2://localhost:$serverPort/"
}
protected def withJdbcStatement(f: Statement => Unit): Unit = {
val connection = DriverManager.getConnection(jdbcUri, user, "")
val statement = connection.createStatement()
try f(statement) finally {
statement.close()
connection.close()
}
}
}
abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll with Logging {
def mode: ServerMode.Value
private val CLASS_NAME = HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$")
private val LOG_FILE_MARK = s"starting $CLASS_NAME, logging to "
private val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator)
private val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator)
private var listeningPort: Int = _
protected def serverPort: Int = listeningPort
protected def user = System.getProperty("user.name")
private var warehousePath: File = _
private var metastorePath: File = _
private def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true"
private var logPath: File = _
private var logTailingProcess: Process = _
private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String]
private def serverStartCommand(port: Int) = {
val portConf = if (mode == ServerMode.binary) {
ConfVars.HIVE_SERVER2_THRIFT_PORT
} else {
ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT
}
s"""$startScript
| --master local
| --hiveconf hive.root.logger=INFO,console
| --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
| --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
| --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost
| --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode
| --hiveconf $portConf=$port
| --driver-class-path ${sys.props("java.class.path")}
| --conf spark.ui.enabled=false
""".stripMargin.split("\\s+").toSeq
}
private def startThriftServer(port: Int, attempt: Int) = {
warehousePath = util.getTempFilePath("warehouse")
metastorePath = util.getTempFilePath("metastore")
logPath = null
logTailingProcess = null
val command = serverStartCommand(port)
diagnosisBuffer ++=
s"""
|### Attempt $attempt ###
|HiveThriftServer2 command line: $command
|Listening port: $port
|System user: $user
""".stripMargin.split("\n")
logInfo(s"Trying to start HiveThriftServer2: port=$port, mode=$mode, attempt=$attempt")
logPath = Process(command, None, "SPARK_TESTING" -> "0").lines.collectFirst {
case line if line.contains(LOG_FILE_MARK) => new File(line.drop(LOG_FILE_MARK.length))
}.getOrElse {
throw new RuntimeException("Failed to find HiveThriftServer2 log file.")
}
val serverStarted = Promise[Unit]()
// Ensures that the following "tail" command won't fail.
logPath.createNewFile()
logTailingProcess =
// Using "-n +0" to make sure all lines in the log file are checked.
Process(s"/usr/bin/env tail -n +0 -f ${logPath.getCanonicalPath}").run(ProcessLogger(
(line: String) => {
diagnosisBuffer += line
if (line.contains("ThriftBinaryCLIService listening on") ||
line.contains("Started ThriftHttpCLIService in http")) {
serverStarted.trySuccess(())
} else if (line.contains("HiveServer2 is stopped")) {
// This log line appears when the server fails to start and terminates gracefully (e.g.
// because of port contention).
serverStarted.tryFailure(new RuntimeException("Failed to start HiveThriftServer2"))
}
}))
Await.result(serverStarted.future, 2.minute)
}
private def stopThriftServer(): Unit = {
// The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while.
Process(stopScript, None).run().exitValue()
Thread.sleep(3.seconds.toMillis)
warehousePath.delete()
warehousePath = null
metastorePath.delete()
metastorePath = null
Option(logPath).foreach(_.delete())
logPath = null
Option(logTailingProcess).foreach(_.destroy())
logTailingProcess = null
}
private def dumpLogs(): Unit = {
logError(
s"""
|=====================================
|HiveThriftServer2Suite failure output
|=====================================
|${diagnosisBuffer.mkString("\n")}
|=========================================
|End HiveThriftServer2Suite failure output
|=========================================
""".stripMargin)
}
override protected def beforeAll(): Unit = {
// Chooses a random port between 10000 and 19999
listeningPort = 10000 + Random.nextInt(10000)
diagnosisBuffer.clear()
// Retries up to 3 times with different port numbers if the server fails to start
(1 to 3).foldLeft(Try(startThriftServer(listeningPort, 0))) { case (started, attempt) =>
started.orElse {
listeningPort += 1
stopThriftServer()
Try(startThriftServer(listeningPort, attempt))
}
}.recover {
case cause: Throwable =>
dumpLogs()
throw cause
}.get
logInfo(s"HiveThriftServer2 started successfully")
}
override protected def afterAll(): Unit = {
stopThriftServer()
logInfo("HiveThriftServer2 stopped")
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment