From cbdcd4edab48593f6331bc267eb94e40908733e5 Mon Sep 17 00:00:00 2001
From: Sameer Agarwal <sameer@databricks.com>
Date: Sun, 24 Apr 2016 22:52:50 -0700
Subject: [PATCH] [SPARK-14870] [SQL] Fix NPE in TPCDS q14a

## What changes were proposed in this pull request?

This PR fixes a bug in `TungstenAggregate` that manifests while aggregating by keys over nullable `BigDecimal` columns. This causes a null pointer exception while executing TPCDS q14a.

## How was this patch tested?

1. Added regression test in `DataFrameAggregateSuite`.
2. Verified that TPCDS q14a works

Author: Sameer Agarwal <sameer@databricks.com>

Closes #12651 from sameeragarwal/tpcds-fix.
---
 .../expressions/codegen/CodeGenerator.scala       |  7 +++++--
 .../execution/aggregate/TungstenAggregate.scala   |  3 ++-
 .../spark/sql/DataFrameAggregateSuite.scala       | 15 +++++++++++++++
 .../org/apache/spark/sql/test/SQLTestData.scala   | 15 +++++++++++++++
 4 files changed, 37 insertions(+), 3 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index fa09f821fc..e4fa429b37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -239,16 +239,19 @@ class CodegenContext {
 
   /**
    * Update a column in MutableRow from ExprCode.
+   *
+   * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise
    */
   def updateColumn(
       row: String,
       dataType: DataType,
       ordinal: Int,
       ev: ExprCode,
-      nullable: Boolean): String = {
+      nullable: Boolean,
+      isVectorized: Boolean = false): String = {
     if (nullable) {
       // Can't call setNullAt on DecimalType, because we need to keep the offset
-      if (dataType.isInstanceOf[DecimalType]) {
+      if (!isVectorized && dataType.isInstanceOf[DecimalType]) {
         s"""
            if (!${ev.isNull}) {
              ${setColumn(row, dataType, ordinal, ev.value)};
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 782da0ea60..49db75e141 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -633,7 +633,8 @@ case class TungstenAggregate(
           updateExpr.map(BindReferences.bindReference(_, inputAttr).genCode(ctx))
         val updateVectorizedRow = vectorizedRowEvals.zipWithIndex.map { case (ev, i) =>
           val dt = updateExpr(i).dataType
-          ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable)
+          ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable,
+            isVectorized = true)
         }
         Option(
           s"""
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 7d96ef6fe0..0fcfb97d2b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -61,6 +61,21 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
       df1.groupBy("key").min("value2"),
       Seq(Row("a", 0), Row("b", 4))
     )
+
+    checkAnswer(
+      decimalData.groupBy("a").agg(sum("b")),
+      Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(3.0)),
+        Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(3.0)),
+        Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0)))
+    )
+
+    checkAnswer(
+      decimalDataWithNulls.groupBy("a").agg(sum("b")),
+      Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(1.0)),
+        Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(1.0)),
+        Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0)),
+        Row(null, new java.math.BigDecimal(2.0)))
+    )
   }
 
   test("rollup") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index 7fa6760b71..c5f25fa1df 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -103,6 +103,19 @@ private[sql] trait SQLTestData { self =>
     df
   }
 
+  protected lazy val decimalDataWithNulls: DataFrame = {
+    val df = sqlContext.sparkContext.parallelize(
+      DecimalDataWithNulls(1, 1) ::
+      DecimalDataWithNulls(1, null) ::
+      DecimalDataWithNulls(2, 1) ::
+      DecimalDataWithNulls(2, null) ::
+      DecimalDataWithNulls(3, 1) ::
+      DecimalDataWithNulls(3, 2) ::
+      DecimalDataWithNulls(null, 2) :: Nil).toDF()
+    df.registerTempTable("decimalDataWithNulls")
+    df
+  }
+
   protected lazy val binaryData: DataFrame = {
     val df = sqlContext.sparkContext.parallelize(
       BinaryData("12".getBytes(StandardCharsets.UTF_8), 1) ::
@@ -267,6 +280,7 @@ private[sql] trait SQLTestData { self =>
     negativeData
     largeAndSmallInts
     decimalData
+    decimalDataWithNulls
     binaryData
     upperCaseData
     lowerCaseData
@@ -296,6 +310,7 @@ private[sql] object SQLTestData {
   case class TestData3(a: Int, b: Option[Int])
   case class LargeAndSmallInts(a: Int, b: Int)
   case class DecimalData(a: BigDecimal, b: BigDecimal)
+  case class DecimalDataWithNulls(a: BigDecimal, b: BigDecimal)
   case class BinaryData(a: Array[Byte], b: Int)
   case class UpperCaseData(N: Int, L: String)
   case class LowerCaseData(n: Int, l: String)
-- 
GitLab