From bc8890b357811612ba6c10d96374902b9e08134f Mon Sep 17 00:00:00 2001
From: Gary King <gary@idibon.com>
Date: Sun, 7 Feb 2016 09:13:28 +0000
Subject: [PATCH] [SPARK-13132][MLLIB] cache standardization param value in
 LogisticRegression

cache the value of the standardization Param in LogisticRegression, rather than re-fetching it from the ParamMap for every index and every optimization step in the quasi-newton optimizer

also, fix Param#toString to cache the stringified representation, rather than re-interpolating it on every call, so any other implementations that have similar repeated access patterns will see a benefit.

this change improves training times for one of my test sets from ~7m30s to ~4m30s

Author: Gary King <gary@idibon.com>

Closes #11027 from idigary/spark-13132-optimize-logistic-regression.
---
 .../apache/spark/ml/classification/LogisticRegression.scala   | 3 ++-
 mllib/src/main/scala/org/apache/spark/ml/param/params.scala   | 4 +++-
 2 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 9b2340a1f1..ac0124513f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -332,12 +332,13 @@ class LogisticRegression @Since("1.2.0") (
         val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
           new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
         } else {
+          val standardizationParam = $(standardization)
           def regParamL1Fun = (index: Int) => {
             // Remove the L1 penalization on the intercept
             if (index == numFeatures) {
               0.0
             } else {
-              if ($(standardization)) {
+              if (standardizationParam) {
                 regParamL1
               } else {
                 // If `standardization` is false, we still standardize the data
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index f48923d699..d7d6c0f5fa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -117,7 +117,9 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
     }
   }
 
-  override final def toString: String = s"${parent}__$name"
+  private[this] val stringRepresentation = s"${parent}__$name"
+
+  override final def toString: String = stringRepresentation
 
   override final def hashCode: Int = toString.##
 
-- 
GitLab