diff --git a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala index e916a92c332088638c02966ab6ffdf07ca417217..bf506d2f24764b2079e65bb0e4fa89fb677d328e 100644 --- a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala @@ -41,7 +41,8 @@ abstract class Updater extends Serializable { class SimpleUpdater extends Updater { override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { - val normGradient = gradient.mul(stepSize / math.sqrt(iter)) + val thisIterStepSize = stepSize / math.sqrt(iter) + val normGradient = gradient.mul(thisIterStepSize) (weightsOld.sub(normGradient), 0) } }