diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index fa09f821fc9977587318cde8910e7aca117b5ad3..e4fa429b37546164d598af1bc4fb5bf3dff8a8d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -239,16 +239,19 @@ class CodegenContext { /** * Update a column in MutableRow from ExprCode. + * + * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise */ def updateColumn( row: String, dataType: DataType, ordinal: Int, ev: ExprCode, - nullable: Boolean): String = { + nullable: Boolean, + isVectorized: Boolean = false): String = { if (nullable) { // Can't call setNullAt on DecimalType, because we need to keep the offset - if (dataType.isInstanceOf[DecimalType]) { + if (!isVectorized && dataType.isInstanceOf[DecimalType]) { s""" if (!${ev.isNull}) { ${setColumn(row, dataType, ordinal, ev.value)}; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 782da0ea604d335c5acc45b7a29a333dbc70c79e..49db75e141e90aeb3628118203c9de8a028bb918 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -633,7 +633,8 @@ case class TungstenAggregate( updateExpr.map(BindReferences.bindReference(_, inputAttr).genCode(ctx)) val updateVectorizedRow = vectorizedRowEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable) + ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable, + isVectorized = true) } Option( s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 7d96ef6fe0a10e9c29f0347821b9e9aaf216ac4d..0fcfb97d2bd905944810c1aef277bcf0097d2d00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -61,6 +61,21 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { df1.groupBy("key").min("value2"), Seq(Row("a", 0), Row("b", 4)) ) + + checkAnswer( + decimalData.groupBy("a").agg(sum("b")), + Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(3.0)), + Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(3.0)), + Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0))) + ) + + checkAnswer( + decimalDataWithNulls.groupBy("a").agg(sum("b")), + Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(1.0)), + Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(1.0)), + Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0)), + Row(null, new java.math.BigDecimal(2.0))) + ) } test("rollup") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 7fa6760b71c8bd26c9934020caf1a7459813512a..c5f25fa1df3b190f282a8a11a4d959e5b9e63d6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -103,6 +103,19 @@ private[sql] trait SQLTestData { self => df } + protected lazy val decimalDataWithNulls: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + DecimalDataWithNulls(1, 1) :: + DecimalDataWithNulls(1, null) :: + DecimalDataWithNulls(2, 1) :: + DecimalDataWithNulls(2, null) :: + DecimalDataWithNulls(3, 1) :: + DecimalDataWithNulls(3, 2) :: + DecimalDataWithNulls(null, 2) :: Nil).toDF() + df.registerTempTable("decimalDataWithNulls") + df + } + protected lazy val binaryData: DataFrame = { val df = sqlContext.sparkContext.parallelize( BinaryData("12".getBytes(StandardCharsets.UTF_8), 1) :: @@ -267,6 +280,7 @@ private[sql] trait SQLTestData { self => negativeData largeAndSmallInts decimalData + decimalDataWithNulls binaryData upperCaseData lowerCaseData @@ -296,6 +310,7 @@ private[sql] object SQLTestData { case class TestData3(a: Int, b: Option[Int]) case class LargeAndSmallInts(a: Int, b: Int) case class DecimalData(a: BigDecimal, b: BigDecimal) + case class DecimalDataWithNulls(a: BigDecimal, b: BigDecimal) case class BinaryData(a: Array[Byte], b: Int) case class UpperCaseData(N: Int, L: String) case class LowerCaseData(n: Int, l: String)