Skip to content
Snippets Groups Projects
Commit 3b931281 authored by Christoph Sawade's avatar Christoph Sawade Committed by Xiangrui Meng
Browse files

[SPARK-3396][MLLIB] Use SquaredL2Updater in LogisticRegressionWithSGD

SimpleUpdater ignores the regularizer, which leads to an unregularized
LogReg. To enable the common L2 regularizer (and the corresponding
regularization parameter) for logistic regression the SquaredL2Updater
has to be used in SGD (see, e.g., [SVMWithSGD])

Author: Christoph Sawade <christoph@sawade.me>

Closes #2398 from BigCrunsh/fix-regparam-logreg and squashes the following commits:

0820c04 [Christoph Sawade] Use SquaredL2Updater in LogisticRegressionWithSGD
parent 37d92528
No related branches found
No related tags found
No related merge requests found
......@@ -84,7 +84,7 @@ class LogisticRegressionWithSGD private (
extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable {
private val gradient = new LogisticGradient()
private val updater = new SimpleUpdater()
private val updater = new SquaredL2Updater()
override val optimizer = new GradientDescent(gradient, updater)
.setStepSize(stepSize)
.setNumIterations(numIterations)
......
......@@ -43,7 +43,7 @@ object LogisticRegressionSuite {
offset: Double,
scale: Double,
nPoints: Int,
seed: Int): Seq[LabeledPoint] = {
seed: Int): Seq[LabeledPoint] = {
val rnd = new Random(seed)
val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
......@@ -58,12 +58,15 @@ object LogisticRegressionSuite {
}
class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Matchers {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
def validatePrediction(
predictions: Seq[Double],
input: Seq[LabeledPoint],
expectedAcc: Double = 0.83) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
prediction != expected.label
}
// At least 83% of the predictions should be on.
((input.length - numOffPredictions).toDouble / input.length) should be > 0.83
((input.length - numOffPredictions).toDouble / input.length) should be > expectedAcc
}
// Test if we can correctly learn A, B where Y = logistic(A + B*X)
......@@ -155,6 +158,41 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
test("logistic regression with initial weights and non-default regularization parameter") {
val nPoints = 10000
val A = 2.0
val B = -1.5
val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42)
val initialB = -1.0
val initialWeights = Vectors.dense(initialB)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
// Use half as many iterations as the previous test.
val lr = new LogisticRegressionWithSGD().setIntercept(true)
lr.optimizer.
setStepSize(10.0).
setNumIterations(10).
setRegParam(1.0)
val model = lr.run(testRDD, initialWeights)
// Test the weights
assert(model.weights(0) ~== -430000.0 relTol 20000.0)
assert(model.intercept ~== 370000.0 relTol 20000.0)
val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData, 0.8)
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData, 0.8)
}
test("logistic regression with initial weights with LBFGS") {
val nPoints = 10000
val A = 2.0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment