Skip to content
Snippets Groups Projects
Commit fe8a3546 authored by Patrick Wendell's avatar Patrick Wendell
Browse files

Merge pull request #459 from srowen/UpdaterL2Regularization

Correct L2 regularized weight update with canonical form

Per thread on the user@ mailing list, and comments from Ameet, I believe the weight update for L2 regularization needs to be corrected. See http://mail-archives.apache.org/mod_mbox/spark-user/201401.mbox/%3CCAH3_EVMetuQuhj3__NdUniDLc4P-FMmmrmxw9TS14or8nT4BNQ%40mail.gmail.com%3E
parents 73dfd42f e91ad3f1
No related branches found
No related tags found
No related merge requests found
......@@ -86,13 +86,17 @@ class L1Updater extends Updater {
/**
* Updater that adjusts the learning rate and performs L2 regularization
*
* See, for example, explanation of gradient and loss with L2 regularization on slide 21-22
* of <a href="http://people.cs.umass.edu/~sheldon/teaching/2012fa/ml/files/lec7-annotated.pdf">
* these slides</a>.
*/
class SquaredL2Updater extends Updater {
override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = {
val thisIterStepSize = stepSize / math.sqrt(iter)
val normGradient = gradient.mul(thisIterStepSize)
val newWeights = weightsOld.sub(normGradient).div(2.0 * thisIterStepSize * regParam + 1.0)
val newWeights = weightsOld.mul(1.0 - 2.0 * thisIterStepSize * regParam).sub(normGradient)
(newWeights, pow(newWeights.norm2, 2.0) * regParam)
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment