Skip to content
Snippets Groups Projects
Commit 64743870 authored by Yanbo Liang's avatar Yanbo Liang Committed by Joseph K. Bradley
Browse files

[SPARK-10394] [ML] Make GBTParams use shared stepSize

```GBTParams``` has ```stepSize``` as learning rate currently.
ML has shared param class ```HasStepSize```, ```GBTParams``` can extend from it rather than duplicated implementation.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #8552 from yanboliang/spark-10394.
parent aad644fb
No related branches found
No related tags found
No related merge requests found
...@@ -20,7 +20,7 @@ package org.apache.spark.ml.tree ...@@ -20,7 +20,7 @@ package org.apache.spark.ml.tree
import org.apache.spark.ml.classification.ClassifierParams import org.apache.spark.ml.classification.ClassifierParams
import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param._ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasMaxIter, HasSeed, HasThresholds} import org.apache.spark.ml.param.shared._
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
...@@ -365,17 +365,7 @@ private[ml] object RandomForestParams { ...@@ -365,17 +365,7 @@ private[ml] object RandomForestParams {
* *
* Note: Marked as private and DeveloperApi since this may be made public in the future. * Note: Marked as private and DeveloperApi since this may be made public in the future.
*/ */
private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize {
/**
* Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each
* estimator.
* (default = 0.1)
* @group param
*/
final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." +
" learning rate) in interval (0, 1] for shrinking the contribution of each estimator",
ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
/* TODO: Add this doc when we add this param. SPARK-7132 /* TODO: Add this doc when we add this param. SPARK-7132
* Threshold for stopping early when runWithValidation is used. * Threshold for stopping early when runWithValidation is used.
...@@ -393,11 +383,19 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { ...@@ -393,11 +383,19 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
/** @group setParam */ /** @group setParam */
def setMaxIter(value: Int): this.type = set(maxIter, value) def setMaxIter(value: Int): this.type = set(maxIter, value)
/** @group setParam */ /**
* Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each
* estimator.
* (default = 0.1)
* @group setParam
*/
def setStepSize(value: Double): this.type = set(stepSize, value) def setStepSize(value: Double): this.type = set(stepSize, value)
/** @group getParam */ override def validateParams(): Unit = {
final def getStepSize: Double = $(stepSize) require(ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)(
getStepSize), "GBT parameter stepSize should be in interval (0, 1], " +
s"but it given invalid value $getStepSize.")
}
/** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
private[ml] def getOldBoostingStrategy( private[ml] def getOldBoostingStrategy(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment