From e21acc1978a6f4a57ef2e08490692b0ffe05fa9e Mon Sep 17 00:00:00 2001
From: Liang-Chi Hsieh <viirya@gmail.com>
Date: Tue, 6 Jan 2015 21:23:31 -0800
Subject: [PATCH] [SPARK-5099][Mllib] Simplify logistic loss function

This is a minor pr where I think that we can simply take minus of `margin`, instead of subtracting  `margin`.

Mathematically, they are equal. But the modified equation is the common form of logistic loss function and so more readable. It also computes more accurate value as some quick tests show.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #3899 from viirya/logit_func and squashes the following commits:

91a3860 [Liang-Chi Hsieh] Modified for comment.
0aa51e4 [Liang-Chi Hsieh] Further simplified.
72a295e [Liang-Chi Hsieh] Revert LogLoss back and add more considerations in Logistic Loss.
a3f83ca [Liang-Chi Hsieh] Fix a bug.
2bc5712 [Liang-Chi Hsieh] Simplify loss function.
---
 .../apache/spark/mllib/optimization/Gradient.scala   | 12 +++++++++---
 1 file changed, 9 insertions(+), 3 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
index 5a419d1640..aaacf3a8a2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
@@ -64,11 +64,17 @@ class LogisticGradient extends Gradient {
     val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
     val gradient = data.copy
     scal(gradientMultiplier, gradient)
+    val minusYP = if (label > 0) margin else -margin
+
+    // log1p is log(1+p) but more accurate for small p
+    // Following two equations are the same analytically but not numerically, e.g.,
+    // math.log1p(math.exp(1000)) == Infinity
+    // 1000 + math.log1p(math.exp(-1000)) == 1000.0
     val loss =
-      if (label > 0) {
-        math.log1p(math.exp(margin)) // log1p is log(1+p) but more accurate for small p
+      if (minusYP < 0) {
+        math.log1p(math.exp(minusYP))
       } else {
-        math.log1p(math.exp(margin)) - margin
+        math.log1p(math.exp(-minusYP)) + minusYP
       }
 
     (gradient, loss)
-- 
GitLab