From 772e7c18fb1a79c0f080408cb43307fe89a4fa04 Mon Sep 17 00:00:00 2001
From: Yin Huai <yhuai@databricks.com>
Date: Mon, 17 Aug 2015 15:30:50 -0700
Subject: [PATCH] [SPARK-9592] [SQL] Fix Last function implemented based on
 AggregateExpression1.

https://issues.apache.org/jira/browse/SPARK-9592

#8113 has the fundamental fix. But, if we want to minimize the number of changed lines, we can go with this one. Then, in 1.6, we merge #8113.

Author: Yin Huai <yhuai@databricks.com>

Closes #8172 from yhuai/lastFix and squashes the following commits:

b28c42a [Yin Huai] Regression test.
af87086 [Yin Huai] Fix last.
---
 .../sql/catalyst/expressions/aggregates.scala     |  9 +++++++--
 .../hive/execution/AggregationQuerySuite.scala    | 15 +++++++++++++++
 2 files changed, 22 insertions(+), 2 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 2cf8312ea5..5e8298aaaa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -650,6 +650,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression1) extends A
   var result: Any = null
 
   override def update(input: InternalRow): Unit = {
+    // We ignore null values.
     if (result == null) {
       result = expr.eval(input)
     }
@@ -679,10 +680,14 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag
   var result: Any = null
 
   override def update(input: InternalRow): Unit = {
-    result = input
+    val value = expr.eval(input)
+    // We ignore null values.
+    if (value != null) {
+      result = value
+    }
   }
 
   override def eval(input: InternalRow): Any = {
-    if (result != null) expr.eval(result.asInstanceOf[InternalRow]) else null
+    result
   }
 }
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index a312f84958..119663af18 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -480,6 +480,21 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
         Row(0, null, 1, 1, null, 0) :: Nil)
   }
 
+  test("test Last implemented based on AggregateExpression1") {
+    // TODO: Remove this test once we remove AggregateExpression1.
+    import org.apache.spark.sql.functions._
+    val df = Seq((1, 1), (2, 2), (3, 3)).toDF("i", "j").repartition(1)
+    withSQLConf(
+      SQLConf.SHUFFLE_PARTITIONS.key -> "1",
+      SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
+
+      checkAnswer(
+        df.groupBy("i").agg(last("j")),
+        df
+      )
+    }
+  }
+
   test("error handling") {
     withSQLConf("spark.sql.useAggregate2" -> "false") {
       val errorMessage = intercept[AnalysisException] {
-- 
GitLab