Skip to content
Snippets Groups Projects
Commit 22683560 authored by Yanbo Liang's avatar Yanbo Liang Committed by Joseph K. Bradley
Browse files

[SPARK-7770] [ML] GBT validationTol change to compare with relative or absolute error

GBT compare ValidateError with tolerance switching between relative and absolute ones, where the former one is relative to the current loss on the training set.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #8549 from yanboliang/spark-7770.
parent 0903c648
No related branches found
No related tags found
No related merge requests found
......@@ -262,7 +262,8 @@ object GradientBoostedTrees extends Logging {
validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
validatePredErrorCheckpointer.update(validatePredError)
val currentValidateError = validatePredError.values.mean()
if (bestValidateError - currentValidateError < validationTol) {
if (bestValidateError - currentValidateError < validationTol * Math.max(
currentValidateError, 0.01)) {
doneLearning = true
} else if (currentValidateError < bestValidateError) {
bestValidateError = currentValidateError
......
......@@ -34,9 +34,16 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
* weak hypotheses used in the final model.
* @param learningRate Learning rate for shrinking the contribution of each estimator. The
* learning rate should be between in the interval (0, 1]
* @param validationTol Useful when runWithValidation is used. If the error rate on the
* validation input between two iterations is less than the validationTol
* then stop. Ignored when
* @param validationTol validationTol is a condition which decides iteration termination when
* runWithValidation is used.
* The end of iteration is decided based on below logic:
* If the current loss on the validation set is > 0.01, the diff
* of validation error is compared to relative tolerance which is
* validationTol * (current loss on the validation set).
* If the current loss on the validation set is <= 0.01, the diff
* of validation error is compared to absolute tolerance which is
* validationTol * 0.01.
* Ignored when
* [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used.
*/
@Since("1.2.0")
......@@ -48,7 +55,7 @@ case class BoostingStrategy @Since("1.4.0") (
// Optional boosting parameters
@Since("1.2.0") @BeanProperty var numIterations: Int = 100,
@Since("1.2.0") @BeanProperty var learningRate: Double = 0.1,
@Since("1.4.0") @BeanProperty var validationTol: Double = 1e-5) extends Serializable {
@Since("1.4.0") @BeanProperty var validationTol: Double = 0.001) extends Serializable {
/**
* Check validity of parameters.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment