From 3afe448d39dc4877b2f2c62b3059aeb3ced0bd96 Mon Sep 17 00:00:00 2001
From: Yin Huai <yhuai@databricks.com>
Date: Wed, 21 Oct 2015 13:43:17 -0700
Subject: [PATCH] [SPARK-9740][SPARK-9592][SPARK-9210][SQL] Change the default
 behavior of First/Last to RESPECT NULLS.

I am changing the default behavior of `First`/`Last` to respect null values (the SQL standard default behavior).

https://issues.apache.org/jira/browse/SPARK-9740

Author: Yin Huai <yhuai@databricks.com>

Closes #8113 from yhuai/firstLast.
---
 .../sql/catalyst/analysis/CheckAnalysis.scala |   3 +-
 .../catalyst/analysis/FunctionRegistry.scala  |   2 +
 .../expressions/aggregate/functions.scala     | 105 +++++++++++++++---
 .../expressions/aggregate/utils.scala         |   8 +-
 .../sql/catalyst/expressions/aggregates.scala |  95 ++++++++++++----
 .../spark/sql/expressions/WindowSpec.scala    |  13 ++-
 .../execution/AggregationQuerySuite.scala     |  38 +++++++
 7 files changed, 219 insertions(+), 45 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index ab215407f7..98d6637c06 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -113,7 +113,8 @@ trait CheckAnalysis {
                 failAnalysis(
                   s"expression '${e.prettyString}' is neither present in the group by, " +
                     s"nor is it an aggregate function. " +
-                    "Add to group by or wrap in first() if you don't care which value you get.")
+                    "Add to group by or wrap in first() (or first_value) if you don't care " +
+                    "which value you get.")
               case e if groupingExprs.exists(_.semanticEquals(e)) => // OK
               case e if e.references.isEmpty => // OK
               case e => e.children.foreach(checkValidAggregateExpression)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index ba77b70a37..f73b24e363 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -179,7 +179,9 @@ object FunctionRegistry {
     expression[Average]("avg"),
     expression[Count]("count"),
     expression[First]("first"),
+    expression[First]("first_value"),
     expression[Last]("last"),
+    expression[Last]("last_value"),
     expression[Max]("max"),
     expression[Min]("min"),
     expression[Stddev]("stddev"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index c0bc7ec09c..515246d344 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -21,6 +21,8 @@ import java.lang.{Long => JLong}
 import java.util
 
 import com.clearspring.analytics.hash.MurmurHash
+
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
@@ -118,7 +120,23 @@ case class Count(child: Expression) extends DeclarativeAggregate {
   override val evaluateExpression = Cast(currentCount, LongType)
 }
 
-case class First(child: Expression) extends DeclarativeAggregate {
+/**
+ * Returns the first value of `child` for a group of rows. If the first value of `child`
+ * is `null`, it returns `null` (respecting nulls). Even if [[First]] is used on a already
+ * sorted column, if we do partial aggregation and final aggregation (when mergeExpression
+ * is used) its result will not be deterministic (unless the input table is sorted and has
+ * a single partition, and we use a single reducer to do the aggregation.).
+ * @param child
+ */
+case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate {
+
+  def this(child: Expression) = this(child, Literal.create(false, BooleanType))
+
+  private val ignoreNulls: Boolean = ignoreNullsExpr match {
+    case Literal(b: Boolean, BooleanType) => b
+    case _ =>
+      throw new AnalysisException("The second argument of First should be a boolean literal.")
+  }
 
   override def children: Seq[Expression] = child :: Nil
 
@@ -135,24 +153,61 @@ case class First(child: Expression) extends DeclarativeAggregate {
 
   private val first = AttributeReference("first", child.dataType)()
 
-  override val aggBufferAttributes = first :: Nil
+  private val valueSet = AttributeReference("valueSet", BooleanType)()
+
+  override val aggBufferAttributes = first :: valueSet :: Nil
 
   override val initialValues = Seq(
-    /* first = */ Literal.create(null, child.dataType)
+    /* first = */ Literal.create(null, child.dataType),
+    /* valueSet = */ Literal.create(false, BooleanType)
   )
 
-  override val updateExpressions = Seq(
-    /* first = */ If(IsNull(first), child, first)
-  )
+  override val updateExpressions = {
+    if (ignoreNulls) {
+      Seq(
+        /* first = */ If(Or(valueSet, IsNull(child)), first, child),
+        /* valueSet = */ Or(valueSet, IsNotNull(child))
+      )
+    } else {
+      Seq(
+        /* first = */ If(valueSet, first, child),
+        /* valueSet = */ Literal.create(true, BooleanType)
+      )
+    }
+  }
 
-  override val mergeExpressions = Seq(
-    /* first = */ If(IsNull(first.left), first.right, first.left)
-  )
+  override val mergeExpressions = {
+    // For first, we can just check if valueSet.left is set to true. If it is set
+    // to true, we use first.right. If not, we use first.right (even if valueSet.right is
+    // false, we are safe to do so because first.right will be null in this case).
+    Seq(
+      /* first = */ If(valueSet.left, first.left, first.right),
+      /* valueSet = */ Or(valueSet.left, valueSet.right)
+    )
+  }
 
   override val evaluateExpression = first
+
+  override def toString: String = s"FIRST($child)${if (ignoreNulls) " IGNORE NULLS"}"
 }
 
-case class Last(child: Expression) extends DeclarativeAggregate {
+/**
+ * Returns the last value of `child` for a group of rows. If the last value of `child`
+ * is `null`, it returns `null` (respecting nulls). Even if [[Last]] is used on a already
+ * sorted column, if we do partial aggregation and final aggregation (when mergeExpression
+ * is used) its result will not be deterministic (unless the input table is sorted and has
+ * a single partition, and we use a single reducer to do the aggregation.).
+ * @param child
+ */
+case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate {
+
+  def this(child: Expression) = this(child, Literal.create(false, BooleanType))
+
+  private val ignoreNulls: Boolean = ignoreNullsExpr match {
+    case Literal(b: Boolean, BooleanType) => b
+    case _ =>
+      throw new AnalysisException("The second argument of First should be a boolean literal.")
+  }
 
   override def children: Seq[Expression] = child :: Nil
 
@@ -175,15 +230,33 @@ case class Last(child: Expression) extends DeclarativeAggregate {
     /* last = */ Literal.create(null, child.dataType)
   )
 
-  override val updateExpressions = Seq(
-    /* last = */ If(IsNull(child), last, child)
-  )
+  override val updateExpressions = {
+    if (ignoreNulls) {
+      Seq(
+        /* last = */ If(IsNull(child), last, child)
+      )
+    } else {
+      Seq(
+        /* last = */ child
+      )
+    }
+  }
 
-  override val mergeExpressions = Seq(
-    /* last = */ If(IsNull(last.right), last.left, last.right)
-  )
+  override val mergeExpressions = {
+    if (ignoreNulls) {
+      Seq(
+        /* last = */ If(IsNull(last.right), last.left, last.right)
+      )
+    } else {
+      Seq(
+        /* last = */ last.right
+      )
+    }
+  }
 
   override val evaluateExpression = last
+
+  override def toString: String = s"LAST($child)${if (ignoreNulls) " IGNORE NULLS"}"
 }
 
 case class Max(child: Expression) extends DeclarativeAggregate {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
index f656ccf13b..12bdab0915 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala
@@ -61,15 +61,15 @@ object Utils {
             mode = aggregate.Complete,
             isDistinct = true)
 
-        case expressions.First(child) =>
+        case expressions.First(child, ignoreNulls) =>
           aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.First(child),
+            aggregateFunction = aggregate.First(child, ignoreNulls),
             mode = aggregate.Complete,
             isDistinct = false)
 
-        case expressions.Last(child) =>
+        case expressions.Last(child, ignoreNulls) =>
           aggregate.AggregateExpression2(
-            aggregateFunction = aggregate.Last(child),
+            aggregateFunction = aggregate.Last(child, ignoreNulls),
             mode = aggregate.Complete,
             isDistinct = false)
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index f1c47f3904..95061c4635 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import com.clearspring.analytics.stream.cardinality.HyperLogLog
 
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
@@ -630,59 +631,113 @@ case class CombineSetsAndSumFunction(
   }
 }
 
-case class First(child: Expression) extends UnaryExpression with PartialAggregate1 {
+case class First(
+    child: Expression,
+    ignoreNullsExpr: Expression)
+  extends UnaryExpression with PartialAggregate1 {
+
+  def this(child: Expression) = this(child, Literal.create(false, BooleanType))
+
+  private val ignoreNulls: Boolean = ignoreNullsExpr match {
+    case Literal(b: Boolean, BooleanType) => b
+    case _ =>
+      throw new AnalysisException("The second argument of First should be a boolean literal.")
+  }
+
   override def nullable: Boolean = true
   override def dataType: DataType = child.dataType
-  override def toString: String = s"FIRST($child)"
+  override def toString: String = s"FIRST(${child}${if (ignoreNulls) " IGNORE NULLS"})"
 
   override def asPartial: SplitEvaluation = {
-    val partialFirst = Alias(First(child), "PartialFirst")()
+    val partialFirst = Alias(First(child, ignoreNulls), "PartialFirst")()
     SplitEvaluation(
-      First(partialFirst.toAttribute),
+      First(partialFirst.toAttribute, ignoreNulls),
       partialFirst :: Nil)
   }
-  override def newInstance(): FirstFunction = new FirstFunction(child, this)
+  override def newInstance(): FirstFunction = new FirstFunction(child, ignoreNulls, this)
 }
 
-case class FirstFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
-  def this() = this(null, null) // Required for serialization.
+object First {
+  def apply(child: Expression): First = First(child, ignoreNulls = false)
 
-  var result: Any = null
+  def apply(child: Expression, ignoreNulls: Boolean): First =
+    First(child, Literal.create(ignoreNulls, BooleanType))
+}
+
+case class FirstFunction(
+    expr: Expression,
+    ignoreNulls: Boolean,
+    base: AggregateExpression1)
+  extends AggregateFunction1 {
+
+  def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization.
+
+  private[this] var result: Any = null
+
+  private[this] var valueSet: Boolean = false
 
   override def update(input: InternalRow): Unit = {
-    // We ignore null values.
-    if (result == null) {
-      result = expr.eval(input)
+    if (!valueSet) {
+      val value = expr.eval(input)
+      // When we have not set the result, we will set the result if we respect nulls
+      // (i.e. ignoreNulls is false), or we ignore nulls and the evaluated value is not null.
+      if (!ignoreNulls || (ignoreNulls && value != null)) {
+        result = value
+        valueSet = true
+      }
     }
   }
 
   override def eval(input: InternalRow): Any = result
 }
 
-case class Last(child: Expression) extends UnaryExpression with PartialAggregate1 {
+case class Last(
+    child: Expression,
+    ignoreNullsExpr: Expression)
+  extends UnaryExpression with PartialAggregate1 {
+
+  def this(child: Expression) = this(child, Literal.create(false, BooleanType))
+
+  private val ignoreNulls: Boolean = ignoreNullsExpr match {
+    case Literal(b: Boolean, BooleanType) => b
+    case _ =>
+      throw new AnalysisException("The second argument of First should be a boolean literal.")
+  }
+
   override def references: AttributeSet = child.references
   override def nullable: Boolean = true
   override def dataType: DataType = child.dataType
-  override def toString: String = s"LAST($child)"
+  override def toString: String = s"LAST($child)${if (ignoreNulls) " IGNORE NULLS"}"
 
   override def asPartial: SplitEvaluation = {
-    val partialLast = Alias(Last(child), "PartialLast")()
+    val partialLast = Alias(Last(child, ignoreNulls), "PartialLast")()
     SplitEvaluation(
-      Last(partialLast.toAttribute),
+      Last(partialLast.toAttribute, ignoreNulls),
       partialLast :: Nil)
   }
-  override def newInstance(): LastFunction = new LastFunction(child, this)
+  override def newInstance(): LastFunction = new LastFunction(child, ignoreNulls, this)
 }
 
-case class LastFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 {
-  def this() = this(null, null) // Required for serialization.
+object Last {
+  def apply(child: Expression): Last = Last(child, ignoreNulls = false)
+
+  def apply(child: Expression, ignoreNulls: Boolean): Last =
+    Last(child, Literal.create(ignoreNulls, BooleanType))
+}
+
+case class LastFunction(
+    expr: Expression,
+    ignoreNulls: Boolean,
+    base: AggregateExpression1)
+  extends AggregateFunction1 {
+
+  def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization.
 
   var result: Any = null
 
   override def update(input: InternalRow): Unit = {
     val value = expr.eval(input)
-    // We ignore null values.
-    if (value != null) {
+    if (!ignoreNulls || (ignoreNulls && value != null)) {
       result = value
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
index c3d2246297..8b9247adea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.expressions
 
 import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.types.BooleanType
 import org.apache.spark.sql.{Column, catalyst}
 import org.apache.spark.sql.catalyst.expressions._
 
@@ -149,13 +150,17 @@ class WindowSpec private[sql](
       case Count(child) => WindowExpression(
         UnresolvedWindowFunction("count", child :: Nil),
         WindowSpecDefinition(partitionSpec, orderSpec, frame))
-      case First(child) => WindowExpression(
+      case First(child, ignoreNulls) => WindowExpression(
         // TODO this is a hack for Hive UDAF first_value
-        UnresolvedWindowFunction("first_value", child :: Nil),
+        UnresolvedWindowFunction(
+          "first_value",
+          child :: ignoreNulls :: Nil),
         WindowSpecDefinition(partitionSpec, orderSpec, frame))
-      case Last(child) => WindowExpression(
+      case Last(child, ignoreNulls) => WindowExpression(
         // TODO this is a hack for Hive UDAF last_value
-        UnresolvedWindowFunction("last_value", child :: Nil),
+        UnresolvedWindowFunction(
+          "last_value",
+          child :: ignoreNulls :: Nil),
         WindowSpecDefinition(partitionSpec, orderSpec, frame))
       case Min(child) => WindowExpression(
         UnresolvedWindowFunction("min", child :: Nil),
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index c9e1bb1995..f38a3f63c3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -323,6 +323,44 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
       Row(11.125) :: Nil)
   }
 
+  test("first_value and last_value") {
+    // We force to use a single partition for the sort and aggregate to make result
+    // deterministic.
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+      checkAnswer(
+        sqlContext.sql(
+          """
+            |SELECT
+            |  first_valUE(key),
+            |  lasT_value(key),
+            |  firSt(key),
+            |  lASt(key),
+            |  first_valUE(key, true),
+            |  lasT_value(key, true),
+            |  firSt(key, true),
+            |  lASt(key, true)
+            |FROM (SELECT key FROM agg1 ORDER BY key) tmp
+          """.stripMargin),
+        Row(null, 3, null, 3, 1, 3, 1, 3) :: Nil)
+
+      checkAnswer(
+        sqlContext.sql(
+          """
+            |SELECT
+            |  first_valUE(key),
+            |  lasT_value(key),
+            |  firSt(key),
+            |  lASt(key),
+            |  first_valUE(key, true),
+            |  lasT_value(key, true),
+            |  firSt(key, true),
+            |  lASt(key, true)
+            |FROM (SELECT key FROM agg1 ORDER BY key DESC) tmp
+          """.stripMargin),
+        Row(3, null, 3, null, 3, 1, 3, 1) :: Nil)
+    }
+  }
+
   test("udaf") {
     checkAnswer(
       sqlContext.sql(
-- 
GitLab