diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
index f192f490ac3d0607575be2e81cca3f2f191c8be6..9338e8121b0febc75a05b2089a7eb26edd3ef1ac 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
@@ -73,35 +73,10 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage
             var curCol = 0
 
             while (curCol < sparkRow.length) {
-              dataTypes(curCol) match {
-                case StringType =>
-                  row.addString(sparkRow(curCol).asInstanceOf[String])
-                case IntegerType =>
-                  row.addColumnValue(ColumnValue.intValue(sparkRow.getInt(curCol)))
-                case BooleanType =>
-                  row.addColumnValue(ColumnValue.booleanValue(sparkRow.getBoolean(curCol)))
-                case DoubleType =>
-                  row.addColumnValue(ColumnValue.doubleValue(sparkRow.getDouble(curCol)))
-                case FloatType =>
-                  row.addColumnValue(ColumnValue.floatValue(sparkRow.getFloat(curCol)))
-                case DecimalType =>
-                  val hiveDecimal = sparkRow.get(curCol).asInstanceOf[BigDecimal].bigDecimal
-                  row.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal)))
-                case LongType =>
-                  row.addColumnValue(ColumnValue.longValue(sparkRow.getLong(curCol)))
-                case ByteType =>
-                  row.addColumnValue(ColumnValue.byteValue(sparkRow.getByte(curCol)))
-                case ShortType =>
-                  row.addColumnValue(ColumnValue.intValue(sparkRow.getShort(curCol)))
-                case TimestampType =>
-                  row.addColumnValue(
-                    ColumnValue.timestampValue(sparkRow.get(curCol).asInstanceOf[Timestamp]))
-                case BinaryType | _: ArrayType | _: StructType | _: MapType =>
-                  val hiveString = result
-                    .queryExecution
-                    .asInstanceOf[HiveContext#QueryExecution]
-                    .toHiveString((sparkRow.get(curCol), dataTypes(curCol)))
-                  row.addColumnValue(ColumnValue.stringValue(hiveString))
+              if (sparkRow.isNullAt(curCol)) {
+                addNullColumnValue(sparkRow, row, curCol)
+              } else {
+                addNonNullColumnValue(sparkRow, row, curCol)
               }
               curCol += 1
             }
@@ -112,6 +87,66 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage
         }
       }
 
+      def addNonNullColumnValue(from: SparkRow, to: Row, ordinal: Int) {
+        dataTypes(ordinal) match {
+          case StringType =>
+            to.addString(from(ordinal).asInstanceOf[String])
+          case IntegerType =>
+            to.addColumnValue(ColumnValue.intValue(from.getInt(ordinal)))
+          case BooleanType =>
+            to.addColumnValue(ColumnValue.booleanValue(from.getBoolean(ordinal)))
+          case DoubleType =>
+            to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal)))
+          case FloatType =>
+            to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal)))
+          case DecimalType =>
+            val hiveDecimal = from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal
+            to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal)))
+          case LongType =>
+            to.addColumnValue(ColumnValue.longValue(from.getLong(ordinal)))
+          case ByteType =>
+            to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal)))
+          case ShortType =>
+            to.addColumnValue(ColumnValue.intValue(from.getShort(ordinal)))
+          case TimestampType =>
+            to.addColumnValue(
+              ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp]))
+          case BinaryType | _: ArrayType | _: StructType | _: MapType =>
+            val hiveString = result
+              .queryExecution
+              .asInstanceOf[HiveContext#QueryExecution]
+              .toHiveString((from.get(ordinal), dataTypes(ordinal)))
+            to.addColumnValue(ColumnValue.stringValue(hiveString))
+        }
+      }
+
+      def addNullColumnValue(from: SparkRow, to: Row, ordinal: Int) {
+        dataTypes(ordinal) match {
+          case StringType =>
+            to.addString(null)
+          case IntegerType =>
+            to.addColumnValue(ColumnValue.intValue(null))
+          case BooleanType =>
+            to.addColumnValue(ColumnValue.booleanValue(null))
+          case DoubleType =>
+            to.addColumnValue(ColumnValue.doubleValue(null))
+          case FloatType =>
+            to.addColumnValue(ColumnValue.floatValue(null))
+          case DecimalType =>
+            to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal))
+          case LongType =>
+            to.addColumnValue(ColumnValue.longValue(null))
+          case ByteType =>
+            to.addColumnValue(ColumnValue.byteValue(null))
+          case ShortType =>
+            to.addColumnValue(ColumnValue.intValue(null))
+          case TimestampType =>
+            to.addColumnValue(ColumnValue.timestampValue(null))
+          case BinaryType | _: ArrayType | _: StructType | _: MapType =>
+            to.addColumnValue(ColumnValue.stringValue(null: String))
+        }
+      }
+
       def getResultSetSchema: TableSchema = {
         logWarning(s"Result Schema: ${result.queryExecution.analyzed.output}")
         if (result.queryExecution.analyzed.output.size == 0) {
diff --git a/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt b/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ae08c640e6c13e6e3ac91dcb5aaed44371f97774
--- /dev/null
+++ b/sql/hive-thriftserver/src/test/resources/data/files/small_kv_with_null.txt
@@ -0,0 +1,10 @@
+238val_238
+
+311val_311
+val_27
+val_165
+val_409
+255val_255
+278val_278
+98val_98
+val_484
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
index 78bffa2607349142dd03fe6a241c03d2adb2b1cc..aedef6ce1f5f22fe10694838ea09d25d5345c87c 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
@@ -113,22 +113,40 @@ class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUt
     val stmt = createStatement()
     stmt.execute("DROP TABLE IF EXISTS test")
     stmt.execute("DROP TABLE IF EXISTS test_cached")
-    stmt.execute("CREATE TABLE test(key int, val string)")
+    stmt.execute("CREATE TABLE test(key INT, val STRING)")
     stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test")
-    stmt.execute("CREATE TABLE test_cached as select * from test limit 4")
+    stmt.execute("CREATE TABLE test_cached AS SELECT * FROM test LIMIT 4")
     stmt.execute("CACHE TABLE test_cached")
 
-    var rs = stmt.executeQuery("select count(*) from test")
+    var rs = stmt.executeQuery("SELECT COUNT(*) FROM test")
     rs.next()
     assert(rs.getInt(1) === 5)
 
-    rs = stmt.executeQuery("select count(*) from test_cached")
+    rs = stmt.executeQuery("SELECT COUNT(*) FROM test_cached")
     rs.next()
     assert(rs.getInt(1) === 4)
 
     stmt.close()
   }
 
+  test("SPARK-3004 regression: result set containing NULL") {
+    Thread.sleep(5 * 1000)
+    val dataFilePath = getDataFile("data/files/small_kv_with_null.txt")
+    val stmt = createStatement()
+    stmt.execute("DROP TABLE IF EXISTS test_null")
+    stmt.execute("CREATE TABLE test_null(key INT, val STRING)")
+    stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_null")
+
+    val rs = stmt.executeQuery("SELECT * FROM test_null WHERE key IS NULL")
+    var count = 0
+    while (rs.next()) {
+      count += 1
+    }
+    assert(count === 5)
+
+    stmt.close()
+  }
+
   def getConnection: Connection = {
     val connectURI = s"jdbc:hive2://localhost:$PORT/"
     DriverManager.getConnection(connectURI, System.getProperty("user.name"), "")