diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala index 7ac7c225e5acb65db5b122a75160f0f0d716d40c..929374eda13a85746569b051e62bdc6d73ad9f69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala @@ -39,9 +39,13 @@ private[ml] trait DifferentiableRegularization[T] extends DiffFunction[T] { * * @param regParam The magnitude of the regularization. * @param shouldApply A function (Int => Boolean) indicating whether a given index should have - * regularization applied to it. + * regularization applied to it. Usually we don't apply regularization to + * the intercept. * @param applyFeaturesStd Option for a function which maps coefficient index (column major) to the - * feature standard deviation. If `None`, no standardization is applied. + * feature standard deviation. Since we always standardize the data during + * training, if `standardization` is false, we have to reverse + * standardization by penalizing each component differently by this param. + * If `standardization` is true, this should be `None`. */ private[ml] class L2Regularization( override val regParam: Double, @@ -57,6 +61,11 @@ private[ml] class L2Regularization( val coef = coefficients(j) applyFeaturesStd match { case Some(getStd) => + // If `standardization` is false, we still standardize the data + // to improve the rate of convergence; as a result, we have to + // perform this reverse standardization by penalizing each component + // differently to get effectively the same objective function when + // the training dataset is not standardized. val std = getStd(j) if (std != 0.0) { val temp = coef / (std * std) @@ -66,6 +75,7 @@ private[ml] class L2Regularization( 0.0 } case None => + // If `standardization` is true, compute L2 regularization normally. sum += coef * coef gradient(j) = coef * regParam }