From 772e7c18fb1a79c0f080408cb43307fe89a4fa04 Mon Sep 17 00:00:00 2001 From: Yin Huai <yhuai@databricks.com> Date: Mon, 17 Aug 2015 15:30:50 -0700 Subject: [PATCH] [SPARK-9592] [SQL] Fix Last function implemented based on AggregateExpression1. https://issues.apache.org/jira/browse/SPARK-9592 #8113 has the fundamental fix. But, if we want to minimize the number of changed lines, we can go with this one. Then, in 1.6, we merge #8113. Author: Yin Huai <yhuai@databricks.com> Closes #8172 from yhuai/lastFix and squashes the following commits: b28c42a [Yin Huai] Regression test. af87086 [Yin Huai] Fix last. --- .../sql/catalyst/expressions/aggregates.scala | 9 +++++++-- .../hive/execution/AggregationQuerySuite.scala | 15 +++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 2cf8312ea5..5e8298aaaa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -650,6 +650,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression1) extends A var result: Any = null override def update(input: InternalRow): Unit = { + // We ignore null values. if (result == null) { result = expr.eval(input) } @@ -679,10 +680,14 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag var result: Any = null override def update(input: InternalRow): Unit = { - result = input + val value = expr.eval(input) + // We ignore null values. + if (value != null) { + result = value + } } override def eval(input: InternalRow): Any = { - if (result != null) expr.eval(result.asInstanceOf[InternalRow]) else null + result } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index a312f84958..119663af18 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -480,6 +480,21 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Row(0, null, 1, 1, null, 0) :: Nil) } + test("test Last implemented based on AggregateExpression1") { + // TODO: Remove this test once we remove AggregateExpression1. + import org.apache.spark.sql.functions._ + val df = Seq((1, 1), (2, 2), (3, 3)).toDF("i", "j").repartition(1) + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + + checkAnswer( + df.groupBy("i").agg(last("j")), + df + ) + } + } + test("error handling") { withSQLConf("spark.sql.useAggregate2" -> "false") { val errorMessage = intercept[AnalysisException] { -- GitLab