From 5a8b978fabb60aa178274f86432c63680c8b351a Mon Sep 17 00:00:00 2001
From: Herman van Hovell <hvanhovell@questtec.nl>
Date: Sun, 31 Jan 2016 13:56:13 -0800
Subject: [PATCH] [SPARK-13049] Add First/last with ignore nulls to
 functions.scala

This PR adds the ability to specify the ```ignoreNulls``` option to the functions dsl, e.g:
```df.select($"id", last($"value", ignoreNulls = true).over(Window.partitionBy($"id").orderBy($"other"))```

This PR is some where between a bug fix (see the JIRA) and a new feature. I am not sure if we should backport to 1.6.

cc yhuai

Author: Herman van Hovell <hvanhovell@questtec.nl>

Closes #10957 from hvanhovell/SPARK-13049.
---
 python/pyspark/sql/functions.py               |  26 +++-
 python/pyspark/sql/tests.py                   |  10 ++
 .../org/apache/spark/sql/functions.scala      | 118 ++++++++++++++----
 .../spark/sql/DataFrameWindowSuite.scala      |  32 +++++
 4 files changed, 157 insertions(+), 29 deletions(-)

diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 719eca8f55..0d57085267 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -81,8 +81,6 @@ _functions = {
 
     'max': 'Aggregate function: returns the maximum value of the expression in a group.',
     'min': 'Aggregate function: returns the minimum value of the expression in a group.',
-    'first': 'Aggregate function: returns the first value in a group.',
-    'last': 'Aggregate function: returns the last value in a group.',
     'count': 'Aggregate function: returns the number of items in a group.',
     'sum': 'Aggregate function: returns the sum of all values in the expression.',
     'avg': 'Aggregate function: returns the average of the values in a group.',
@@ -278,6 +276,18 @@ def countDistinct(col, *cols):
     return Column(jc)
 
 
+@since(1.3)
+def first(col, ignorenulls=False):
+    """Aggregate function: returns the first value in a group.
+
+    The function by default returns the first values it sees. It will return the first non-null
+    value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.first(_to_java_column(col), ignorenulls)
+    return Column(jc)
+
+
 @since(1.6)
 def input_file_name():
     """Creates a string column for the file name of the current Spark task.
@@ -310,6 +320,18 @@ def isnull(col):
     return Column(sc._jvm.functions.isnull(_to_java_column(col)))
 
 
+@since(1.3)
+def last(col, ignorenulls=False):
+    """Aggregate function: returns the last value in a group.
+
+    The function by default returns the last values it sees. It will return the last non-null
+    value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+    """
+    sc = SparkContext._active_spark_context
+    jc = sc._jvm.functions.last(_to_java_column(col), ignorenulls)
+    return Column(jc)
+
+
 @since(1.6)
 def monotonically_increasing_id():
     """A column that generates monotonically increasing 64-bit integers.
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 410efbafe0..e30aa0a796 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -641,6 +641,16 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
         self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
 
+    def test_first_last_ignorenulls(self):
+        from pyspark.sql import functions
+        df = self.sqlCtx.range(0, 100)
+        df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id"))
+        df3 = df2.select(functions.first(df2.id, False).alias('a'),
+                         functions.first(df2.id, True).alias('b'),
+                         functions.last(df2.id, False).alias('c'),
+                         functions.last(df2.id, True).alias('d'))
+        self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())
+
     def test_corr(self):
         import math
         df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 3a27466176..b970eee4e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -349,19 +349,51 @@ object functions extends LegacyFunctions {
   }
 
   /**
-   * Aggregate function: returns the first value in a group.
-   *
-   * @group agg_funcs
-   * @since 1.3.0
-   */
-  def first(e: Column): Column = withAggregateFunction { new First(e.expr) }
-
-  /**
-   * Aggregate function: returns the first value of a column in a group.
-   *
-   * @group agg_funcs
-   * @since 1.3.0
-   */
+    * Aggregate function: returns the first value in a group.
+    *
+    * The function by default returns the first values it sees. It will return the first non-null
+    * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+    *
+    * @group agg_funcs
+    * @since 2.0.0
+    */
+  def first(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction {
+    new First(e.expr, Literal(ignoreNulls))
+  }
+
+  /**
+    * Aggregate function: returns the first value of a column in a group.
+    *
+    * The function by default returns the first values it sees. It will return the first non-null
+    * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+    *
+    * @group agg_funcs
+    * @since 2.0.0
+    */
+  def first(columnName: String, ignoreNulls: Boolean): Column = {
+    first(Column(columnName), ignoreNulls)
+  }
+
+  /**
+    * Aggregate function: returns the first value in a group.
+    *
+    * The function by default returns the first values it sees. It will return the first non-null
+    * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+    *
+    * @group agg_funcs
+    * @since 1.3.0
+    */
+  def first(e: Column): Column = first(e, ignoreNulls = false)
+
+  /**
+    * Aggregate function: returns the first value of a column in a group.
+    *
+    * The function by default returns the first values it sees. It will return the first non-null
+    * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+    *
+    * @group agg_funcs
+    * @since 1.3.0
+    */
   def first(columnName: String): Column = first(Column(columnName))
 
   /**
@@ -381,20 +413,52 @@ object functions extends LegacyFunctions {
   def kurtosis(columnName: String): Column = kurtosis(Column(columnName))
 
   /**
-   * Aggregate function: returns the last value in a group.
-   *
-   * @group agg_funcs
-   * @since 1.3.0
-   */
-  def last(e: Column): Column = withAggregateFunction { new Last(e.expr) }
-
-  /**
-   * Aggregate function: returns the last value of the column in a group.
-   *
-   * @group agg_funcs
-   * @since 1.3.0
-   */
-  def last(columnName: String): Column = last(Column(columnName))
+    * Aggregate function: returns the last value in a group.
+    *
+    * The function by default returns the last values it sees. It will return the last non-null
+    * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+    *
+    * @group agg_funcs
+    * @since 2.0.0
+    */
+  def last(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction {
+    new Last(e.expr, Literal(ignoreNulls))
+  }
+
+  /**
+    * Aggregate function: returns the last value of the column in a group.
+    *
+    * The function by default returns the last values it sees. It will return the last non-null
+    * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+    *
+    * @group agg_funcs
+    * @since 2.0.0
+    */
+  def last(columnName: String, ignoreNulls: Boolean): Column = {
+    last(Column(columnName), ignoreNulls)
+  }
+
+  /**
+    * Aggregate function: returns the last value in a group.
+    *
+    * The function by default returns the last values it sees. It will return the last non-null
+    * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+    *
+    * @group agg_funcs
+    * @since 1.3.0
+    */
+  def last(e: Column): Column = last(e, ignoreNulls = false)
+
+  /**
+    * Aggregate function: returns the last value of the column in a group.
+    *
+    * The function by default returns the last values it sees. It will return the last non-null
+    * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+    *
+    * @group agg_funcs
+    * @since 1.3.0
+    */
+  def last(columnName: String): Column = last(Column(columnName), ignoreNulls = false)
 
   /**
    * Aggregate function: returns the maximum value of the expression in a group.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala
index 09a56f6f3a..d38842c3c0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala
@@ -312,4 +312,36 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext {
         Row("b", 3, null, null),
         Row("b", 2, null, null)))
   }
+
+  test("last/first with ignoreNulls") {
+    val nullStr: String = null
+    val df = Seq(
+      ("a", 0, nullStr),
+      ("a", 1, "x"),
+      ("a", 2, "y"),
+      ("a", 3, "z"),
+      ("a", 4, nullStr),
+      ("b", 1, nullStr),
+      ("b", 2, nullStr)).
+      toDF("key", "order", "value")
+    val window = Window.partitionBy($"key").orderBy($"order")
+    checkAnswer(
+      df.select(
+        $"key",
+        $"order",
+        first($"value").over(window),
+        first($"value", ignoreNulls = false).over(window),
+        first($"value", ignoreNulls = true).over(window),
+        last($"value").over(window),
+        last($"value", ignoreNulls = false).over(window),
+        last($"value", ignoreNulls = true).over(window)),
+      Seq(
+        Row("a", 0, null, null, null, null, null, null),
+        Row("a", 1, null, null, "x", "x", "x", "x"),
+        Row("a", 2, null, null, "x", "y", "y", "y"),
+        Row("a", 3, null, null, "x", "z", "z", "z"),
+        Row("a", 4, null, null, "x", null, null, "z"),
+        Row("b", 1, null, null, null, null, null, null),
+        Row("b", 2, null, null, null, null, null, null)))
+  }
 }
-- 
GitLab