From bc8890b357811612ba6c10d96374902b9e08134f Mon Sep 17 00:00:00 2001 From: Gary King <gary@idibon.com> Date: Sun, 7 Feb 2016 09:13:28 +0000 Subject: [PATCH] [SPARK-13132][MLLIB] cache standardization param value in LogisticRegression cache the value of the standardization Param in LogisticRegression, rather than re-fetching it from the ParamMap for every index and every optimization step in the quasi-newton optimizer also, fix Param#toString to cache the stringified representation, rather than re-interpolating it on every call, so any other implementations that have similar repeated access patterns will see a benefit. this change improves training times for one of my test sets from ~7m30s to ~4m30s Author: Gary King <gary@idibon.com> Closes #11027 from idigary/spark-13132-optimize-logistic-regression. --- .../apache/spark/ml/classification/LogisticRegression.scala | 3 ++- mllib/src/main/scala/org/apache/spark/ml/param/params.scala | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 9b2340a1f1..ac0124513f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -332,12 +332,13 @@ class LogisticRegression @Since("1.2.0") ( val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) { new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) } else { + val standardizationParam = $(standardization) def regParamL1Fun = (index: Int) => { // Remove the L1 penalization on the intercept if (index == numFeatures) { 0.0 } else { - if ($(standardization)) { + if (standardizationParam) { regParamL1 } else { // If `standardization` is false, we still standardize the data diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index f48923d699..d7d6c0f5fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -117,7 +117,9 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali } } - override final def toString: String = s"${parent}__$name" + private[this] val stringRepresentation = s"${parent}__$name" + + override final def toString: String = stringRepresentation override final def hashCode: Int = toString.## -- GitLab