Skip to content
Snippets Groups Projects
Commit e91ad3f1 authored by Sean Owen's avatar Sean Owen
Browse files

Correct L2 regularized weight update with canonical form

parent d749d472
No related branches found
No related tags found
No related merge requests found
......@@ -86,13 +86,17 @@ class L1Updater extends Updater {
/**
* Updater that adjusts the learning rate and performs L2 regularization
*
* See, for example, explanation of gradient and loss with L2 regularization on slide 21-22
* of <a href="http://people.cs.umass.edu/~sheldon/teaching/2012fa/ml/files/lec7-annotated.pdf">
* these slides</a>.
*/
class SquaredL2Updater extends Updater {
override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = {
val thisIterStepSize = stepSize / math.sqrt(iter)
val normGradient = gradient.mul(thisIterStepSize)
val newWeights = weightsOld.sub(normGradient).div(2.0 * thisIterStepSize * regParam + 1.0)
val newWeights = weightsOld.mul(1.0 - 2.0 * thisIterStepSize * regParam).sub(normGradient)
(newWeights, pow(newWeights.norm2, 2.0) * regParam)
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment