Skip to content
Snippets Groups Projects
Commit cbdcd4ed authored by Sameer Agarwal's avatar Sameer Agarwal Committed by Davies Liu
Browse files

[SPARK-14870] [SQL] Fix NPE in TPCDS q14a

## What changes were proposed in this pull request?

This PR fixes a bug in `TungstenAggregate` that manifests while aggregating by keys over nullable `BigDecimal` columns. This causes a null pointer exception while executing TPCDS q14a.

## How was this patch tested?

1. Added regression test in `DataFrameAggregateSuite`.
2. Verified that TPCDS q14a works

Author: Sameer Agarwal <sameer@databricks.com>

Closes #12651 from sameeragarwal/tpcds-fix.
parent c752b6c5
No related branches found
No related tags found
No related merge requests found
......@@ -239,16 +239,19 @@ class CodegenContext {
/**
* Update a column in MutableRow from ExprCode.
*
* @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise
*/
def updateColumn(
row: String,
dataType: DataType,
ordinal: Int,
ev: ExprCode,
nullable: Boolean): String = {
nullable: Boolean,
isVectorized: Boolean = false): String = {
if (nullable) {
// Can't call setNullAt on DecimalType, because we need to keep the offset
if (dataType.isInstanceOf[DecimalType]) {
if (!isVectorized && dataType.isInstanceOf[DecimalType]) {
s"""
if (!${ev.isNull}) {
${setColumn(row, dataType, ordinal, ev.value)};
......
......@@ -633,7 +633,8 @@ case class TungstenAggregate(
updateExpr.map(BindReferences.bindReference(_, inputAttr).genCode(ctx))
val updateVectorizedRow = vectorizedRowEvals.zipWithIndex.map { case (ev, i) =>
val dt = updateExpr(i).dataType
ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable)
ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable,
isVectorized = true)
}
Option(
s"""
......
......@@ -61,6 +61,21 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
df1.groupBy("key").min("value2"),
Seq(Row("a", 0), Row("b", 4))
)
checkAnswer(
decimalData.groupBy("a").agg(sum("b")),
Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(3.0)),
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(3.0)),
Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0)))
)
checkAnswer(
decimalDataWithNulls.groupBy("a").agg(sum("b")),
Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(1.0)),
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(1.0)),
Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0)),
Row(null, new java.math.BigDecimal(2.0)))
)
}
test("rollup") {
......
......@@ -103,6 +103,19 @@ private[sql] trait SQLTestData { self =>
df
}
protected lazy val decimalDataWithNulls: DataFrame = {
val df = sqlContext.sparkContext.parallelize(
DecimalDataWithNulls(1, 1) ::
DecimalDataWithNulls(1, null) ::
DecimalDataWithNulls(2, 1) ::
DecimalDataWithNulls(2, null) ::
DecimalDataWithNulls(3, 1) ::
DecimalDataWithNulls(3, 2) ::
DecimalDataWithNulls(null, 2) :: Nil).toDF()
df.registerTempTable("decimalDataWithNulls")
df
}
protected lazy val binaryData: DataFrame = {
val df = sqlContext.sparkContext.parallelize(
BinaryData("12".getBytes(StandardCharsets.UTF_8), 1) ::
......@@ -267,6 +280,7 @@ private[sql] trait SQLTestData { self =>
negativeData
largeAndSmallInts
decimalData
decimalDataWithNulls
binaryData
upperCaseData
lowerCaseData
......@@ -296,6 +310,7 @@ private[sql] object SQLTestData {
case class TestData3(a: Int, b: Option[Int])
case class LargeAndSmallInts(a: Int, b: Int)
case class DecimalData(a: BigDecimal, b: BigDecimal)
case class DecimalDataWithNulls(a: BigDecimal, b: BigDecimal)
case class BinaryData(a: Array[Byte], b: Int)
case class UpperCaseData(N: Int, L: String)
case class LowerCaseData(n: Int, l: String)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment