Skip to content
Snippets Groups Projects
Commit 1e1ffb19 authored by atalwalkar's avatar atalwalkar
Browse files

Merge pull request #745 from shivaram/loss-update-fix

Remove duplicate loss history in Gradient Descent
parents 207548b6 3ca9faa3
No related branches found
No related tags found
No related merge requests found
......@@ -151,7 +151,6 @@ object LogisticRegressionLocalRandomSGD {
input: RDD[(Int, Array[Double])],
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
initialWeights: Array[Double])
: LogisticRegressionModel =
......@@ -174,7 +173,6 @@ object LogisticRegressionLocalRandomSGD {
input: RDD[(Int, Array[Double])],
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double)
: LogisticRegressionModel =
{
......@@ -195,8 +193,7 @@ object LogisticRegressionLocalRandomSGD {
def train(
input: RDD[(Int, Array[Double])],
numIterations: Int,
stepSize: Double
)
stepSize: Double)
: LogisticRegressionModel =
{
train(input, numIterations, stepSize, 1.0)
......
......@@ -61,7 +61,7 @@ object GradientDescent {
// Initialize weights as a column vector
var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights:_*)
var reg_val = 0.0
var regVal = 0.0
for (i <- 1 to numIters) {
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42+i).map {
......@@ -71,15 +71,14 @@ object GradientDescent {
(grad, loss)
}.reduce((a, b) => (a._1.addi(b._1), a._2 + b._2))
stochasticLossHistory.append(lossSum / miniBatchSize + reg_val)
/**
* NOTE(Xinghao): lossSum is computed using the weights from the previous iteration
* and regVal is the regularization value computed in the previous iteration as well.
*/
stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
val update = updater.compute(weights, gradientSum.div(miniBatchSize), stepSize, i, regParam)
weights = update._1
reg_val = update._2
stochasticLossHistory.append(lossSum / miniBatchSize + reg_val)
/*
* NOTE(Xinghao): The loss here is sum of lossSum computed using the weights before applying updater,
* and reg_val using weights after applying updater
*/
regVal = update._2
}
(weights.toArray, stochasticLossHistory.toArray)
......
......@@ -23,6 +23,7 @@ import org.jblas.DoubleMatrix
abstract class Updater extends Serializable {
/**
* Compute an updated value for weights given the gradient, stepSize and iteration number.
* Also returns the regularization value computed using the *updated* weights.
*
* @param weightsOlds - Column matrix of size nx1 where n is the number of features.
* @param gradient - Column matrix of size nx1 where n is the number of features.
......@@ -31,7 +32,7 @@ abstract class Updater extends Serializable {
* @param regParam - Regularization parameter
*
* @return A tuple of 2 elements. The first element is a column matrix containing updated weights,
* and the second element is the regularization value.
* and the second element is the regularization value computed using updated weights.
*/
def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int, regParam: Double):
(DoubleMatrix, Double)
......@@ -46,13 +47,13 @@ class SimpleUpdater extends Updater {
}
/**
* L1 regularization -- corresponding proximal operator is the soft-thresholding function
* That is, each weight component is shrunk towards 0 by shrinkageVal
* If w > shrinkageVal, set weight component to w-shrinkageVal.
* If w < -shrinkageVal, set weight component to w+shrinkageVal.
* If -shrinkageVal < w < shrinkageVal, set weight component to 0.
* Equivalently, set weight component to signum(w) * max(0.0, abs(w) - shrinkageVal)
**/
* L1 regularization -- corresponding proximal operator is the soft-thresholding function
* That is, each weight component is shrunk towards 0 by shrinkageVal
* If w > shrinkageVal, set weight component to w-shrinkageVal.
* If w < -shrinkageVal, set weight component to w+shrinkageVal.
* If -shrinkageVal < w < shrinkageVal, set weight component to 0.
* Equivalently, set weight component to signum(w) * max(0.0, abs(w) - shrinkageVal)
*/
class L1Updater extends Updater {
override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = {
......@@ -76,7 +77,7 @@ class SquaredL2Updater extends Updater {
val thisIterStepSize = stepSize / math.sqrt(iter)
val normGradient = gradient.mul(thisIterStepSize)
val newWeights = weightsOld.sub(normGradient).div(2.0 * thisIterStepSize * regParam + 1.0)
(newWeights, pow(newWeights.norm2,2.0) * regParam)
(newWeights, pow(newWeights.norm2, 2.0) * regParam)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment