diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 95ed48cea67167601d2694c40df2a06c6bd597e8..66a07e31360d897c0e60d2c1a1f4c090aeadd40c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -262,7 +262,8 @@ object GradientBoostedTrees extends Logging { validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) validatePredErrorCheckpointer.update(validatePredError) val currentValidateError = validatePredError.values.mean() - if (bestValidateError - currentValidateError < validationTol) { + if (bestValidateError - currentValidateError < validationTol * Math.max( + currentValidateError, 0.01)) { doneLearning = true } else if (currentValidateError < bestValidateError) { bestValidateError = currentValidateError diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index b5c72fba3ede1fa1e42b8dc97b57d4a26ec65ba8..fc13bcfd8e998719267b307951dacf60404095b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -34,9 +34,16 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * weak hypotheses used in the final model. * @param learningRate Learning rate for shrinking the contribution of each estimator. The * learning rate should be between in the interval (0, 1] - * @param validationTol Useful when runWithValidation is used. If the error rate on the - * validation input between two iterations is less than the validationTol - * then stop. Ignored when + * @param validationTol validationTol is a condition which decides iteration termination when + * runWithValidation is used. + * The end of iteration is decided based on below logic: + * If the current loss on the validation set is > 0.01, the diff + * of validation error is compared to relative tolerance which is + * validationTol * (current loss on the validation set). + * If the current loss on the validation set is <= 0.01, the diff + * of validation error is compared to absolute tolerance which is + * validationTol * 0.01. + * Ignored when * [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used. */ @Since("1.2.0") @@ -48,7 +55,7 @@ case class BoostingStrategy @Since("1.4.0") ( // Optional boosting parameters @Since("1.2.0") @BeanProperty var numIterations: Int = 100, @Since("1.2.0") @BeanProperty var learningRate: Double = 0.1, - @Since("1.4.0") @BeanProperty var validationTol: Double = 1e-5) extends Serializable { + @Since("1.4.0") @BeanProperty var validationTol: Double = 0.001) extends Serializable { /** * Check validity of parameters.