Skip to content
Snippets Groups Projects
Commit bdc7a1a4 authored by Cheng Lian's avatar Cheng Lian Committed by Michael Armbrust
Browse files

[SPARK-3004][SQL] Added null checking when retrieving row set

JIRA issue: [SPARK-3004](https://issues.apache.org/jira/browse/SPARK-3004)

HiveThriftServer2 throws exception when the result set contains `NULL`. Should check `isNullAt` in `SparkSQLOperationManager.getNextRowSet`.

Note that simply using `row.addColumnValue(null)` doesn't work, since Hive set the column type of a null `ColumnValue` to String by default.

Author: Cheng Lian <lian.cs.zju@gmail.com>

Closes #1920 from liancheng/spark-3004 and squashes the following commits:

1b1db1c [Cheng Lian] Adding NULL column values in the Hive way
2217722 [Cheng Lian] Fixed SPARK-3004: added null checking when retrieving row set
parent 7ecb867c
No related branches found
No related tags found
No related merge requests found
......@@ -73,35 +73,10 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage
var curCol = 0
while (curCol < sparkRow.length) {
dataTypes(curCol) match {
case StringType =>
row.addString(sparkRow(curCol).asInstanceOf[String])
case IntegerType =>
row.addColumnValue(ColumnValue.intValue(sparkRow.getInt(curCol)))
case BooleanType =>
row.addColumnValue(ColumnValue.booleanValue(sparkRow.getBoolean(curCol)))
case DoubleType =>
row.addColumnValue(ColumnValue.doubleValue(sparkRow.getDouble(curCol)))
case FloatType =>
row.addColumnValue(ColumnValue.floatValue(sparkRow.getFloat(curCol)))
case DecimalType =>
val hiveDecimal = sparkRow.get(curCol).asInstanceOf[BigDecimal].bigDecimal
row.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal)))
case LongType =>
row.addColumnValue(ColumnValue.longValue(sparkRow.getLong(curCol)))
case ByteType =>
row.addColumnValue(ColumnValue.byteValue(sparkRow.getByte(curCol)))
case ShortType =>
row.addColumnValue(ColumnValue.intValue(sparkRow.getShort(curCol)))
case TimestampType =>
row.addColumnValue(
ColumnValue.timestampValue(sparkRow.get(curCol).asInstanceOf[Timestamp]))
case BinaryType | _: ArrayType | _: StructType | _: MapType =>
val hiveString = result
.queryExecution
.asInstanceOf[HiveContext#QueryExecution]
.toHiveString((sparkRow.get(curCol), dataTypes(curCol)))
row.addColumnValue(ColumnValue.stringValue(hiveString))
if (sparkRow.isNullAt(curCol)) {
addNullColumnValue(sparkRow, row, curCol)
} else {
addNonNullColumnValue(sparkRow, row, curCol)
}
curCol += 1
}
......@@ -112,6 +87,66 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage
}
}
def addNonNullColumnValue(from: SparkRow, to: Row, ordinal: Int) {
dataTypes(ordinal) match {
case StringType =>
to.addString(from(ordinal).asInstanceOf[String])
case IntegerType =>
to.addColumnValue(ColumnValue.intValue(from.getInt(ordinal)))
case BooleanType =>
to.addColumnValue(ColumnValue.booleanValue(from.getBoolean(ordinal)))
case DoubleType =>
to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal)))
case FloatType =>
to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal)))
case DecimalType =>
val hiveDecimal = from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal
to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal)))
case LongType =>
to.addColumnValue(ColumnValue.longValue(from.getLong(ordinal)))
case ByteType =>
to.addColumnValue(ColumnValue.byteValue(from.getByte(ordinal)))
case ShortType =>
to.addColumnValue(ColumnValue.intValue(from.getShort(ordinal)))
case TimestampType =>
to.addColumnValue(
ColumnValue.timestampValue(from.get(ordinal).asInstanceOf[Timestamp]))
case BinaryType | _: ArrayType | _: StructType | _: MapType =>
val hiveString = result
.queryExecution
.asInstanceOf[HiveContext#QueryExecution]
.toHiveString((from.get(ordinal), dataTypes(ordinal)))
to.addColumnValue(ColumnValue.stringValue(hiveString))
}
}
def addNullColumnValue(from: SparkRow, to: Row, ordinal: Int) {
dataTypes(ordinal) match {
case StringType =>
to.addString(null)
case IntegerType =>
to.addColumnValue(ColumnValue.intValue(null))
case BooleanType =>
to.addColumnValue(ColumnValue.booleanValue(null))
case DoubleType =>
to.addColumnValue(ColumnValue.doubleValue(null))
case FloatType =>
to.addColumnValue(ColumnValue.floatValue(null))
case DecimalType =>
to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal))
case LongType =>
to.addColumnValue(ColumnValue.longValue(null))
case ByteType =>
to.addColumnValue(ColumnValue.byteValue(null))
case ShortType =>
to.addColumnValue(ColumnValue.intValue(null))
case TimestampType =>
to.addColumnValue(ColumnValue.timestampValue(null))
case BinaryType | _: ArrayType | _: StructType | _: MapType =>
to.addColumnValue(ColumnValue.stringValue(null: String))
}
}
def getResultSetSchema: TableSchema = {
logWarning(s"Result Schema: ${result.queryExecution.analyzed.output}")
if (result.queryExecution.analyzed.output.size == 0) {
......
238val_238

311val_311
val_27
val_165
val_409
255val_255
278val_278
98val_98
val_484
......@@ -113,22 +113,40 @@ class HiveThriftServer2Suite extends FunSuite with BeforeAndAfterAll with TestUt
val stmt = createStatement()
stmt.execute("DROP TABLE IF EXISTS test")
stmt.execute("DROP TABLE IF EXISTS test_cached")
stmt.execute("CREATE TABLE test(key int, val string)")
stmt.execute("CREATE TABLE test(key INT, val STRING)")
stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test")
stmt.execute("CREATE TABLE test_cached as select * from test limit 4")
stmt.execute("CREATE TABLE test_cached AS SELECT * FROM test LIMIT 4")
stmt.execute("CACHE TABLE test_cached")
var rs = stmt.executeQuery("select count(*) from test")
var rs = stmt.executeQuery("SELECT COUNT(*) FROM test")
rs.next()
assert(rs.getInt(1) === 5)
rs = stmt.executeQuery("select count(*) from test_cached")
rs = stmt.executeQuery("SELECT COUNT(*) FROM test_cached")
rs.next()
assert(rs.getInt(1) === 4)
stmt.close()
}
test("SPARK-3004 regression: result set containing NULL") {
Thread.sleep(5 * 1000)
val dataFilePath = getDataFile("data/files/small_kv_with_null.txt")
val stmt = createStatement()
stmt.execute("DROP TABLE IF EXISTS test_null")
stmt.execute("CREATE TABLE test_null(key INT, val STRING)")
stmt.execute(s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE test_null")
val rs = stmt.executeQuery("SELECT * FROM test_null WHERE key IS NULL")
var count = 0
while (rs.next()) {
count += 1
}
assert(count === 5)
stmt.close()
}
def getConnection: Connection = {
val connectURI = s"jdbc:hive2://localhost:$PORT/"
DriverManager.getConnection(connectURI, System.getProperty("user.name"), "")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment