Skip to content
Snippets Groups Projects
Commit f515f943 authored by GuoQiang Li's avatar GuoQiang Li Committed by Xiangrui Meng
Browse files

[SPARK-4526][MLLIB]GradientDescent get a wrong gradient value according to the gradient formula.

This is caused by the miniBatchSize parameter.The number of `RDD.sample` returns is not fixed.
cc mengxr

Author: GuoQiang Li <witgo@qq.com>

Closes #3399 from witgo/GradientDescent and squashes the following commits:

13cb228 [GuoQiang Li] review commit
668ab66 [GuoQiang Li] Double to Long
b6aa11a [GuoQiang Li] Check miniBatchSize is greater than 0
0b5c3e3 [GuoQiang Li] Minor fix
12e7424 [GuoQiang Li] GradientDescent get a wrong gradient value according to the gradient formula, which is caused by the miniBatchSize parameter.
parent 89f91226
No related branches found
No related tags found
No related merge requests found
......@@ -160,14 +160,15 @@ object GradientDescent extends Logging {
val stochasticLossHistory = new ArrayBuffer[Double](numIterations)
val numExamples = data.count()
val miniBatchSize = numExamples * miniBatchFraction
// if no data, return initial weights to avoid NaNs
if (numExamples == 0) {
logInfo("GradientDescent.runMiniBatchSGD returning initial weights, no data found")
logWarning("GradientDescent.runMiniBatchSGD returning initial weights, no data found")
return (initialWeights, stochasticLossHistory.toArray)
}
if (numExamples * miniBatchFraction < 1) {
logWarning("The miniBatchFraction is too small")
}
// Initialize weights as a column vector
......@@ -185,25 +186,31 @@ object GradientDescent extends Logging {
val bcWeights = data.context.broadcast(weights)
// Sample a subset (fraction miniBatchFraction) of the total data
// compute and sum up the subgradients on this subset (this is one map-reduce)
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i)
.treeAggregate((BDV.zeros[Double](n), 0.0))(
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad))
(grad, loss + l)
val (gradientSum, lossSum, miniBatchSize) = data.sample(false, miniBatchFraction, 42 + i)
.treeAggregate((BDV.zeros[Double](n), 0.0, 0L))(
seqOp = (c, v) => {
// c: (grad, loss, count), v: (label, features)
val l = gradient.compute(v._2, v._1, bcWeights.value, Vectors.fromBreeze(c._1))
(c._1, c._2 + l, c._3 + 1)
},
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
(grad1 += grad2, loss1 + loss2)
combOp = (c1, c2) => {
// c: (grad, loss, count)
(c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3)
})
/**
* NOTE(Xinghao): lossSum is computed using the weights from the previous iteration
* and regVal is the regularization value computed in the previous iteration as well.
*/
stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
val update = updater.compute(
weights, Vectors.fromBreeze(gradientSum / miniBatchSize), stepSize, i, regParam)
weights = update._1
regVal = update._2
if (miniBatchSize > 0) {
/**
* NOTE(Xinghao): lossSum is computed using the weights from the previous iteration
* and regVal is the regularization value computed in the previous iteration as well.
*/
stochasticLossHistory.append(lossSum / miniBatchSize + regVal)
val update = updater.compute(
weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), stepSize, i, regParam)
weights = update._1
regVal = update._2
} else {
logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero")
}
}
logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment