From bdc7a1a4749301f8d18617c130c7766684aa8789 Mon Sep 17 00:00:00 2001
From: Cheng Lian <lian.cs.zju@gmail.com>
Date: Wed, 13 Aug 2014 16:27:50 -0700
Subject: [PATCH] [SPARK-3004][SQL] Added null checking when retrieving row set

JIRA issue: [SPARK-3004](https://issues.apache.org/jira/browse/SPARK-3004)

HiveThriftServer2 throws exception when the result set contains `NULL`. Should check `isNullAt` in `SparkSQLOperationManager.getNextRowSet`.

Note that simply using `row.addColumnValue(null)` doesn't work, since Hive set the column type of a null `ColumnValue` to String by default.

Author: Cheng Lian <lian.cs.zju@gmail.com>

Closes #1920 from liancheng/spark-3004 and squashes the following commits:

1b1db1c [Cheng Lian] Adding NULL column values in the Hive way
2217722 [Cheng Lian] Fixed SPARK-3004: added null checking when retrieving row set
---
 .../server/SparkSQLOperationManager.scala     | 93 +++++++++++++------
 .../data/files/small_kv_with_null.txt         | 10 ++
 .../thriftserver/HiveThriftServer2Suite.scala | 26 +++++-
 3 files changed, 96 insertions(+), 33 deletions(-)
 create mode 100644 sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt

diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
index f192f490ac..9338e8121b 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
@@ -73,35 +73,10 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage
             var curCol = 0
 
             while (curCol < sparkRow.length) {
-              dataTypes(curCol) match {
-                case StringType =>
-                  row.addString(sparkRow(curCol).asInstanceOf[String])
-                case IntegerType =>
-                  row.addColumnValue(ColumnValue.intValue(sparkRow.getInt(curCol)))
-                case BooleanType =>
-                  row.addColumnValue(ColumnValue.booleanValue(sparkRow.getBoolean(curCol)))
-                case DoubleType =>
-                  row.addColumnValue(ColumnValue.doubleValue(sparkRow.getDouble(curCol)))
-                case FloatType =>
-                  row.addColumnValue(ColumnValue.floatValue(sparkRow.getFloat(curCol)))
-                case DecimalType =>
-                  val hiveDecimal = sparkRow.get(curCol).asInstanceOf[BigDecimal].bigDecimal
-                  row.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal)))
-                case LongType =>
-                  row.addColumnValue(ColumnValue.longValue(sparkRow.getLong(curCol)))
-                case ByteType =>
-                  row.addColumnValue(ColumnValue.byteValue(sparkRow.getByte(curCol)))
-                case ShortType =>
-                  row.addColumnValue(ColumnValue.intValue(sparkRow.getShort(curCol)))
-                case TimestampType =>
-                  row.addColumnValue(
-                    ColumnValue.timestampValue(sparkRow.get(curCol).asInstanceOf[Timestamp]))
-                case BinaryType | _: ArrayType | _: StructType | _: MapType =>
-                  val hiveString = result
-                    .queryExecution
-                    .asInstanceOf[HiveContext#QueryExecution]
-                    .toHiveString((sparkRow.get(curCol), dataTypes(curCol)))
-                  row.addColumnValue(ColumnValue.stringValue(hiveString))
+              if (sparkRow.isNullAt(curCol)) {
+                addNullColumnValue(sparkRow, row, curCol)
+              } else {
+                addNonNullColumnValue(sparkRow, row, curCol)
               }
               curCol += 1
             }
@@ -112,6 +87,66 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage
         }
       }
 
+      def addNonNullColumnValue(from: SparkRow, to: Row, ordinal: Int) {
+        dataTypes(ordinal) match {
+          case StringType =>
+            to.addString(from(ordinal).asInstanceOf[String])
+          case IntegerType =>
+            to.addColumnValue(ColumnValue.intValue(from.getInt(ordinal)))
+          case BooleanType =>
+            to.addColumnValue(ColumnValue.booleanValue(from.getBoolean(ordinal)))
+          case DoubleType =>
+            to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal)))
+          case FloatType =>
+            to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal)))
+          case DecimalType =>
+            val hiveDecimal = from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal
+            to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal)))
+          case LongType =>
+            to.addColumnValue(ColumnValue.longValue(from.getLong(ordinal)))
+          case ByteType =>
+            to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal)))
+          case ShortType =>
+            to.addColumnValue(ColumnValue.intValue(from.getShort(ordinal)))
+          case TimestampType =>
+            to.addColumnValue(
+              ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp]))
+          case BinaryType | _: ArrayType | _: StructType | _: MapType =>
+            val hiveString = result
+              .queryExecution
+              .asInstanceOf[HiveContext#QueryExecution]
+              .toHiveString((from.get(ordinal), dataTypes(ordinal)))
+            to.addColumnValue(ColumnValue.stringValue(hiveString))
+        }
+      }
+
+      def addNullColumnValue(from: SparkRow, to: Row, ordinal: Int) {
+        dataTypes(ordinal) match {
+          case StringType =>
+            to.addString(null)
+          case IntegerType =>
+            to.addColumnValue(ColumnValue.intValue(null))
+          case BooleanType =>
+            to.addColumnValue(ColumnValue.booleanValue(null))
+          case DoubleType =>
+            to.addColumnValue(ColumnValue.doubleValue(null))
+          case FloatType =>
+            to.addColumnValue(ColumnValue.floatValue(null))
+          case DecimalType =>
+            to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal))
+          case LongType =>
+            to.addColumnValue(ColumnValue.longValue(null))
+          case ByteType =>
+            to.addColumnValue(ColumnValue.byteValue(null))
+          case ShortType =>
+            to.addColumnValue(ColumnValue.intValue(null))
+          case TimestampType =>
+            to.addColumnValue(ColumnValue.timestampValue(null))
+          case BinaryType | _: ArrayType | _: StructType | _: MapType =>
+            to.addColumnValue(ColumnValue.stringValue(null: String))
+        }
+      }
+
       def getResultSetSchema: TableSchema = {
         logWarning(s"Result Schema: ${result.queryExecution.analyzed.output}")
         if (result.queryExecution.analyzed.output.size == 0) {
diff --git a/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt b/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt
new file mode 100644
index 0000000000..ae08c640e6
--- /dev/null
+++ b/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt
@@ -0,0 +1,10 @@
+238val_238
+
+311val_311
+val_27
+val_165
+val_409
+255val_255
+278val_278
+98val_98
+val_484
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
index 78bffa2607..aedef6ce1f 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
@@ -113,22 +113,40 @@ class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUt
     val stmt = createStatement()
     stmt.execute("DROP TABLE IF EXISTS test")
     stmt.execute("DROP TABLE IF EXISTS test_cached")
-    stmt.execute("CREATE TABLE test(key int, val string)")
+    stmt.execute("CREATE TABLE test(key INT, val STRING)")
     stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test")
-    stmt.execute("CREATE TABLE test_cached as select * from test limit 4")
+    stmt.execute("CREATE TABLE test_cached AS SELECT * FROM test LIMIT 4")
     stmt.execute("CACHE TABLE test_cached")
 
-    var rs = stmt.executeQuery("select count(*) from test")
+    var rs = stmt.executeQuery("SELECT COUNT(*) FROM test")
     rs.next()
     assert(rs.getInt(1) === 5)
 
-    rs = stmt.executeQuery("select count(*) from test_cached")
+    rs = stmt.executeQuery("SELECT COUNT(*) FROM test_cached")
     rs.next()
     assert(rs.getInt(1) === 4)
 
     stmt.close()
   }
 
+  test("SPARK-3004 regression: result set containing NULL") {
+    Thread.sleep(5 * 1000)
+    val dataFilePath = getDataFile("data/files/small_kv_with_null.txt")
+    val stmt = createStatement()
+    stmt.execute("DROP TABLE IF EXISTS test_null")
+    stmt.execute("CREATE TABLE test_null(key INT, val STRING)")
+    stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_null")
+
+    val rs = stmt.executeQuery("SELECT * FROM test_null WHERE key IS NULL")
+    var count = 0
+    while (rs.next()) {
+      count += 1
+    }
+    assert(count === 5)
+
+    stmt.close()
+  }
+
   def getConnection: Connection = {
     val connectURI = s"jdbc:hive2://localhost:$PORT/"
     DriverManager.getConnection(connectURI, System.getProperty("user.name"), "")
-- 
GitLab