From 1e43851d6455f65b850ea0327d0e92f65395d23f Mon Sep 17 00:00:00 2001
From: Liang-Chi Hsieh <viirya@gmail.com>
Date: Thu, 16 Apr 2015 17:50:20 -0700
Subject: [PATCH] [SPARK-6899][SQL] Fix type mismatch when using codegen with
 Average on DecimalType

JIRA https://issues.apache.org/jira/browse/SPARK-6899

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #5517 from viirya/fix_codegen_average and squashes the following commits:

8ae5f65 [Liang-Chi Hsieh] Add the case of DecimalType.Unlimited to Average.
---
 .../spark/sql/catalyst/expressions/aggregates.scala      | 2 +-
 .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 9 +++++++++
 2 files changed, 10 insertions(+), 1 deletion(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 14a855054b..f3830c6d3b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -326,7 +326,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
 
   override def asPartial: SplitEvaluation = {
     child.dataType match {
-      case DecimalType.Fixed(_, _) =>
+      case DecimalType.Fixed(_, _) | DecimalType.Unlimited =>
         // Turn the child to unlimited decimals for calculation, before going back to fixed
         val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
         val partialCount = Alias(Count(child), "PartialCount")()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 34b2cb054a..44a7d1e7bb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -537,4 +537,13 @@ class DataFrameSuite extends QueryTest {
     val df = TestSQLContext.createDataFrame(rowRDD, schema)
     df.rdd.collect()
   }
+
+  test("SPARK-6899") {
+    val originalValue = TestSQLContext.conf.codegenEnabled
+    TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, "true")
+    checkAnswer(
+      decimalData.agg(avg('a)),
+      Row(new java.math.BigDecimal(2.0)))
+    TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
+  }
 }
-- 
GitLab