diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index e8bcdd76efd7ae00601f4d8fa0616a37faa6746a..b2717ec54e6935af4986b1370d2f3d84e6fd746a 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -51,6 +51,7 @@ private[hive] class SparkExecuteStatementOperation( private var result: DataFrame = _ private var iter: Iterator[SparkRow] = _ + private var iterHeader: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ private var statementId: String = _ @@ -110,6 +111,14 @@ private[hive] class SparkExecuteStatementOperation( assertState(OperationState.FINISHED) setHasResultSet(true) val resultRowSet: RowSet = RowSetFactory.create(getResultSetSchema, getProtocolVersion) + + // Reset iter to header when fetching start from first row + if (order.equals(FetchOrientation.FETCH_FIRST)) { + val (ita, itb) = iterHeader.duplicate + iter = ita + iterHeader = itb + } + if (!iter.hasNext) { resultRowSet } else { @@ -228,6 +237,9 @@ private[hive] class SparkExecuteStatementOperation( result.collect().iterator } } + val (itra, itrb) = iter.duplicate + iterHeader = itra + iter = itrb dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray } catch { case e: HiveSQLException => diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index e388c2a082f1893f5810301a5d900595a799d47c..8f2c4fafa0b43eeba85667079ab391cb991d9943 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -36,6 +36,8 @@ import org.apache.hive.service.auth.PlainSaslHelper import org.apache.hive.service.cli.GetInfoType import org.apache.hive.service.cli.thrift.TCLIService.Client import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient +import org.apache.hive.service.cli.FetchOrientation +import org.apache.hive.service.cli.FetchType import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll @@ -91,6 +93,52 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { } } + test("SPARK-16563 ThriftCLIService FetchResults repeat fetching result") { + withCLIServiceClient { client => + val user = System.getProperty("user.name") + val sessionHandle = client.openSession(user, "") + + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_16563", + "CREATE TABLE test_16563(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_16563") + + queries.foreach(statement.execute) + val confOverlay = new java.util.HashMap[java.lang.String, java.lang.String] + val operationHandle = client.executeStatement( + sessionHandle, + "SELECT * FROM test_16563", + confOverlay) + + // Fetch result first time + assertResult(5, "Fetching result first time from next row") { + + val rows_next = client.fetchResults( + operationHandle, + FetchOrientation.FETCH_NEXT, + 1000, + FetchType.QUERY_OUTPUT) + + rows_next.numRows() + } + + // Fetch result second time from first row + assertResult(5, "Repeat fetching result from first row") { + + val rows_first = client.fetchResults( + operationHandle, + FetchOrientation.FETCH_FIRST, + 1000, + FetchType.QUERY_OUTPUT) + + rows_first.numRows() + } + statement.executeQuery("DROP TABLE IF EXISTS test_16563") + } + } + } + test("JDBC query execution") { withJdbcStatement { statement => val queries = Seq(