Skip to content
Snippets Groups Projects
Commit d138aa8e authored by Omede Firouz's avatar Omede Firouz Committed by Joseph K. Bradley
Browse files

[SPARK-6705][MLLIB] Add fit intercept api to ml logisticregression

I have the fit intercept enabled by default for logistic regression, I
wonder what others think here. I understand that it enables allocation
by default which is undesirable, but one needs to have a very strong
reason for not having an intercept term enabled so it is the safer
default from a statistical sense.

Explicitly modeling the intercept by adding a column of all 1s does not
work. I believe the reason is that since the API for
LogisticRegressionWithLBFGS forces column normalization, and a column of all
1s has 0 variance so dividing by 0 kills it.

Author: Omede Firouz <ofirouz@palantir.com>

Closes #5301 from oefirouz/addIntercept and squashes the following commits:

9f1286b [Omede Firouz] [SPARK-6705][MLLIB] Add fitInterceptTerm to LogisticRegression
1d6bd6f [Omede Firouz] [SPARK-6705][MLLIB] Add a fit intercept term to ML LogisticRegression
9963509 [Omede Firouz] [MLLIB] Add fitIntercept to LogisticRegression
2257fca [Omede Firouz] [MLLIB] Add fitIntercept param to logistic regression
329c1e2 [Omede Firouz] [MLLIB] Add fit intercept term
bd9663c [Omede Firouz] [MLLIB] Add fit intercept api to ml logisticregression
parent c83e0394
No related branches found
No related tags found
No related merge requests found
...@@ -31,7 +31,7 @@ import org.apache.spark.storage.StorageLevel ...@@ -31,7 +31,7 @@ import org.apache.spark.storage.StorageLevel
* Params for logistic regression. * Params for logistic regression.
*/ */
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
with HasRegParam with HasMaxIter with HasThreshold with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold
/** /**
...@@ -55,6 +55,9 @@ class LogisticRegression ...@@ -55,6 +55,9 @@ class LogisticRegression
/** @group setParam */ /** @group setParam */
def setMaxIter(value: Int): this.type = set(maxIter, value) def setMaxIter(value: Int): this.type = set(maxIter, value)
/** @group setParam */
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
/** @group setParam */ /** @group setParam */
def setThreshold(value: Double): this.type = set(threshold, value) def setThreshold(value: Double): this.type = set(threshold, value)
...@@ -67,7 +70,8 @@ class LogisticRegression ...@@ -67,7 +70,8 @@ class LogisticRegression
} }
// Train model // Train model
val lr = new LogisticRegressionWithLBFGS val lr = new LogisticRegressionWithLBFGS()
.setIntercept(paramMap(fitIntercept))
lr.optimizer lr.optimizer
.setRegParam(paramMap(regParam)) .setRegParam(paramMap(regParam))
.setNumIterations(paramMap(maxIter)) .setNumIterations(paramMap(maxIter))
......
...@@ -106,6 +106,18 @@ private[ml] trait HasProbabilityCol extends Params { ...@@ -106,6 +106,18 @@ private[ml] trait HasProbabilityCol extends Params {
def getProbabilityCol: String = get(probabilityCol) def getProbabilityCol: String = get(probabilityCol)
} }
private[ml] trait HasFitIntercept extends Params {
/**
* param for fitting the intercept term, defaults to true
* @group param
*/
val fitIntercept: BooleanParam =
new BooleanParam(this, "fitIntercept", "indicates whether to fit an intercept term", Some(true))
/** @group getParam */
def getFitIntercept: Boolean = get(fitIntercept)
}
private[ml] trait HasThreshold extends Params { private[ml] trait HasThreshold extends Params {
/** /**
* param for threshold in (binary) prediction * param for threshold in (binary) prediction
......
...@@ -46,6 +46,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { ...@@ -46,6 +46,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
assert(lr.getPredictionCol == "prediction") assert(lr.getPredictionCol == "prediction")
assert(lr.getRawPredictionCol == "rawPrediction") assert(lr.getRawPredictionCol == "rawPrediction")
assert(lr.getProbabilityCol == "probability") assert(lr.getProbabilityCol == "probability")
assert(lr.getFitIntercept == true)
val model = lr.fit(dataset) val model = lr.fit(dataset)
model.transform(dataset) model.transform(dataset)
.select("label", "probability", "prediction", "rawPrediction") .select("label", "probability", "prediction", "rawPrediction")
...@@ -55,6 +56,14 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { ...@@ -55,6 +56,14 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
assert(model.getPredictionCol == "prediction") assert(model.getPredictionCol == "prediction")
assert(model.getRawPredictionCol == "rawPrediction") assert(model.getRawPredictionCol == "rawPrediction")
assert(model.getProbabilityCol == "probability") assert(model.getProbabilityCol == "probability")
assert(model.intercept !== 0.0)
}
test("logistic regression doesn't fit intercept when fitIntercept is off") {
val lr = new LogisticRegression
lr.setFitIntercept(false)
val model = lr.fit(dataset)
assert(model.intercept === 0.0)
} }
test("logistic regression with setters") { test("logistic regression with setters") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment