Skip to content
Snippets Groups Projects
Commit 9970aa09 authored by Zheng RuiFeng's avatar Zheng RuiFeng Committed by Yanbo Liang
Browse files

[SPARK-20669][ML] LoR.family and LDA.optimizer should be case insensitive

## What changes were proposed in this pull request?
make param `family` in LoR and `optimizer` in LDA case insensitive

## How was this patch tested?
updated tests

yanboliang

Author: Zheng RuiFeng <ruifengz@foxmail.com>

Closes #17910 from zhengruifeng/lr_family_lowercase.
parent b0888d1a
No related branches found
No related tags found
No related merge requests found
......@@ -94,7 +94,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
final val family: Param[String] = new Param(this, "family",
"The name of family which is a description of the label distribution to be used in the " +
s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.",
ParamValidators.inArray[String](supportedFamilyNames))
(value: String) => supportedFamilyNames.contains(value.toLowerCase(Locale.ROOT)))
/** @group getParam */
@Since("2.1.0")
......@@ -526,7 +526,7 @@ class LogisticRegression @Since("1.2.0") (
case None => histogram.length
}
val isMultinomial = $(family) match {
val isMultinomial = getFamily.toLowerCase(Locale.ROOT) match {
case "binomial" =>
require(numClasses == 1 || numClasses == 2, s"Binomial family only supports 1 or 2 " +
s"outcome classes but found $numClasses.")
......
......@@ -174,8 +174,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
@Since("1.6.0")
final val optimizer = new Param[String](this, "optimizer", "Optimizer or inference" +
" algorithm used to estimate the LDA model. Supported: " + supportedOptimizers.mkString(", "),
(o: String) =>
ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase(Locale.ROOT)))
(value: String) => supportedOptimizers.contains(value.toLowerCase(Locale.ROOT)))
/** @group getParam */
@Since("1.6.0")
......@@ -325,7 +324,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
s" ${getDocConcentration.length}, but k = $getK. docConcentration must be an array of" +
s" length either 1 (scalar) or k (num topics).")
}
getOptimizer match {
getOptimizer.toLowerCase(Locale.ROOT) match {
case "online" =>
require(getDocConcentration.forall(_ >= 0),
"For Online LDA optimizer, docConcentration values must be >= 0. Found values: " +
......@@ -337,7 +336,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
}
}
if (isSet(topicConcentration)) {
getOptimizer match {
getOptimizer.toLowerCase(Locale.ROOT) match {
case "online" =>
require(getTopicConcentration >= 0, s"For Online LDA optimizer, topicConcentration" +
s" must be >= 0. Found value: $getTopicConcentration")
......@@ -350,17 +349,18 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
}
private[clustering] def getOldOptimizer: OldLDAOptimizer = getOptimizer match {
case "online" =>
new OldOnlineLDAOptimizer()
.setTau0($(learningOffset))
.setKappa($(learningDecay))
.setMiniBatchFraction($(subsamplingRate))
.setOptimizeDocConcentration($(optimizeDocConcentration))
case "em" =>
new OldEMLDAOptimizer()
.setKeepLastCheckpoint($(keepLastCheckpoint))
}
private[clustering] def getOldOptimizer: OldLDAOptimizer =
getOptimizer.toLowerCase(Locale.ROOT) match {
case "online" =>
new OldOnlineLDAOptimizer()
.setTau0($(learningOffset))
.setKappa($(learningDecay))
.setMiniBatchFraction($(subsamplingRate))
.setOptimizeDocConcentration($(optimizeDocConcentration))
case "em" =>
new OldEMLDAOptimizer()
.setKeepLastCheckpoint($(keepLastCheckpoint))
}
}
private object LDAParams {
......
......@@ -2582,6 +2582,17 @@ class LogisticRegressionSuite
assert(expected.coefficients.toArray === actual.coefficients.toArray)
}
}
test("string params should be case-insensitive") {
val lr = new LogisticRegression()
Seq(("AuTo", smallBinaryDataset), ("biNoMial", smallBinaryDataset),
("mulTinomIAl", smallMultinomialDataset)).foreach { case (family, data) =>
lr.setFamily(family)
assert(lr.getFamily === family)
val model = lr.fit(data)
assert(model.getFamily === family)
}
}
}
object LogisticRegressionSuite {
......
......@@ -313,4 +313,14 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
assert(model.getCheckpointFiles.isEmpty)
}
test("string params should be case-insensitive") {
val lda = new LDA()
Seq("eM", "oNLinE").foreach { optimizer =>
lda.setOptimizer(optimizer)
assert(lda.getOptimizer === optimizer)
val model = lda.fit(dataset)
assert(model.getOptimizer === optimizer)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment