diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala index 4c51f4f881f76e1f02b8342219a1fd94ffed3a9c..37124f261eeb90f0aff06a9478adf84f009f024d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -86,13 +86,17 @@ class L1Updater extends Updater { /** * Updater that adjusts the learning rate and performs L2 regularization + * + * See, for example, explanation of gradient and loss with L2 regularization on slide 21-22 + * of <a href="http://people.cs.umass.edu/~sheldon/teaching/2012fa/ml/files/lec7-annotated.pdf"> + * these slides</a>. */ class SquaredL2Updater extends Updater { override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { val thisIterStepSize = stepSize / math.sqrt(iter) val normGradient = gradient.mul(thisIterStepSize) - val newWeights = weightsOld.sub(normGradient).div(2.0 * thisIterStepSize * regParam + 1.0) + val newWeights = weightsOld.mul(1.0 - 2.0 * thisIterStepSize * regParam).sub(normGradient) (newWeights, pow(newWeights.norm2, 2.0) * regParam) } }