Skip to content
Snippets Groups Projects
Commit 2b1111dd authored by Holden Karau's avatar Holden Karau Committed by DB Tsai
Browse files

[SPARK-7888] Be able to disable intercept in linear regression in ml package

Author: Holden Karau <holden@pigscanfly.ca>

Closes #6927 from holdenk/SPARK-7888-Be-able-to-disable-intercept-in-Linear-Regression-in-ML-package and squashes the following commits:

0ad384c [Holden Karau] Add MiMa excludes
4016fac [Holden Karau] Switch to wild card import, remove extra blank lines
ae5baa8 [Holden Karau] CR feedback, move the fitIntercept down rather than changing ymean and etc above
f34971c [Holden Karau] Fix some more long lines
319bd3f [Holden Karau] Fix long lines
3bb9ee1 [Holden Karau] Update the regression suite tests
7015b9f [Holden Karau] Our code performs the same with R, except we need more than one data point but that seems reasonable
0b0c8c0 [Holden Karau] fix the issue with the sample R code
e2140ba [Holden Karau] Add a test, it fails!
5e84a0b [Holden Karau] Write out thoughts and use the correct trait
91ffc0a [Holden Karau] more murh
006246c [Holden Karau] murp?
parent 6f4cadf5
No related branches found
No related tags found
No related merge requests found
...@@ -26,7 +26,7 @@ import org.apache.spark.Logging ...@@ -26,7 +26,7 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol} import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.linalg.BLAS._
...@@ -41,7 +41,8 @@ import org.apache.spark.util.StatCounter ...@@ -41,7 +41,8 @@ import org.apache.spark.util.StatCounter
* Params for linear regression. * Params for linear regression.
*/ */
private[regression] trait LinearRegressionParams extends PredictorParams private[regression] trait LinearRegressionParams extends PredictorParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
with HasFitIntercept
/** /**
* :: Experimental :: * :: Experimental ::
...@@ -72,6 +73,14 @@ class LinearRegression(override val uid: String) ...@@ -72,6 +73,14 @@ class LinearRegression(override val uid: String)
def setRegParam(value: Double): this.type = set(regParam, value) def setRegParam(value: Double): this.type = set(regParam, value)
setDefault(regParam -> 0.0) setDefault(regParam -> 0.0)
/**
* Set if we should fit the intercept
* Default is true.
* @group setParam
*/
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
setDefault(fitIntercept -> true)
/** /**
* Set the ElasticNet mixing parameter. * Set the ElasticNet mixing parameter.
* For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
...@@ -123,6 +132,7 @@ class LinearRegression(override val uid: String) ...@@ -123,6 +132,7 @@ class LinearRegression(override val uid: String)
val numFeatures = summarizer.mean.size val numFeatures = summarizer.mean.size
val yMean = statCounter.mean val yMean = statCounter.mean
val yStd = math.sqrt(statCounter.variance) val yStd = math.sqrt(statCounter.variance)
// look at glmnet5.m L761 maaaybe that has info
// If the yStd is zero, then the intercept is yMean with zero weights; // If the yStd is zero, then the intercept is yMean with zero weights;
// as a result, training is not needed. // as a result, training is not needed.
...@@ -142,7 +152,7 @@ class LinearRegression(override val uid: String) ...@@ -142,7 +152,7 @@ class LinearRegression(override val uid: String)
val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam
val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam
val costFun = new LeastSquaresCostFun(instances, yStd, yMean, val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept),
featuresStd, featuresMean, effectiveL2RegParam) featuresStd, featuresMean, effectiveL2RegParam)
val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) { val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
...@@ -180,7 +190,7 @@ class LinearRegression(override val uid: String) ...@@ -180,7 +190,7 @@ class LinearRegression(override val uid: String)
// The intercept in R's GLMNET is computed using closed form after the coefficients are // The intercept in R's GLMNET is computed using closed form after the coefficients are
// converged. See the following discussion for detail. // converged. See the following discussion for detail.
// http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
val intercept = yMean - dot(weights, Vectors.dense(featuresMean)) val intercept = if ($(fitIntercept)) yMean - dot(weights, Vectors.dense(featuresMean)) else 0.0
if (handlePersistence) instances.unpersist() if (handlePersistence) instances.unpersist()
// TODO: Converts to sparse format based on the storage, but may base on the scoring speed. // TODO: Converts to sparse format based on the storage, but may base on the scoring speed.
...@@ -234,6 +244,7 @@ class LinearRegressionModel private[ml] ( ...@@ -234,6 +244,7 @@ class LinearRegressionModel private[ml] (
* See this discussion for detail. * See this discussion for detail.
* http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet * http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
* *
* When training with intercept enabled,
* The objective function in the scaled space is given by * The objective function in the scaled space is given by
* {{{ * {{{
* L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2, * L = 1/2n ||\sum_i w_i(x_i - \bar{x_i}) / \hat{x_i} - (y - \bar{y}) / \hat{y}||^2,
...@@ -241,6 +252,10 @@ class LinearRegressionModel private[ml] ( ...@@ -241,6 +252,10 @@ class LinearRegressionModel private[ml] (
* where \bar{x_i} is the mean of x_i, \hat{x_i} is the standard deviation of x_i, * where \bar{x_i} is the mean of x_i, \hat{x_i} is the standard deviation of x_i,
* \bar{y} is the mean of label, and \hat{y} is the standard deviation of label. * \bar{y} is the mean of label, and \hat{y} is the standard deviation of label.
* *
* If we fitting the intercept disabled (that is forced through 0.0),
* we can use the same equation except we set \bar{y} and \bar{x_i} to 0 instead
* of the respective means.
*
* This can be rewritten as * This can be rewritten as
* {{{ * {{{
* L = 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y} * L = 1/2n ||\sum_i (w_i/\hat{x_i})x_i - \sum_i (w_i/\hat{x_i})\bar{x_i} - y / \hat{y}
...@@ -255,6 +270,7 @@ class LinearRegressionModel private[ml] ( ...@@ -255,6 +270,7 @@ class LinearRegressionModel private[ml] (
* \sum_i w_i^\prime x_i - y / \hat{y} + offset * \sum_i w_i^\prime x_i - y / \hat{y} + offset
* }}} * }}}
* *
*
* Note that the effective weights and offset don't depend on training dataset, * Note that the effective weights and offset don't depend on training dataset,
* so they can be precomputed. * so they can be precomputed.
* *
...@@ -301,6 +317,7 @@ private class LeastSquaresAggregator( ...@@ -301,6 +317,7 @@ private class LeastSquaresAggregator(
weights: Vector, weights: Vector,
labelStd: Double, labelStd: Double,
labelMean: Double, labelMean: Double,
fitIntercept: Boolean,
featuresStd: Array[Double], featuresStd: Array[Double],
featuresMean: Array[Double]) extends Serializable { featuresMean: Array[Double]) extends Serializable {
...@@ -321,7 +338,7 @@ private class LeastSquaresAggregator( ...@@ -321,7 +338,7 @@ private class LeastSquaresAggregator(
} }
i += 1 i += 1
} }
(weightsArray, -sum + labelMean / labelStd, weightsArray.length) (weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0, weightsArray.length)
} }
private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray) private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray)
...@@ -404,6 +421,7 @@ private class LeastSquaresCostFun( ...@@ -404,6 +421,7 @@ private class LeastSquaresCostFun(
data: RDD[(Double, Vector)], data: RDD[(Double, Vector)],
labelStd: Double, labelStd: Double,
labelMean: Double, labelMean: Double,
fitIntercept: Boolean,
featuresStd: Array[Double], featuresStd: Array[Double],
featuresMean: Array[Double], featuresMean: Array[Double],
effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] { effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
...@@ -412,7 +430,7 @@ private class LeastSquaresCostFun( ...@@ -412,7 +430,7 @@ private class LeastSquaresCostFun(
val w = Vectors.fromBreeze(weights) val w = Vectors.fromBreeze(weights)
val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd, val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd,
labelMean, featuresStd, featuresMean))( labelMean, fitIntercept, featuresStd, featuresMean))(
seqOp = (c, v) => (c, v) match { seqOp = (c, v) => (c, v) match {
case (aggregator, (label, features)) => aggregator.add(label, features) case (aggregator, (label, features)) => aggregator.add(label, features)
}, },
......
...@@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, Row} ...@@ -26,6 +26,7 @@ import org.apache.spark.sql.{DataFrame, Row}
class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var dataset: DataFrame = _ @transient var dataset: DataFrame = _
@transient var datasetWithoutIntercept: DataFrame = _
/** /**
* In `LinearRegressionSuite`, we will make sure that the model trained by SparkML * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
...@@ -34,14 +35,24 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -34,14 +35,24 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
* *
* import org.apache.spark.mllib.util.LinearDataGenerator * import org.apache.spark.mllib.util.LinearDataGenerator
* val data = * val data =
* sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2) * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2),
* data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).saveAsTextFile("path") * Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)
* data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1)
* .saveAsTextFile("path")
*/ */
override def beforeAll(): Unit = { override def beforeAll(): Unit = {
super.beforeAll() super.beforeAll()
dataset = sqlContext.createDataFrame( dataset = sqlContext.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput( sc.parallelize(LinearDataGenerator.generateLinearInput(
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
/**
* datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating
* training model without intercept
*/
datasetWithoutIntercept = sqlContext.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
0.0, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
} }
test("linear regression with intercept without regularization") { test("linear regression with intercept without regularization") {
...@@ -78,6 +89,42 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -78,6 +89,42 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
} }
} }
test("linear regression without intercept without regularization") {
val trainer = (new LinearRegression).setFitIntercept(false)
val model = trainer.fit(dataset)
val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept)
/**
* weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0,
* intercept = FALSE))
* > weights
* 3 x 1 sparse Matrix of class "dgCMatrix"
* s0
* (Intercept) .
* as.numeric.data.V2. 6.995908
* as.numeric.data.V3. 5.275131
*/
val weightsR = Array(6.995908, 5.275131)
assert(model.intercept ~== 0 relTol 1E-3)
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
/**
* Then again with the data with no intercept:
* > weightsWithoutIntercept
* 3 x 1 sparse Matrix of class "dgCMatrix"
* s0
* (Intercept) .
* as.numeric.data3.V2. 4.70011
* as.numeric.data3.V3. 7.19943
*/
val weightsWithoutInterceptR = Array(4.70011, 7.19943)
assert(modelWithoutIntercept.intercept ~== 0 relTol 1E-3)
assert(modelWithoutIntercept.weights(0) ~== weightsWithoutInterceptR(0) relTol 1E-3)
assert(modelWithoutIntercept.weights(1) ~== weightsWithoutInterceptR(1) relTol 1E-3)
}
test("linear regression with intercept with L1 regularization") { test("linear regression with intercept with L1 regularization") {
val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
val model = trainer.fit(dataset) val model = trainer.fit(dataset)
...@@ -87,11 +134,11 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -87,11 +134,11 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
* > weights * > weights
* 3 x 1 sparse Matrix of class "dgCMatrix" * 3 x 1 sparse Matrix of class "dgCMatrix"
* s0 * s0
* (Intercept) 6.311546 * (Intercept) 6.24300
* as.numeric.data.V2. 2.123522 * as.numeric.data.V2. 4.024821
* as.numeric.data.V3. 4.605651 * as.numeric.data.V3. 6.679841
*/ */
val interceptR = 6.243000 val interceptR = 6.24300
val weightsR = Array(4.024821, 6.679841) val weightsR = Array(4.024821, 6.679841)
assert(model.intercept ~== interceptR relTol 1E-3) assert(model.intercept ~== interceptR relTol 1E-3)
...@@ -106,6 +153,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -106,6 +153,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
} }
} }
test("linear regression without intercept with L1 regularization") {
val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
.setFitIntercept(false)
val model = trainer.fit(dataset)
/**
* weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57,
* intercept=FALSE))
* > weights
* 3 x 1 sparse Matrix of class "dgCMatrix"
* s0
* (Intercept) .
* as.numeric.data.V2. 6.299752
* as.numeric.data.V3. 4.772913
*/
val interceptR = 0.0
val weightsR = Array(6.299752, 4.772913)
assert(model.intercept ~== interceptR relTol 1E-3)
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
test("linear regression with intercept with L2 regularization") { test("linear regression with intercept with L2 regularization") {
val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
val model = trainer.fit(dataset) val model = trainer.fit(dataset)
...@@ -134,6 +211,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -134,6 +211,36 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
} }
} }
test("linear regression without intercept with L2 regularization") {
val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
.setFitIntercept(false)
val model = trainer.fit(dataset)
/**
* weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3,
* intercept = FALSE))
* > weights
* 3 x 1 sparse Matrix of class "dgCMatrix"
* s0
* (Intercept) .
* as.numeric.data.V2. 5.522875
* as.numeric.data.V3. 4.214502
*/
val interceptR = 0.0
val weightsR = Array(5.522875, 4.214502)
assert(model.intercept ~== interceptR relTol 1E-3)
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
test("linear regression with intercept with ElasticNet regularization") { test("linear regression with intercept with ElasticNet regularization") {
val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
val model = trainer.fit(dataset) val model = trainer.fit(dataset)
...@@ -161,4 +268,34 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -161,4 +268,34 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(prediction1 ~== prediction2 relTol 1E-5) assert(prediction1 ~== prediction2 relTol 1E-5)
} }
} }
test("linear regression without intercept with ElasticNet regularization") {
val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
.setFitIntercept(false)
val model = trainer.fit(dataset)
/**
* weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6,
* intercept=FALSE))
* > weights
* 3 x 1 sparse Matrix of class "dgCMatrix"
* s0
* (Intercept) .
* as.numeric.dataM.V2. 5.673348
* as.numeric.dataM.V3. 4.322251
*/
val interceptR = 0.0
val weightsR = Array(5.673348, 4.322251)
assert(model.intercept ~== interceptR relTol 1E-3)
assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
model.transform(dataset).select("features", "prediction").collect().foreach {
case Row(features: DenseVector, prediction1: Double) =>
val prediction2 =
features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
assert(prediction1 ~== prediction2 relTol 1E-5)
}
}
} }
...@@ -53,6 +53,11 @@ object MimaExcludes { ...@@ -53,6 +53,11 @@ object MimaExcludes {
// Removing a testing method from a private class // Removing a testing method from a private class
ProblemFilters.exclude[MissingMethodProblem]( ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"), "org.apache.spark.streaming.kafka.KafkaTestUtils.waitUntilLeaderOffset"),
// While private MiMa is still not happy about the changes,
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.ml.regression.LeastSquaresAggregator.this"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.ml.regression.LeastSquaresCostFun.this"),
// SQL execution is considered private. // SQL execution is considered private.
excludePackage("org.apache.spark.sql.execution"), excludePackage("org.apache.spark.sql.execution"),
// NanoTime and CatalystTimestampConverter is only used inside catalyst, // NanoTime and CatalystTimestampConverter is only used inside catalyst,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment