Skip to content
Snippets Groups Projects
Commit d679843a authored by Xiangrui Meng's avatar Xiangrui Meng Committed by Tathagata Das
Browse files

[SPARK-1327] GLM needs to check addIntercept for intercept and weights

GLM needs to check addIntercept for intercept and weights. The current implementation always uses the first weight as intercept. Added a test for training without adding intercept.

JIRA: https://spark-project.atlassian.net/browse/SPARK-1327

Author: Xiangrui Meng <meng@databricks.com>

Closes #236 from mengxr/glm and squashes the following commits:

bcac1ac [Xiangrui Meng] add two tests to ensure {Lasso, Ridge}.setIntercept will throw an exceptions
a104072 [Xiangrui Meng] remove protected to be compatible with 0.9
0e57aa4 [Xiangrui Meng] update Lasso and RidgeRegression to parse the weights correctly from GLM mark createModel protected mark predictPoint protected
d7f629f [Xiangrui Meng] fix a bug in GLM when intercept is not used
parent 1fa48d94
No related branches found
No related tags found
No related merge requests found
......@@ -136,25 +136,28 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
// Prepend an extra variable consisting of all 1.0's for the intercept.
val data = if (addIntercept) {
input.map(labeledPoint => (labeledPoint.label, labeledPoint.features.+:(1.0)))
input.map(labeledPoint => (labeledPoint.label, 1.0 +: labeledPoint.features))
} else {
input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
}
val initialWeightsWithIntercept = if (addIntercept) {
initialWeights.+:(1.0)
0.0 +: initialWeights
} else {
initialWeights
}
val weights = optimizer.optimize(data, initialWeightsWithIntercept)
val intercept = weights(0)
val weightsScaled = weights.tail
val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept)
val model = createModel(weightsScaled, intercept)
val (intercept, weights) = if (addIntercept) {
(weightsWithIntercept(0), weightsWithIntercept.tail)
} else {
(0.0, weightsWithIntercept)
}
logInfo("Final weights " + weights.mkString(","))
logInfo("Final intercept " + intercept)
logInfo("Final model weights " + model.weights.mkString(","))
logInfo("Final model intercept " + model.intercept)
model
createModel(weights, intercept)
}
}
......@@ -36,8 +36,10 @@ class LassoModel(
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double) = {
override def predictPoint(
dataMatrix: DoubleMatrix,
weightMatrix: DoubleMatrix,
intercept: Double): Double = {
dataMatrix.dot(weightMatrix) + intercept
}
}
......@@ -66,7 +68,7 @@ class LassoWithSGD private (
.setMiniBatchFraction(miniBatchFraction)
// We don't want to penalize the intercept, so set this to false.
setIntercept(false)
super.setIntercept(false)
var yMean = 0.0
var xColMean: DoubleMatrix = _
......@@ -77,10 +79,16 @@ class LassoWithSGD private (
*/
def this() = this(1.0, 100, 1.0, 1.0)
def createModel(weights: Array[Double], intercept: Double) = {
val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
override def setIntercept(addIntercept: Boolean): this.type = {
// TODO: Support adding intercept.
if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.")
this
}
override def createModel(weights: Array[Double], intercept: Double) = {
val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*)
val weightsScaled = weightsMat.div(xColSd)
val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0))
val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)
new LassoModel(weightsScaled.data, interceptScaled)
}
......
......@@ -31,13 +31,14 @@ import org.jblas.DoubleMatrix
* @param intercept Intercept computed for this model.
*/
class LinearRegressionModel(
override val weights: Array[Double],
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double) = {
override val weights: Array[Double],
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable {
override def predictPoint(
dataMatrix: DoubleMatrix,
weightMatrix: DoubleMatrix,
intercept: Double): Double = {
dataMatrix.dot(weightMatrix) + intercept
}
}
......@@ -55,8 +56,7 @@ class LinearRegressionWithSGD private (
var stepSize: Double,
var numIterations: Int,
var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LinearRegressionModel]
with Serializable {
extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable {
val gradient = new LeastSquaresGradient()
val updater = new SimpleUpdater()
......@@ -69,7 +69,7 @@ class LinearRegressionWithSGD private (
*/
def this() = this(1.0, 100, 1.0)
def createModel(weights: Array[Double], intercept: Double) = {
override def createModel(weights: Array[Double], intercept: Double) = {
new LinearRegressionModel(weights, intercept)
}
}
......
......@@ -36,8 +36,10 @@ class RidgeRegressionModel(
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double) = {
override def predictPoint(
dataMatrix: DoubleMatrix,
weightMatrix: DoubleMatrix,
intercept: Double): Double = {
dataMatrix.dot(weightMatrix) + intercept
}
}
......@@ -67,7 +69,7 @@ class RidgeRegressionWithSGD private (
.setMiniBatchFraction(miniBatchFraction)
// We don't want to penalize the intercept in RidgeRegression, so set this to false.
setIntercept(false)
super.setIntercept(false)
var yMean = 0.0
var xColMean: DoubleMatrix = _
......@@ -78,8 +80,14 @@ class RidgeRegressionWithSGD private (
*/
def this() = this(1.0, 100, 1.0, 1.0)
def createModel(weights: Array[Double], intercept: Double) = {
val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
override def setIntercept(addIntercept: Boolean): this.type = {
// TODO: Support adding intercept.
if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.")
this
}
override def createModel(weights: Array[Double], intercept: Double) = {
val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*)
val weightsScaled = weightsMat.div(xColSd)
val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)
......
......@@ -17,11 +17,8 @@
package org.apache.spark.mllib.regression
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
class LassoSuite extends FunSuite with LocalSparkContext {
......@@ -104,4 +101,10 @@ class LassoSuite extends FunSuite with LocalSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
test("do not support intercept") {
intercept[UnsupportedOperationException] {
new LassoWithSGD().setIntercept(true)
}
}
}
......@@ -17,7 +17,6 @@
package org.apache.spark.mllib.regression
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
......@@ -57,4 +56,29 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
// Test if we can correctly learn Y = 10*X1 + 10*X2
test("linear regression without intercept") {
val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(
0.0, Array(10.0, 10.0), 100, 42), 2).cache()
val linReg = new LinearRegressionWithSGD().setIntercept(false)
linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
val model = linReg.run(testRDD)
assert(model.intercept === 0.0)
assert(model.weights.length === 2)
assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0)
assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0)
val validationData = LinearDataGenerator.generateLinearInput(
0.0, Array(10.0, 10.0), 100, 17)
val validationRDD = sc.parallelize(validationData, 2).cache()
// Test prediction on RDD.
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
}
......@@ -17,14 +17,11 @@
package org.apache.spark.mllib.regression
import org.jblas.DoubleMatrix
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = {
......@@ -74,4 +71,10 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
assert(ridgeErr < linearErr,
"ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
}
test("do not support intercept") {
intercept[UnsupportedOperationException] {
new RidgeRegressionWithSGD().setIntercept(true)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment