diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
deleted file mode 100644
index b52a51d11e4ad3017d0300f631c24400d223c09a..0000000000000000000000000000000000000000
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
+++ /dev/null
@@ -1,387 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.thriftserver
-
-import java.io.File
-import java.net.ServerSocket
-import java.sql.{Date, DriverManager, Statement}
-import java.util.concurrent.TimeoutException
-
-import scala.collection.JavaConversions._
-import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.duration._
-import scala.concurrent.{Await, Promise}
-import scala.sys.process.{Process, ProcessLogger}
-import scala.util.Try
-
-import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.apache.hive.jdbc.HiveDriver
-import org.apache.hive.service.auth.PlainSaslHelper
-import org.apache.hive.service.cli.GetInfoType
-import org.apache.hive.service.cli.thrift.TCLIService.Client
-import org.apache.hive.service.cli.thrift._
-import org.apache.thrift.protocol.TBinaryProtocol
-import org.apache.thrift.transport.TSocket
-import org.scalatest.FunSuite
-
-import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst.util.getTempFilePath
-import org.apache.spark.sql.hive.HiveShim
-
-/**
- * Tests for the HiveThriftServer2 using JDBC.
- *
- * NOTE: SPARK_PREPEND_CLASSES is explicitly disabled in this test suite. Assembly jar must be
- * rebuilt after changing HiveThriftServer2 related code.
- */
-class HiveThriftServer2Suite extends FunSuite with Logging {
-  Class.forName(classOf[HiveDriver].getCanonicalName)
-
-  object TestData {
-    def getTestDataFilePath(name: String) = {
-      Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name")
-    }
-
-    val smallKv = getTestDataFilePath("small_kv.txt")
-    val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt")
-  }
-
-  def randomListeningPort =  {
-    // Let the system to choose a random available port to avoid collision with other parallel
-    // builds.
-    val socket = new ServerSocket(0)
-    val port = socket.getLocalPort
-    socket.close()
-    port
-  }
-
-  def withJdbcStatement(
-      serverStartTimeout: FiniteDuration = 1.minute,
-      httpMode: Boolean = false)(
-      f: Statement => Unit) {
-    val port = randomListeningPort
-
-    startThriftServer(port, serverStartTimeout, httpMode) {
-      val jdbcUri = if (httpMode) {
-        s"jdbc:hive2://${"localhost"}:$port/" +
-          "default?hive.server2.transport.mode=http;hive.server2.thrift.http.path=cliservice"
-      } else {
-        s"jdbc:hive2://${"localhost"}:$port/"
-      }
-
-      val user = System.getProperty("user.name")
-      val connection = DriverManager.getConnection(jdbcUri, user, "")
-      val statement = connection.createStatement()
-
-      try {
-        f(statement)
-      } finally {
-        statement.close()
-        connection.close()
-      }
-    }
-  }
-
-  def withCLIServiceClient(
-      serverStartTimeout: FiniteDuration = 1.minute)(
-      f: ThriftCLIServiceClient => Unit) {
-    val port = randomListeningPort
-
-    startThriftServer(port) {
-      // Transport creation logics below mimics HiveConnection.createBinaryTransport
-      val rawTransport = new TSocket("localhost", port)
-      val user = System.getProperty("user.name")
-      val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
-      val protocol = new TBinaryProtocol(transport)
-      val client = new ThriftCLIServiceClient(new Client(protocol))
-
-      transport.open()
-
-      try {
-        f(client)
-      } finally {
-        transport.close()
-      }
-    }
-  }
-
-  def startThriftServer(
-      port: Int,
-      serverStartTimeout: FiniteDuration = 1.minute,
-      httpMode: Boolean = false)(
-      f: => Unit) {
-    val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator)
-    val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator)
-
-    val warehousePath = getTempFilePath("warehouse")
-    val metastorePath = getTempFilePath("metastore")
-    val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true"
-
-    val command =
-      if (httpMode) {
-          s"""$startScript
-             |  --master local
-             |  --hiveconf hive.root.logger=INFO,console
-             |  --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
-             |  --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
-             |  --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost
-             |  --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=http
-             |  --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT}=$port
-             |  --driver-class-path ${sys.props("java.class.path")}
-             |  --conf spark.ui.enabled=false
-           """.stripMargin.split("\\s+").toSeq
-      } else {
-          s"""$startScript
-             |  --master local
-             |  --hiveconf hive.root.logger=INFO,console
-             |  --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
-             |  --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
-             |  --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost
-             |  --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$port
-             |  --driver-class-path ${sys.props("java.class.path")}
-             |  --conf spark.ui.enabled=false
-           """.stripMargin.split("\\s+").toSeq
-      }
-
-    val serverRunning = Promise[Unit]()
-    val buffer = new ArrayBuffer[String]()
-    val LOGGING_MARK =
-      s"starting ${HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$")}, logging to "
-    var logTailingProcess: Process = null
-    var logFilePath: String = null
-
-    def captureLogOutput(line: String): Unit = {
-      buffer += line
-      if (line.contains("ThriftBinaryCLIService listening on") ||
-          line.contains("Started ThriftHttpCLIService in http")) {
-        serverRunning.success(())
-      }
-    }
-
-    def captureThriftServerOutput(source: String)(line: String): Unit = {
-      if (line.startsWith(LOGGING_MARK)) {
-        logFilePath = line.drop(LOGGING_MARK.length).trim
-        // Ensure that the log file is created so that the `tail' command won't fail
-        Try(new File(logFilePath).createNewFile())
-        logTailingProcess = Process(s"/usr/bin/env tail -f $logFilePath")
-          .run(ProcessLogger(captureLogOutput, _ => ()))
-      }
-    }
-
-    val env = Seq(
-      // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths
-      "SPARK_TESTING" -> "0")
-
-    Process(command, None, env: _*).run(ProcessLogger(
-      captureThriftServerOutput("stdout"),
-      captureThriftServerOutput("stderr")))
-
-    try {
-      Await.result(serverRunning.future, serverStartTimeout)
-      f
-    } catch {
-      case cause: Exception =>
-        cause match {
-          case _: TimeoutException =>
-            logError(s"Failed to start Hive Thrift server within $serverStartTimeout", cause)
-          case _ =>
-        }
-        logError(
-          s"""
-             |=====================================
-             |HiveThriftServer2Suite failure output
-             |=====================================
-             |HiveThriftServer2 command line: ${command.mkString(" ")}
-             |Binding port: $port
-             |System user: ${System.getProperty("user.name")}
-             |
-             |${buffer.mkString("\n")}
-             |=========================================
-             |End HiveThriftServer2Suite failure output
-             |=========================================
-           """.stripMargin, cause)
-        throw cause
-    } finally {
-      warehousePath.delete()
-      metastorePath.delete()
-      Process(stopScript, None, env: _*).run().exitValue()
-      // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while.
-      Thread.sleep(3.seconds.toMillis)
-      Option(logTailingProcess).map(_.destroy())
-      Option(logFilePath).map(new File(_).delete())
-    }
-  }
-
-  test("Test JDBC query execution") {
-    withJdbcStatement() { statement =>
-      val queries = Seq(
-        "SET spark.sql.shuffle.partitions=3",
-        "DROP TABLE IF EXISTS test",
-        "CREATE TABLE test(key INT, val STRING)",
-        s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test",
-        "CACHE TABLE test")
-
-      queries.foreach(statement.execute)
-
-      assertResult(5, "Row count mismatch") {
-        val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test")
-        resultSet.next()
-        resultSet.getInt(1)
-      }
-    }
-  }
-
-  test("Test JDBC query execution in Http Mode") {
-    withJdbcStatement(httpMode = true) { statement =>
-      val queries = Seq(
-        "SET spark.sql.shuffle.partitions=3",
-        "DROP TABLE IF EXISTS test",
-        "CREATE TABLE test(key INT, val STRING)",
-        s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test",
-        "CACHE TABLE test")
-
-      queries.foreach(statement.execute)
-
-      assertResult(5, "Row count mismatch") {
-        val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test")
-        resultSet.next()
-        resultSet.getInt(1)
-      }
-    }
-  }
-
-  test("SPARK-3004 regression: result set containing NULL") {
-    withJdbcStatement() { statement =>
-      val queries = Seq(
-        "DROP TABLE IF EXISTS test_null",
-        "CREATE TABLE test_null(key INT, val STRING)",
-        s"LOAD DATA LOCAL INPATH '${TestData.smallKvWithNull}' OVERWRITE INTO TABLE test_null")
-
-      queries.foreach(statement.execute)
-
-      val resultSet = statement.executeQuery("SELECT * FROM test_null WHERE key IS NULL")
-
-      (0 until 5).foreach { _ =>
-        resultSet.next()
-        assert(resultSet.getInt(1) === 0)
-        assert(resultSet.wasNull())
-      }
-
-      assert(!resultSet.next())
-    }
-  }
-
-  test("GetInfo Thrift API") {
-    withCLIServiceClient() { client =>
-      val user = System.getProperty("user.name")
-      val sessionHandle = client.openSession(user, "")
-
-      assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") {
-        client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue
-      }
-
-      assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") {
-        client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue
-      }
-
-      assertResult(true, "Spark version shouldn't be \"Unknown\"") {
-        val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue
-        logInfo(s"Spark version: $version")
-        version != "Unknown"
-      }
-    }
-  }
-
-  test("Checks Hive version") {
-    withJdbcStatement() { statement =>
-      val resultSet = statement.executeQuery("SET spark.sql.hive.version")
-      resultSet.next()
-      assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
-    }
-  }
-
-  test("Checks Hive version in Http Mode") {
-    withJdbcStatement(httpMode = true) { statement =>
-      val resultSet = statement.executeQuery("SET spark.sql.hive.version")
-      resultSet.next()
-      assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
-    }
-  }
-
-  test("SPARK-4292 regression: result set iterator issue") {
-    withJdbcStatement() { statement =>
-      val queries = Seq(
-        "DROP TABLE IF EXISTS test_4292",
-        "CREATE TABLE test_4292(key INT, val STRING)",
-        s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_4292")
-
-      queries.foreach(statement.execute)
-
-      val resultSet = statement.executeQuery("SELECT key FROM test_4292")
-
-      Seq(238, 86, 311, 27, 165).foreach { key =>
-        resultSet.next()
-        assert(resultSet.getInt(1) === key)
-      }
-
-      statement.executeQuery("DROP TABLE IF EXISTS test_4292")
-    }
-  }
-
-  test("SPARK-4309 regression: Date type support") {
-    withJdbcStatement() { statement =>
-      val queries = Seq(
-        "DROP TABLE IF EXISTS test_date",
-        "CREATE TABLE test_date(key INT, value STRING)",
-        s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_date")
-
-      queries.foreach(statement.execute)
-
-      assertResult(Date.valueOf("2011-01-01")) {
-        val resultSet = statement.executeQuery(
-          "SELECT CAST('2011-01-01' as date) FROM test_date LIMIT 1")
-        resultSet.next()
-        resultSet.getDate(1)
-      }
-    }
-  }
-
-  test("SPARK-4407 regression: Complex type support") {
-    withJdbcStatement() { statement =>
-      val queries = Seq(
-        "DROP TABLE IF EXISTS test_map",
-        "CREATE TABLE test_map(key INT, value STRING)",
-        s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map")
-
-      queries.foreach(statement.execute)
-
-      assertResult("""{238:"val_238"}""") {
-        val resultSet = statement.executeQuery("SELECT MAP(key, value) FROM test_map LIMIT 1")
-        resultSet.next()
-        resultSet.getString(1)
-      }
-
-      assertResult("""["238","val_238"]""") {
-        val resultSet = statement.executeQuery(
-          "SELECT ARRAY(CAST(key AS STRING), value) FROM test_map LIMIT 1")
-        resultSet.next()
-        resultSet.getString(1)
-      }
-    }
-  }
-}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
new file mode 100644
index 0000000000000000000000000000000000000000..77ef37253e38f03b3b19a2b2228f12ea78d98708
--- /dev/null
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
@@ -0,0 +1,403 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import java.io.File
+import java.sql.{Date, DriverManager, Statement}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.duration._
+import scala.concurrent.{Await, Promise}
+import scala.sys.process.{Process, ProcessLogger}
+import scala.util.{Random, Try}
+
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars
+import org.apache.hive.jdbc.HiveDriver
+import org.apache.hive.service.auth.PlainSaslHelper
+import org.apache.hive.service.cli.GetInfoType
+import org.apache.hive.service.cli.thrift.TCLIService.Client
+import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient
+import org.apache.thrift.protocol.TBinaryProtocol
+import org.apache.thrift.transport.TSocket
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.util
+import org.apache.spark.sql.hive.HiveShim
+
+object TestData {
+  def getTestDataFilePath(name: String) = {
+    Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name")
+  }
+
+  val smallKv = getTestDataFilePath("small_kv.txt")
+  val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt")
+}
+
+class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
+  override def mode = ServerMode.binary
+
+  private def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = {
+    // Transport creation logics below mimics HiveConnection.createBinaryTransport
+    val rawTransport = new TSocket("localhost", serverPort)
+    val user = System.getProperty("user.name")
+    val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
+    val protocol = new TBinaryProtocol(transport)
+    val client = new ThriftCLIServiceClient(new Client(protocol))
+
+    transport.open()
+    try f(client) finally transport.close()
+  }
+
+  test("GetInfo Thrift API") {
+    withCLIServiceClient { client =>
+      val user = System.getProperty("user.name")
+      val sessionHandle = client.openSession(user, "")
+
+      assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") {
+        client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue
+      }
+
+      assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") {
+        client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue
+      }
+
+      assertResult(true, "Spark version shouldn't be \"Unknown\"") {
+        val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue
+        logInfo(s"Spark version: $version")
+        version != "Unknown"
+      }
+    }
+  }
+
+  test("JDBC query execution") {
+    withJdbcStatement { statement =>
+      val queries = Seq(
+        "SET spark.sql.shuffle.partitions=3",
+        "DROP TABLE IF EXISTS test",
+        "CREATE TABLE test(key INT, val STRING)",
+        s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test",
+        "CACHE TABLE test")
+
+      queries.foreach(statement.execute)
+
+      assertResult(5, "Row count mismatch") {
+        val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test")
+        resultSet.next()
+        resultSet.getInt(1)
+      }
+    }
+  }
+
+  test("Checks Hive version") {
+    withJdbcStatement { statement =>
+      val resultSet = statement.executeQuery("SET spark.sql.hive.version")
+      resultSet.next()
+      assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
+    }
+  }
+
+  test("SPARK-3004 regression: result set containing NULL") {
+    withJdbcStatement { statement =>
+      val queries = Seq(
+        "DROP TABLE IF EXISTS test_null",
+        "CREATE TABLE test_null(key INT, val STRING)",
+        s"LOAD DATA LOCAL INPATH '${TestData.smallKvWithNull}' OVERWRITE INTO TABLE test_null")
+
+      queries.foreach(statement.execute)
+
+      val resultSet = statement.executeQuery("SELECT * FROM test_null WHERE key IS NULL")
+
+      (0 until 5).foreach { _ =>
+        resultSet.next()
+        assert(resultSet.getInt(1) === 0)
+        assert(resultSet.wasNull())
+      }
+
+      assert(!resultSet.next())
+    }
+  }
+
+  test("SPARK-4292 regression: result set iterator issue") {
+    withJdbcStatement { statement =>
+      val queries = Seq(
+        "DROP TABLE IF EXISTS test_4292",
+        "CREATE TABLE test_4292(key INT, val STRING)",
+        s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_4292")
+
+      queries.foreach(statement.execute)
+
+      val resultSet = statement.executeQuery("SELECT key FROM test_4292")
+
+      Seq(238, 86, 311, 27, 165).foreach { key =>
+        resultSet.next()
+        assert(resultSet.getInt(1) === key)
+      }
+
+      statement.executeQuery("DROP TABLE IF EXISTS test_4292")
+    }
+  }
+
+  test("SPARK-4309 regression: Date type support") {
+    withJdbcStatement { statement =>
+      val queries = Seq(
+        "DROP TABLE IF EXISTS test_date",
+        "CREATE TABLE test_date(key INT, value STRING)",
+        s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_date")
+
+      queries.foreach(statement.execute)
+
+      assertResult(Date.valueOf("2011-01-01")) {
+        val resultSet = statement.executeQuery(
+          "SELECT CAST('2011-01-01' as date) FROM test_date LIMIT 1")
+        resultSet.next()
+        resultSet.getDate(1)
+      }
+    }
+  }
+
+  test("SPARK-4407 regression: Complex type support") {
+    withJdbcStatement { statement =>
+      val queries = Seq(
+        "DROP TABLE IF EXISTS test_map",
+        "CREATE TABLE test_map(key INT, value STRING)",
+        s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map")
+
+      queries.foreach(statement.execute)
+
+      assertResult("""{238:"val_238"}""") {
+        val resultSet = statement.executeQuery("SELECT MAP(key, value) FROM test_map LIMIT 1")
+        resultSet.next()
+        resultSet.getString(1)
+      }
+
+      assertResult("""["238","val_238"]""") {
+        val resultSet = statement.executeQuery(
+          "SELECT ARRAY(CAST(key AS STRING), value) FROM test_map LIMIT 1")
+        resultSet.next()
+        resultSet.getString(1)
+      }
+    }
+  }
+}
+
+class HiveThriftHttpServerSuite extends HiveThriftJdbcTest {
+  override def mode = ServerMode.http
+
+  test("JDBC query execution") {
+    withJdbcStatement { statement =>
+      val queries = Seq(
+        "SET spark.sql.shuffle.partitions=3",
+        "DROP TABLE IF EXISTS test",
+        "CREATE TABLE test(key INT, val STRING)",
+        s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test",
+        "CACHE TABLE test")
+
+      queries.foreach(statement.execute)
+
+      assertResult(5, "Row count mismatch") {
+        val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test")
+        resultSet.next()
+        resultSet.getInt(1)
+      }
+    }
+  }
+
+  test("Checks Hive version") {
+    withJdbcStatement { statement =>
+      val resultSet = statement.executeQuery("SET spark.sql.hive.version")
+      resultSet.next()
+      assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
+    }
+  }
+}
+
+object ServerMode extends Enumeration {
+  val binary, http = Value
+}
+
+abstract class HiveThriftJdbcTest extends HiveThriftServer2Test {
+  Class.forName(classOf[HiveDriver].getCanonicalName)
+
+  private def jdbcUri = if (mode == ServerMode.http) {
+    s"""jdbc:hive2://localhost:$serverPort/
+       |default?
+       |hive.server2.transport.mode=http;
+       |hive.server2.thrift.http.path=cliservice
+     """.stripMargin.split("\n").mkString.trim
+  } else {
+    s"jdbc:hive2://localhost:$serverPort/"
+  }
+
+  protected def withJdbcStatement(f: Statement => Unit): Unit = {
+    val connection = DriverManager.getConnection(jdbcUri, user, "")
+    val statement = connection.createStatement()
+
+    try f(statement) finally {
+      statement.close()
+      connection.close()
+    }
+  }
+}
+
+abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll with Logging {
+  def mode: ServerMode.Value
+
+  private val CLASS_NAME = HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$")
+  private val LOG_FILE_MARK = s"starting $CLASS_NAME, logging to "
+
+  private val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator)
+  private val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator)
+
+  private var listeningPort: Int = _
+  protected def serverPort: Int = listeningPort
+
+  protected def user = System.getProperty("user.name")
+
+  private var warehousePath: File = _
+  private var metastorePath: File = _
+  private def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true"
+
+  private var logPath: File = _
+  private var logTailingProcess: Process = _
+  private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String]
+
+  private def serverStartCommand(port: Int) = {
+    val portConf = if (mode == ServerMode.binary) {
+      ConfVars.HIVE_SERVER2_THRIFT_PORT
+    } else {
+      ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT
+    }
+
+    s"""$startScript
+       |  --master local
+       |  --hiveconf hive.root.logger=INFO,console
+       |  --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
+       |  --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
+       |  --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost
+       |  --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode
+       |  --hiveconf $portConf=$port
+       |  --driver-class-path ${sys.props("java.class.path")}
+       |  --conf spark.ui.enabled=false
+     """.stripMargin.split("\\s+").toSeq
+  }
+
+  private def startThriftServer(port: Int, attempt: Int) = {
+    warehousePath = util.getTempFilePath("warehouse")
+    metastorePath = util.getTempFilePath("metastore")
+    logPath = null
+    logTailingProcess = null
+
+    val command = serverStartCommand(port)
+
+    diagnosisBuffer ++=
+      s"""
+         |### Attempt $attempt ###
+         |HiveThriftServer2 command line: $command
+         |Listening port: $port
+         |System user: $user
+       """.stripMargin.split("\n")
+
+    logInfo(s"Trying to start HiveThriftServer2: port=$port, mode=$mode, attempt=$attempt")
+
+    logPath = Process(command, None, "SPARK_TESTING" -> "0").lines.collectFirst {
+      case line if line.contains(LOG_FILE_MARK) => new File(line.drop(LOG_FILE_MARK.length))
+    }.getOrElse {
+      throw new RuntimeException("Failed to find HiveThriftServer2 log file.")
+    }
+
+    val serverStarted = Promise[Unit]()
+
+    // Ensures that the following "tail" command won't fail.
+    logPath.createNewFile()
+    logTailingProcess =
+      // Using "-n +0" to make sure all lines in the log file are checked.
+      Process(s"/usr/bin/env tail -n +0 -f ${logPath.getCanonicalPath}").run(ProcessLogger(
+        (line: String) => {
+          diagnosisBuffer += line
+
+          if (line.contains("ThriftBinaryCLIService listening on") ||
+              line.contains("Started ThriftHttpCLIService in http")) {
+            serverStarted.trySuccess(())
+          } else if (line.contains("HiveServer2 is stopped")) {
+            // This log line appears when the server fails to start and terminates gracefully (e.g.
+            // because of port contention).
+            serverStarted.tryFailure(new RuntimeException("Failed to start HiveThriftServer2"))
+          }
+        }))
+
+    Await.result(serverStarted.future, 2.minute)
+  }
+
+  private def stopThriftServer(): Unit = {
+    // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while.
+    Process(stopScript, None).run().exitValue()
+    Thread.sleep(3.seconds.toMillis)
+
+    warehousePath.delete()
+    warehousePath = null
+
+    metastorePath.delete()
+    metastorePath = null
+
+    Option(logPath).foreach(_.delete())
+    logPath = null
+
+    Option(logTailingProcess).foreach(_.destroy())
+    logTailingProcess = null
+  }
+
+  private def dumpLogs(): Unit = {
+    logError(
+      s"""
+         |=====================================
+         |HiveThriftServer2Suite failure output
+         |=====================================
+         |${diagnosisBuffer.mkString("\n")}
+         |=========================================
+         |End HiveThriftServer2Suite failure output
+         |=========================================
+       """.stripMargin)
+  }
+
+  override protected def beforeAll(): Unit = {
+    // Chooses a random port between 10000 and 19999
+    listeningPort = 10000 + Random.nextInt(10000)
+    diagnosisBuffer.clear()
+
+    // Retries up to 3 times with different port numbers if the server fails to start
+    (1 to 3).foldLeft(Try(startThriftServer(listeningPort, 0))) { case (started, attempt) =>
+      started.orElse {
+        listeningPort += 1
+        stopThriftServer()
+        Try(startThriftServer(listeningPort, attempt))
+      }
+    }.recover {
+      case cause: Throwable =>
+        dumpLogs()
+        throw cause
+    }.get
+
+    logInfo(s"HiveThriftServer2 started successfully")
+  }
+
+  override protected def afterAll(): Unit = {
+    stopThriftServer()
+    logInfo("HiveThriftServer2 stopped")
+  }
+}