Skip to content
Snippets Groups Projects
Commit fcb3e186 authored by Basin's avatar Basin Committed by Xiangrui Meng
Browse files

[SPARK-5317]Set BoostingStrategy.defaultParams With Enumeration...

[SPARK-5317]Set BoostingStrategy.defaultParams With Enumeration Algo.Classification or Algo.Regression

JIRA Issue: https://issues.apache.org/jira/browse/SPARK-5317
When setting the BoostingStrategy.defaultParams("Classification"), It's more straightforward to set it with the Enumeration Algo.Classification, just like BoostingStragety.defaultParams(Algo.Classification).
I overload the method BoostingStragety.defaultParams().

Author: Basin <jpsachilles@gmail.com>

Closes #4103 from Peishen-Jia/stragetyAlgo and squashes the following commits:

87bab1c [Basin] Docs and Code documentations updated.
3b72875 [Basin] defaultParams(algoStr: String) call defaultParams(algo: Algo).
7c1e6ee [Basin] Doc of Java updated. algo -> algoStr instead.
d5c8a2e [Basin] Merge branch 'stragetyAlgo' of github.com:Peishen-Jia/spark into stragetyAlgo
65f96ce [Basin] mllib-ensembles doc modified.
e04a5aa [Basin] boostingstrategy.defaultParam string algo to enumeration.
68cf544 [Basin] mllib-ensembles doc modified.
a4aea51 [Basin] boostingstrategy.defaultParam string algo to enumeration.
parent ca7910d6
No related branches found
No related tags found
No related merge requests found
......@@ -68,6 +68,15 @@ case class BoostingStrategy(
@Experimental
object BoostingStrategy {
/**
* Returns default configuration for the boosting algorithm
* @param algo Learning goal. Supported: "Classification" or "Regression"
* @return Configuration for boosting algorithm
*/
def defaultParams(algo: String): BoostingStrategy = {
defaultParams(Algo.fromString(algo))
}
/**
* Returns default configuration for the boosting algorithm
* @param algo Learning goal. Supported:
......@@ -75,15 +84,15 @@ object BoostingStrategy {
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
* @return Configuration for boosting algorithm
*/
def defaultParams(algo: String): BoostingStrategy = {
val treeStrategy = Strategy.defaultStrategy(algo)
treeStrategy.maxDepth = 3
def defaultParams(algo: Algo): BoostingStrategy = {
val treeStragtegy = Strategy.defaultStategy(algo)
treeStragtegy.maxDepth = 3
algo match {
case "Classification" =>
treeStrategy.numClasses = 2
new BoostingStrategy(treeStrategy, LogLoss)
case "Regression" =>
new BoostingStrategy(treeStrategy, SquaredError)
case Algo.Classification =>
treeStragtegy.numClasses = 2
new BoostingStrategy(treeStragtegy, LogLoss)
case Algo.Regression =>
new BoostingStrategy(treeStragtegy, SquaredError)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
}
......
......@@ -173,11 +173,19 @@ object Strategy {
* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
* @param algo "Classification" or "Regression"
*/
def defaultStrategy(algo: String): Strategy = algo match {
case "Classification" =>
def defaultStrategy(algo: String): Strategy = {
defaultStategy(Algo.fromString(algo))
}
/**
* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
* @param algo Algo.Classification or Algo.Regression
*/
def defaultStategy(algo: Algo): Strategy = algo match {
case Algo.Classification =>
new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
numClasses = 2)
case "Regression" =>
case Algo.Regression =>
new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
numClasses = 0)
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment