diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index f137c8cb418944283705d577b4571a425a8a5335..1e7ba91e01989cc340810f61a3ccb5c0aba6834c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -57,7 +57,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam final val family: Param[String] = new Param(this, "family", "The name of family which is a description of the error distribution to be used in the " + s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.", - ParamValidators.inArray[String](supportedFamilyNames.toArray)) + (value: String) => supportedFamilyNames.contains(value.toLowerCase)) /** @group getParam */ @Since("2.0.0") @@ -74,7 +74,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam final val link: Param[String] = new Param(this, "link", "The name of link function " + "which provides the relationship between the linear predictor and the mean of the " + s"distribution function. Supported options: ${supportedLinkNames.mkString(", ")}", - ParamValidators.inArray[String](supportedLinkNames.toArray)) + (value: String) => supportedLinkNames.contains(value.toLowerCase)) /** @group getParam */ @Since("2.0.0") @@ -405,7 +405,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * @param name family name: "gaussian", "binomial", "poisson" or "gamma". */ def fromName(name: String): Family = { - name match { + name.toLowerCase match { case Gaussian.name => Gaussian case Binomial.name => Binomial case Poisson.name => Poisson @@ -609,7 +609,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * "inverse", "probit", "cloglog" or "sqrt". */ def fromName(name: String): Link = { - name match { + name.toLowerCase match { case Identity.name => Identity case Logit.name => Logit case Log.name => Log diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 3e9e1fced8ec4d1428c46da03ee8f39bb205a0ae..415d426af3c12462c7dcafd24c7ff2685b4f2c5b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -553,7 +553,7 @@ class GeneralizedLinearRegressionSuite for ((link, dataset) <- Seq(("inverse", datasetGammaInverse), ("identity", datasetGammaIdentity), ("log", datasetGammaLog))) { for (fitIntercept <- Seq(false, true)) { - val trainer = new GeneralizedLinearRegression().setFamily("gamma").setLink(link) + val trainer = new GeneralizedLinearRegression().setFamily("Gamma").setLink(link) .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction") val model = trainer.fit(dataset) val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) @@ -989,7 +989,7 @@ class GeneralizedLinearRegressionSuite -0.6344390 0.3172195 0.2114797 -0.1586097 */ val trainer = new GeneralizedLinearRegression() - .setFamily("gamma") + .setFamily("Gamma") .setWeightCol("weight") val model = trainer.fit(datasetWithWeight)