Skip to content
Snippets Groups Projects
Commit 8daf10e3 authored by Yanbo Liang's avatar Yanbo Liang
Browse files

[SPARK-19155][ML] MLlib GeneralizedLinearRegression family and link should case insensitive

## What changes were proposed in this pull request?
MLlib ```GeneralizedLinearRegression``` ```family``` and ```link``` should be case insensitive. This is consistent with some other MLlib params such as [```featureSubsetStrategy```](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala#L415

).

## How was this patch tested?
Update corresponding tests.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #16516 from yanboliang/spark-19133.

(cherry picked from commit 3dcad9fa)
Signed-off-by: default avatarYanbo Liang <ybliang8@gmail.com>
parent 6f0ad575
No related branches found
No related tags found
No related merge requests found
...@@ -57,7 +57,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam ...@@ -57,7 +57,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
final val family: Param[String] = new Param(this, "family", final val family: Param[String] = new Param(this, "family",
"The name of family which is a description of the error distribution to be used in the " + "The name of family which is a description of the error distribution to be used in the " +
s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.", s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.",
ParamValidators.inArray[String](supportedFamilyNames.toArray)) (value: String) => supportedFamilyNames.contains(value.toLowerCase))
/** @group getParam */ /** @group getParam */
@Since("2.0.0") @Since("2.0.0")
...@@ -74,7 +74,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam ...@@ -74,7 +74,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
final val link: Param[String] = new Param(this, "link", "The name of link function " + final val link: Param[String] = new Param(this, "link", "The name of link function " +
"which provides the relationship between the linear predictor and the mean of the " + "which provides the relationship between the linear predictor and the mean of the " +
s"distribution function. Supported options: ${supportedLinkNames.mkString(", ")}", s"distribution function. Supported options: ${supportedLinkNames.mkString(", ")}",
ParamValidators.inArray[String](supportedLinkNames.toArray)) (value: String) => supportedLinkNames.contains(value.toLowerCase))
/** @group getParam */ /** @group getParam */
@Since("2.0.0") @Since("2.0.0")
...@@ -405,7 +405,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine ...@@ -405,7 +405,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
* @param name family name: "gaussian", "binomial", "poisson" or "gamma". * @param name family name: "gaussian", "binomial", "poisson" or "gamma".
*/ */
def fromName(name: String): Family = { def fromName(name: String): Family = {
name match { name.toLowerCase match {
case Gaussian.name => Gaussian case Gaussian.name => Gaussian
case Binomial.name => Binomial case Binomial.name => Binomial
case Poisson.name => Poisson case Poisson.name => Poisson
...@@ -609,7 +609,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine ...@@ -609,7 +609,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
* "inverse", "probit", "cloglog" or "sqrt". * "inverse", "probit", "cloglog" or "sqrt".
*/ */
def fromName(name: String): Link = { def fromName(name: String): Link = {
name match { name.toLowerCase match {
case Identity.name => Identity case Identity.name => Identity
case Logit.name => Logit case Logit.name => Logit
case Log.name => Log case Log.name => Log
......
...@@ -553,7 +553,7 @@ class GeneralizedLinearRegressionSuite ...@@ -553,7 +553,7 @@ class GeneralizedLinearRegressionSuite
for ((link, dataset) <- Seq(("inverse", datasetGammaInverse), for ((link, dataset) <- Seq(("inverse", datasetGammaInverse),
("identity", datasetGammaIdentity), ("log", datasetGammaLog))) { ("identity", datasetGammaIdentity), ("log", datasetGammaLog))) {
for (fitIntercept <- Seq(false, true)) { for (fitIntercept <- Seq(false, true)) {
val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link) val trainer = new GeneralizedLinearRegression().setFamily("Gamma").setLink(link)
.setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction") .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction")
val model = trainer.fit(dataset) val model = trainer.fit(dataset)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
...@@ -989,7 +989,7 @@ class GeneralizedLinearRegressionSuite ...@@ -989,7 +989,7 @@ class GeneralizedLinearRegressionSuite
-0.6344390 0.3172195 0.2114797 -0.1586097 -0.6344390 0.3172195 0.2114797 -0.1586097
*/ */
val trainer = new GeneralizedLinearRegression() val trainer = new GeneralizedLinearRegression()
.setFamily("gamma") .setFamily("Gamma")
.setWeightCol("weight") .setWeightCol("weight")
val model = trainer.fit(datasetWithWeight) val model = trainer.fit(datasetWithWeight)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment