Skip to content
Snippets Groups Projects
Commit 22683560 authored by Yanbo Liang's avatar Yanbo Liang Committed by Joseph K. Bradley
Browse files

[SPARK-7770] [ML] GBT validationTol change to compare with relative or absolute error

GBT compare ValidateError with tolerance switching between relative and absolute ones, where the former one is relative to the current loss on the training set.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #8549 from yanboliang/spark-7770.
parent 0903c648
No related branches found
No related tags found
No related merge requests found
...@@ -262,7 +262,8 @@ object GradientBoostedTrees extends Logging { ...@@ -262,7 +262,8 @@ object GradientBoostedTrees extends Logging {
validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
validatePredErrorCheckpointer.update(validatePredError) validatePredErrorCheckpointer.update(validatePredError)
val currentValidateError = validatePredError.values.mean() val currentValidateError = validatePredError.values.mean()
if (bestValidateError - currentValidateError < validationTol) { if (bestValidateError - currentValidateError < validationTol * Math.max(
currentValidateError, 0.01)) {
doneLearning = true doneLearning = true
} else if (currentValidateError < bestValidateError) { } else if (currentValidateError < bestValidateError) {
bestValidateError = currentValidateError bestValidateError = currentValidateError
......
...@@ -34,9 +34,16 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} ...@@ -34,9 +34,16 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
* weak hypotheses used in the final model. * weak hypotheses used in the final model.
* @param learningRate Learning rate for shrinking the contribution of each estimator. The * @param learningRate Learning rate for shrinking the contribution of each estimator. The
* learning rate should be between in the interval (0, 1] * learning rate should be between in the interval (0, 1]
* @param validationTol Useful when runWithValidation is used. If the error rate on the * @param validationTol validationTol is a condition which decides iteration termination when
* validation input between two iterations is less than the validationTol * runWithValidation is used.
* then stop. Ignored when * The end of iteration is decided based on below logic:
* If the current loss on the validation set is > 0.01, the diff
* of validation error is compared to relative tolerance which is
* validationTol * (current loss on the validation set).
* If the current loss on the validation set is <= 0.01, the diff
* of validation error is compared to absolute tolerance which is
* validationTol * 0.01.
* Ignored when
* [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used. * [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used.
*/ */
@Since("1.2.0") @Since("1.2.0")
...@@ -48,7 +55,7 @@ case class BoostingStrategy @Since("1.4.0") ( ...@@ -48,7 +55,7 @@ case class BoostingStrategy @Since("1.4.0") (
// Optional boosting parameters // Optional boosting parameters
@Since("1.2.0") @BeanProperty var numIterations: Int = 100, @Since("1.2.0") @BeanProperty var numIterations: Int = 100,
@Since("1.2.0") @BeanProperty var learningRate: Double = 0.1, @Since("1.2.0") @BeanProperty var learningRate: Double = 0.1,
@Since("1.4.0") @BeanProperty var validationTol: Double = 1e-5) extends Serializable { @Since("1.4.0") @BeanProperty var validationTol: Double = 0.001) extends Serializable {
/** /**
* Check validity of parameters. * Check validity of parameters.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment