From 14869ae64eb27830179d4954a5dc3e0a1e1330b4 Mon Sep 17 00:00:00 2001
From: Dongjoon Hyun <dongjoon@apache.org>
Date: Tue, 19 Apr 2016 22:28:11 -0700
Subject: [PATCH] [SPARK-14639] [PYTHON] [R] Add `bround` function in Python/R.

## What changes were proposed in this pull request?

This issue aims to expose Scala `bround` function in Python/R API.
`bround` function is implemented in SPARK-14614 by extending current `round` function.
We used the following semantics from Hive.
```java
public static double bround(double input, int scale) {
    if (Double.isNaN(input) || Double.isInfinite(input)) {
      return input;
    }
    return BigDecimal.valueOf(input).setScale(scale, RoundingMode.HALF_EVEN).doubleValue();
}
```

After this PR, `pyspark` and `sparkR` also support `bround` function.

**PySpark**
```python
>>> from pyspark.sql.functions import bround
>>> sqlContext.createDataFrame([(2.5,)], ['a']).select(bround('a', 0).alias('r')).collect()
[Row(r=2.0)]
```

**SparkR**
```r
> df = createDataFrame(sqlContext, data.frame(x = c(2.5, 3.5)))
> head(collect(select(df, bround(df$x, 0))))
  bround(x, 0)
1            2
2            4
```

## How was this patch tested?

Pass the Jenkins tests (including new testcases).

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #12509 from dongjoon-hyun/SPARK-14639.
---
 R/pkg/NAMESPACE                           |  1 +
 R/pkg/R/functions.R                       | 22 +++++++++++++++++++++-
 R/pkg/R/generics.R                        |  4 ++++
 R/pkg/inst/tests/testthat/test_sparkSQL.R |  5 +++++
 python/pyspark/sql/functions.py           | 19 ++++++++++++++++---
 5 files changed, 47 insertions(+), 4 deletions(-)

diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 10b9d16279..667fff7192 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -126,6 +126,7 @@ exportMethods("%in%",
               "between",
               "bin",
               "bitwiseNOT",
+              "bround",
               "cast",
               "cbrt",
               "ceil",
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index db877b2d63..54234b0455 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -994,7 +994,7 @@ setMethod("rint",
 
 #' round
 #'
-#' Returns the value of the column `e` rounded to 0 decimal places.
+#' Returns the value of the column `e` rounded to 0 decimal places using HALF_UP rounding mode.
 #'
 #' @rdname round
 #' @name round
@@ -1008,6 +1008,26 @@ setMethod("round",
             column(jc)
           })
 
+#' bround
+#'
+#' Returns the value of the column `e` rounded to `scale` decimal places using HALF_EVEN rounding
+#' mode if `scale` >= 0 or at integral part when `scale` < 0.
+#' Also known as Gaussian rounding or bankers' rounding that rounds to the nearest even number.
+#' bround(2.5, 0) = 2, bround(3.5, 0) = 4.
+#'
+#' @rdname bround
+#' @name bround
+#' @family math_funcs
+#' @export
+#' @examples \dontrun{bround(df$c, 0)}
+setMethod("bround",
+          signature(x = "Column"),
+          function(x, scale = 0) {
+            jc <- callJStatic("org.apache.spark.sql.functions", "bround", x@jc, as.integer(scale))
+            column(jc)
+          })
+
+
 #' rtrim
 #'
 #' Trim the spaces from right end for the specified string value.
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index a71be55bca..6b67258d77 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -760,6 +760,10 @@ setGeneric("bin", function(x) { standardGeneric("bin") })
 #' @export
 setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") })
 
+#' @rdname bround
+#' @export
+setGeneric("bround", function(x, ...) { standardGeneric("bround") })
+
 #' @rdname cbrt
 #' @export
 setGeneric("cbrt", function(x) { standardGeneric("cbrt") })
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 2f65484fcb..b923ccf6bb 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1087,6 +1087,11 @@ test_that("column functions", {
   expect_equal(collect(select(df, last(df$age, TRUE)))[[1]], 19)
   expect_equal(collect(select(df, last("age")))[[1]], 19)
   expect_equal(collect(select(df, last("age", TRUE)))[[1]], 19)
+
+  # Test bround()
+  df <- createDataFrame(sqlContext, data.frame(x = c(2.5, 3.5)))
+  expect_equal(collect(select(df, bround(df$x, 0)))[[1]][1], 2)
+  expect_equal(collect(select(df, bround(df$x, 0)))[[1]][2], 4)
 })
 
 test_that("column binary mathfunctions", {
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 5017ab5b36..dac842c0ce 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -467,16 +467,29 @@ def randn(seed=None):
 @since(1.5)
 def round(col, scale=0):
     """
-    Round the value of `e` to `scale` decimal places if `scale` >= 0
+    Round the given value to `scale` decimal places using HALF_UP rounding mode if `scale` >= 0
     or at integral part when `scale` < 0.
 
-    >>> sqlContext.createDataFrame([(2.546,)], ['a']).select(round('a', 1).alias('r')).collect()
-    [Row(r=2.5)]
+    >>> sqlContext.createDataFrame([(2.5,)], ['a']).select(round('a', 0).alias('r')).collect()
+    [Row(r=3.0)]
     """
     sc = SparkContext._active_spark_context
     return Column(sc._jvm.functions.round(_to_java_column(col), scale))
 
 
+@since(2.0)
+def bround(col, scale=0):
+    """
+    Round the given value to `scale` decimal places using HALF_EVEN rounding mode if `scale` >= 0
+    or at integral part when `scale` < 0.
+
+    >>> sqlContext.createDataFrame([(2.5,)], ['a']).select(bround('a', 0).alias('r')).collect()
+    [Row(r=2.0)]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.bround(_to_java_column(col), scale))
+
+
 @since(1.5)
 def shiftLeft(col, numBits):
     """Shift the given value numBits left.
-- 
GitLab