diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 2cf8312ea59aa27a31a14b28256dd956752738d1..5e8298aaaa9cb08173a60855363afba12cc6c0a7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -650,6 +650,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression1) extends A
   var result: Any = null
 
   override def update(input: InternalRow): Unit = {
+    // We ignore null values.
     if (result == null) {
       result = expr.eval(input)
     }
@@ -679,10 +680,14 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag
   var result: Any = null
 
   override def update(input: InternalRow): Unit = {
-    result = input
+    val value = expr.eval(input)
+    // We ignore null values.
+    if (value != null) {
+      result = value
+    }
   }
 
   override def eval(input: InternalRow): Any = {
-    if (result != null) expr.eval(result.asInstanceOf[InternalRow]) else null
+    result
   }
 }
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index a312f849582485e4fb420aab5af7816a24dc8181..119663af1887a2a274384b5a28e769868729e4bf 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -480,6 +480,21 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
         Row(0, null, 1, 1, null, 0) :: Nil)
   }
 
+  test("test Last implemented based on AggregateExpression1") {
+    // TODO: Remove this test once we remove AggregateExpression1.
+    import org.apache.spark.sql.functions._
+    val df = Seq((1, 1), (2, 2), (3, 3)).toDF("i", "j").repartition(1)
+    withSQLConf(
+      SQLConf.SHUFFLE_PARTITIONS.key -> "1",
+      SQLConf.USE_SQL_AGGREGATE2.key -> "false") {
+
+      checkAnswer(
+        df.groupBy("i").agg(last("j")),
+        df
+      )
+    }
+  }
+
   test("error handling") {
     withSQLConf("spark.sql.useAggregate2" -> "false") {
       val errorMessage = intercept[AnalysisException] {