From de333d121da4cb80d45819cbcf8b4246e48ec4d0 Mon Sep 17 00:00:00 2001
From: xin wu <xinwu@us.ibm.com>
Date: Sun, 25 Sep 2016 16:46:12 -0700
Subject: [PATCH] [SPARK-17551][SQL] Add DataFrame API for null ordering

## What changes were proposed in this pull request?
This pull request adds Scala/Java DataFrame API for null ordering (NULLS FIRST | LAST).

Also did some minor clean up for related code (e.g. incorrect indentation), and renamed "orderby-nulls-ordering.sql" to be consistent with existing test files.

## How was this patch tested?
Added a new test case in DataFrameSuite.

Author: petermaxlee <petermaxlee@gmail.com>
Author: Xin Wu <xinwu@us.ibm.com>

Closes #15123 from petermaxlee/SPARK-17551.
---
 .../sql/catalyst/expressions/SortOrder.scala  | 28 ++------
 .../codegen/GenerateOrdering.scala            | 16 ++---
 .../scala/org/apache/spark/sql/Column.scala   | 64 ++++++++++++++++++-
 .../org/apache/spark/sql/functions.scala      | 51 ++++++++++++++-
 ...dering.sql => order-by-nulls-ordering.sql} |  0
 ...ql.out => order-by-nulls-ordering.sql.out} |  0
 .../org/apache/spark/sql/DataFrameSuite.scala | 18 ++++++
 7 files changed, 144 insertions(+), 33 deletions(-)
 rename sql/core/src/test/resources/sql-tests/inputs/{orderby-nulls-ordering.sql => order-by-nulls-ordering.sql} (100%)
 rename sql/core/src/test/resources/sql-tests/results/{orderby-nulls-ordering.sql.out => order-by-nulls-ordering.sql.out} (100%)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index d015125bac..3bebd552ef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -54,10 +54,7 @@ case object NullsLast extends NullOrdering{
  * An expression that can be used to sort a tuple.  This class extends expression primarily so that
  * transformations over expression will descend into its child.
  */
-case class SortOrder(
-  child: Expression,
-  direction: SortDirection,
-  nullOrdering: NullOrdering)
+case class SortOrder(child: Expression, direction: SortDirection, nullOrdering: NullOrdering)
   extends UnaryExpression with Unevaluable {
 
   /** Sort order is not foldable because we don't have an eval for it. */
@@ -94,17 +91,9 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
 
   val nullValue = child.child.dataType match {
     case BooleanType | DateType | TimestampType | _: IntegralType =>
-      if (nullAsSmallest) {
-        Long.MinValue
-      } else {
-        Long.MaxValue
-      }
+      if (nullAsSmallest) Long.MinValue else Long.MaxValue
     case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
-      if (nullAsSmallest) {
-        Long.MinValue
-      } else {
-        Long.MaxValue
-      }
+      if (nullAsSmallest) Long.MinValue else Long.MaxValue
     case _: DecimalType =>
       if (nullAsSmallest) {
         DoublePrefixComparator.computePrefix(Double.NegativeInfinity)
@@ -112,16 +101,13 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
         DoublePrefixComparator.computePrefix(Double.NaN)
       }
     case _ =>
-      if (nullAsSmallest) {
-        0L
-      } else {
-        -1L
-      }
+      if (nullAsSmallest) 0L else -1L
   }
 
-  private def nullAsSmallest: Boolean = (child.isAscending && child.nullOrdering == NullsFirst) ||
+  private def nullAsSmallest: Boolean = {
+    (child.isAscending && child.nullOrdering == NullsFirst) ||
       (!child.isAscending && child.nullOrdering == NullsLast)
-
+  }
 
   override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index e7df95e114..f1c30ef6c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -100,16 +100,16 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
             // Nothing
           } else if ($isNullA) {
             return ${
-        order.nullOrdering match {
-          case NullsFirst => "-1"
-          case NullsLast => "1"
-        }};
+              order.nullOrdering match {
+                case NullsFirst => "-1"
+                case NullsLast => "1"
+              }};
           } else if ($isNullB) {
             return ${
-        order.nullOrdering match {
-          case NullsFirst => "1"
-          case NullsLast => "-1"
-        }};
+              order.nullOrdering match {
+                case NullsFirst => "1"
+                case NullsLast => "-1"
+              }};
           } else {
             int comp = ${ctx.genComp(order.child.dataType, primitiveA, primitiveB)};
             if (comp != 0) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 844ca7a8e9..63da501f18 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -1007,7 +1007,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
   /**
    * Returns an ordering used in sorting.
    * {{{
-   *   // Scala: sort a DataFrame by age column in descending order.
+   *   // Scala
    *   df.sort(df("age").desc)
    *
    *   // Java
@@ -1020,7 +1020,37 @@ class Column(protected[sql] val expr: Expression) extends Logging {
   def desc: Column = withExpr { SortOrder(expr, Descending) }
 
   /**
-   * Returns an ordering used in sorting.
+   * Returns a descending ordering used in sorting, where null values appear before non-null values.
+   * {{{
+   *   // Scala: sort a DataFrame by age column in descending order and null values appearing first.
+   *   df.sort(df("age").desc_nulls_first)
+   *
+   *   // Java
+   *   df.sort(df.col("age").desc_nulls_first());
+   * }}}
+   *
+   * @group expr_ops
+   * @since 2.1.0
+   */
+  def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst) }
+
+  /**
+   * Returns a descending ordering used in sorting, where null values appear after non-null values.
+   * {{{
+   *   // Scala: sort a DataFrame by age column in descending order and null values appearing last.
+   *   df.sort(df("age").desc_nulls_last)
+   *
+   *   // Java
+   *   df.sort(df.col("age").desc_nulls_last());
+   * }}}
+   *
+   * @group expr_ops
+   * @since 2.1.0
+   */
+  def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast) }
+
+  /**
+   * Returns an ascending ordering used in sorting.
    * {{{
    *   // Scala: sort a DataFrame by age column in ascending order.
    *   df.sort(df("age").asc)
@@ -1034,6 +1064,36 @@ class Column(protected[sql] val expr: Expression) extends Logging {
    */
   def asc: Column = withExpr { SortOrder(expr, Ascending) }
 
+  /**
+   * Returns an ascending ordering used in sorting, where null values appear before non-null values.
+   * {{{
+   *   // Scala: sort a DataFrame by age column in ascending order and null values appearing first.
+   *   df.sort(df("age").asc_nulls_last)
+   *
+   *   // Java
+   *   df.sort(df.col("age").asc_nulls_last());
+   * }}}
+   *
+   * @group expr_ops
+   * @since 2.1.0
+   */
+  def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst) }
+
+  /**
+   * Returns an ordering used in sorting, where null values appear after non-null values.
+   * {{{
+   *   // Scala: sort a DataFrame by age column in ascending order and null values appearing last.
+   *   df.sort(df("age").asc_nulls_last)
+   *
+   *   // Java
+   *   df.sort(df.col("age").asc_nulls_last());
+   * }}}
+   *
+   * @group expr_ops
+   * @since 2.1.0
+   */
+  def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast) }
+
   /**
    * Prints the expression to the console for debugging purpose.
    *
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 960c87f60e..47bf41a2da 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -109,7 +109,6 @@ object functions {
   /**
    * Returns a sort expression based on ascending order of the column.
    * {{{
-   *   // Sort by dept in ascending order, and then age in descending order.
    *   df.sort(asc("dept"), desc("age"))
    * }}}
    *
@@ -118,10 +117,33 @@ object functions {
    */
   def asc(columnName: String): Column = Column(columnName).asc
 
+  /**
+   * Returns a sort expression based on ascending order of the column,
+   * and null values return before non-null values.
+   * {{{
+   *   df.sort(asc_nulls_last("dept"), desc("age"))
+   * }}}
+   *
+   * @group sort_funcs
+   * @since 2.1.0
+   */
+  def asc_nulls_first(columnName: String): Column = Column(columnName).asc_nulls_first
+
+  /**
+   * Returns a sort expression based on ascending order of the column,
+   * and null values appear after non-null values.
+   * {{{
+   *   df.sort(asc_nulls_last("dept"), desc("age"))
+   * }}}
+   *
+   * @group sort_funcs
+   * @since 2.1.0
+   */
+  def asc_nulls_last(columnName: String): Column = Column(columnName).asc_nulls_last
+
   /**
    * Returns a sort expression based on the descending order of the column.
    * {{{
-   *   // Sort by dept in ascending order, and then age in descending order.
    *   df.sort(asc("dept"), desc("age"))
    * }}}
    *
@@ -130,6 +152,31 @@ object functions {
    */
   def desc(columnName: String): Column = Column(columnName).desc
 
+  /**
+   * Returns a sort expression based on the descending order of the column,
+   * and null values appear before non-null values.
+   * {{{
+   *   df.sort(asc("dept"), desc_nulls_first("age"))
+   * }}}
+   *
+   * @group sort_funcs
+   * @since 2.1.0
+   */
+  def desc_nulls_first(columnName: String): Column = Column(columnName).desc_nulls_first
+
+  /**
+   * Returns a sort expression based on the descending order of the column,
+   * and null values appear after non-null values.
+   * {{{
+   *   df.sort(asc("dept"), desc_nulls_last("age"))
+   * }}}
+   *
+   * @group sort_funcs
+   * @since 2.1.0
+   */
+  def desc_nulls_last(columnName: String): Column = Column(columnName).desc_nulls_last
+
+
   //////////////////////////////////////////////////////////////////////////////////////////////
   // Aggregate functions
   //////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/test/resources/sql-tests/inputs/orderby-nulls-ordering.sql b/sql/core/src/test/resources/sql-tests/inputs/order-by-nulls-ordering.sql
similarity index 100%
rename from sql/core/src/test/resources/sql-tests/inputs/orderby-nulls-ordering.sql
rename to sql/core/src/test/resources/sql-tests/inputs/order-by-nulls-ordering.sql
diff --git a/sql/core/src/test/resources/sql-tests/results/orderby-nulls-ordering.sql.out b/sql/core/src/test/resources/sql-tests/results/order-by-nulls-ordering.sql.out
similarity index 100%
rename from sql/core/src/test/resources/sql-tests/results/orderby-nulls-ordering.sql.out
rename to sql/core/src/test/resources/sql-tests/results/order-by-nulls-ordering.sql.out
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 2c60a7dd92..16cc368208 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -326,6 +326,24 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
       Row(6))
   }
 
+  test("sorting with null ordering") {
+    val data = Seq[java.lang.Integer](2, 1, null).toDF("key")
+
+    checkAnswer(data.orderBy('key.asc), Row(null) :: Row(1) :: Row(2) :: Nil)
+    checkAnswer(data.orderBy(asc("key")), Row(null) :: Row(1) :: Row(2) :: Nil)
+    checkAnswer(data.orderBy('key.asc_nulls_first), Row(null) :: Row(1) :: Row(2) :: Nil)
+    checkAnswer(data.orderBy(asc_nulls_first("key")), Row(null) :: Row(1) :: Row(2) :: Nil)
+    checkAnswer(data.orderBy('key.asc_nulls_last), Row(1) :: Row(2) :: Row(null) :: Nil)
+    checkAnswer(data.orderBy(asc_nulls_last("key")), Row(1) :: Row(2) :: Row(null) :: Nil)
+
+    checkAnswer(data.orderBy('key.desc), Row(2) :: Row(1) :: Row(null) :: Nil)
+    checkAnswer(data.orderBy(desc("key")), Row(2) :: Row(1) :: Row(null) :: Nil)
+    checkAnswer(data.orderBy('key.desc_nulls_first), Row(null) :: Row(2) :: Row(1) :: Nil)
+    checkAnswer(data.orderBy(desc_nulls_first("key")), Row(null) :: Row(2) :: Row(1) :: Nil)
+    checkAnswer(data.orderBy('key.desc_nulls_last), Row(2) :: Row(1) :: Row(null) :: Nil)
+    checkAnswer(data.orderBy(desc_nulls_last("key")), Row(2) :: Row(1) :: Row(null) :: Nil)
+  }
+
   test("global sorting") {
     checkAnswer(
       testData2.orderBy('a.asc, 'b.asc),
-- 
GitLab