Skip to content
Snippets Groups Projects
Commit 589cce93 authored by Gio Borje's avatar Gio Borje Committed by Sean Owen
Browse files

Log warnings for numIterations * miniBatchFraction < 1.0

## What changes were proposed in this pull request?

Add a warning log for the case that `numIterations * miniBatchFraction <1.0` during gradient descent. If the product of those two numbers is less than `1.0`, then not all training examples will be used during optimization. To put this concretely, suppose that `numExamples = 100`, `miniBatchFraction = 0.2` and `numIterations = 3`. Then, 3 iterations will occur each sampling approximately 6 examples each. In the best case, each of the 6 examples are unique; hence 18/100 examples are used.

This may be counter-intuitive to most users and led to the issue during the development of another Spark  ML model: https://github.com/zhengruifeng/spark-libFM/issues/11. If a user actually does not require the training data set, it would be easier and more intuitive to use `RDD.sample`.

## How was this patch tested?

`build/mvn -DskipTests clean package` build succeeds

Author: Gio Borje <gborje@linkedin.com>

Closes #13265 from Hydrotoast/master.
parent 9c297df3
No related branches found
No related tags found
No related merge requests found
......@@ -197,6 +197,11 @@ object GradientDescent extends Logging {
"< 1.0 can be unstable because of the stochasticity in sampling.")
}
if (numIterations * miniBatchFraction < 1.0) {
logWarning("Not all examples will be used if numIterations * miniBatchFraction < 1.0: " +
s"numIterations=$numIterations and miniBatchFraction=$miniBatchFraction")
}
val stochasticLossHistory = new ArrayBuffer[Double](numIterations)
// Record previous weight and current one to calculate solution vector difference
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment