diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index a6912056395d7e4ae3fe3d188236fef5ebeffba6..0857877951c82dc557e8990f6f3f19445e04ac4f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -160,14 +160,15 @@ object GradientDescent extends Logging { val stochasticLossHistory = new ArrayBuffer[Double](numIterations) val numExamples = data.count() - val miniBatchSize = numExamples * miniBatchFraction // if no data, return initial weights to avoid NaNs if (numExamples == 0) { - - logInfo("GradientDescent.runMiniBatchSGD returning initial weights, no data found") + logWarning("GradientDescent.runMiniBatchSGD returning initial weights, no data found") return (initialWeights, stochasticLossHistory.toArray) + } + if (numExamples * miniBatchFraction < 1) { + logWarning("The miniBatchFraction is too small") } // Initialize weights as a column vector @@ -185,25 +186,31 @@ object GradientDescent extends Logging { val bcWeights = data.context.broadcast(weights) // Sample a subset (fraction miniBatchFraction) of the total data // compute and sum up the subgradients on this subset (this is one map-reduce) - val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i) - .treeAggregate((BDV.zeros[Double](n), 0.0))( - seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => - val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad)) - (grad, loss + l) + val (gradientSum, lossSum, miniBatchSize) = data.sample(false, miniBatchFraction, 42 + i) + .treeAggregate((BDV.zeros[Double](n), 0.0, 0L))( + seqOp = (c, v) => { + // c: (grad, loss, count), v: (label, features) + val l = gradient.compute(v._2, v._1, bcWeights.value, Vectors.fromBreeze(c._1)) + (c._1, c._2 + l, c._3 + 1) }, - combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => - (grad1 += grad2, loss1 + loss2) + combOp = (c1, c2) => { + // c: (grad, loss, count) + (c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3) }) - /** - * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration - * and regVal is the regularization value computed in the previous iteration as well. - */ - stochasticLossHistory.append(lossSum / miniBatchSize + regVal) - val update = updater.compute( - weights, Vectors.fromBreeze(gradientSum / miniBatchSize), stepSize, i, regParam) - weights = update._1 - regVal = update._2 + if (miniBatchSize > 0) { + /** + * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration + * and regVal is the regularization value computed in the previous iteration as well. + */ + stochasticLossHistory.append(lossSum / miniBatchSize + regVal) + val update = updater.compute( + weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), stepSize, i, regParam) + weights = update._1 + regVal = update._2 + } else { + logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero") + } } logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format(