From 60e2d9e2902b132b14191c9791c71e8f0d42ce9d Mon Sep 17 00:00:00 2001
From: DB Tsai <dbtsai@alpinenow.com>
Date: Wed, 7 Jan 2015 10:13:41 -0800
Subject: [PATCH] [SPARK-5128][MLLib] Add common used log1pExp API in MLUtils

When `x` is positive and large, computing `math.log(1 + math.exp(x))` will lead to arithmetic
overflow. This will happen when `x > 709.78` which is not a very large number.
It can be addressed by rewriting the formula into `x + math.log1p(math.exp(-x))` when `x > 0`.

Author: DB Tsai <dbtsai@alpinenow.com>

Closes #3915 from dbtsai/mathutil and squashes the following commits:

bec6a84 [DB Tsai] remove empty line
3239541 [DB Tsai] revert part of patch into another PR
23144f3 [DB Tsai] doc
49f3658 [DB Tsai] temp
6c29ed3 [DB Tsai] formating
f8447f9 [DB Tsai] address another overflow issue in gradientMultiplier in LOR gradient code
64eefd0 [DB Tsai] first commit
---
 .../spark/mllib/optimization/Gradient.scala   | 19 ++++++++-----------
 .../spark/mllib/tree/loss/LogLoss.scala       | 10 +++-------
 .../org/apache/spark/mllib/util/MLUtils.scala | 16 ++++++++++++++++
 .../spark/mllib/util/MLUtilsSuite.scala       | 13 ++++++++++---
 4 files changed, 37 insertions(+), 21 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
index aaacf3a8a2..1ca0f36c6a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.optimization
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
+import org.apache.spark.mllib.util.MLUtils
 
 /**
  * :: DeveloperApi ::
@@ -64,17 +65,12 @@ class LogisticGradient extends Gradient {
     val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
     val gradient = data.copy
     scal(gradientMultiplier, gradient)
-    val minusYP = if (label > 0) margin else -margin
-
-    // log1p is log(1+p) but more accurate for small p
-    // Following two equations are the same analytically but not numerically, e.g.,
-    // math.log1p(math.exp(1000)) == Infinity
-    // 1000 + math.log1p(math.exp(-1000)) == 1000.0
     val loss =
-      if (minusYP < 0) {
-        math.log1p(math.exp(minusYP))
+      if (label > 0) {
+        // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
+        MLUtils.log1pExp(margin)
       } else {
-        math.log1p(math.exp(-minusYP)) + minusYP
+        MLUtils.log1pExp(margin) - margin
       }
 
     (gradient, loss)
@@ -89,9 +85,10 @@ class LogisticGradient extends Gradient {
     val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
     axpy(gradientMultiplier, data, cumGradient)
     if (label > 0) {
-      math.log1p(math.exp(margin))
+      // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
+      MLUtils.log1pExp(margin)
     } else {
-      math.log1p(math.exp(margin)) - margin
+      MLUtils.log1pExp(margin) - margin
     }
   }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index 7ce9fa6f86..55213e6956 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree.loss
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.model.TreeEnsembleModel
+import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
 
 /**
@@ -61,13 +62,8 @@ object LogLoss extends Loss {
     data.map { case point =>
       val prediction = model.predict(point.features)
       val margin = 2.0 * point.label * prediction
-      // The following are equivalent to 2.0 * log(1 + exp(-margin)) but are more numerically
-      // stable.
-      if (margin >= 0) {
-        2.0 * math.log1p(math.exp(-margin))
-      } else {
-        2.0 * (-margin + math.log1p(math.exp(margin)))
-      }
+      // The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
+      2.0 * MLUtils.log1pExp(-margin)
     }.mean()
   }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index c7843464a7..5d6ddd47f6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -322,4 +322,20 @@ object MLUtils {
     }
     sqDist
   }
+
+  /**
+   * When `x` is positive and large, computing `math.log(1 + math.exp(x))` will lead to arithmetic
+   * overflow. This will happen when `x > 709.78` which is not a very large number.
+   * It can be addressed by rewriting the formula into `x + math.log1p(math.exp(-x))` when `x > 0`.
+   *
+   * @param x a floating-point value as input.
+   * @return the result of `math.log(1 + math.exp(x))`.
+   */
+  private[mllib] def log1pExp(x: Double): Double = {
+    if (x > 0) {
+      x + math.log1p(math.exp(-x))
+    } else {
+      math.log1p(math.exp(x))
+    }
+  }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 7778847f8b..668fc1d43c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -20,18 +20,17 @@ package org.apache.spark.mllib.util
 import java.io.File
 
 import scala.io.Source
-import scala.math
 
 import org.scalatest.FunSuite
 
-import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm,
-  squaredDistance => breezeSquaredDistance}
+import breeze.linalg.{squaredDistance => breezeSquaredDistance}
 import com.google.common.base.Charsets
 import com.google.common.io.Files
 
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.MLUtils._
+import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.util.Utils
 
 class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
@@ -204,4 +203,12 @@ class MLUtilsSuite extends FunSuite with MLlibTestSparkContext {
     assert(points.collect().toSet === loaded.collect().toSet)
     Utils.deleteRecursively(tempDir)
   }
+
+  test("log1pExp") {
+    assert(log1pExp(76.3) ~== math.log1p(math.exp(76.3)) relTol 1E-10)
+    assert(log1pExp(87296763.234) ~== 87296763.234 relTol 1E-10)
+
+    assert(log1pExp(-13.8) ~== math.log1p(math.exp(-13.8)) absTol 1E-10)
+    assert(log1pExp(-238423789.865) ~== math.log1p(math.exp(-238423789.865)) absTol 1E-10)
+  }
 }
-- 
GitLab