Skip to content
Snippets Groups Projects
Commit b72611f2 authored by Holden Karau's avatar Holden Karau Committed by DB Tsai
Browse files

[SPARK-7780][MLLIB] intercept in logisticregressionwith lbfgs should not be regularized

The intercept in Logistic Regression represents a prior on categories which should not be regularized. In MLlib, the regularization is handled through Updater, and the Updater penalizes all the components without excluding the intercept which resulting poor training accuracy with regularization.
The new implementation in ML framework handles this properly, and we should call the implementation in ML from MLlib since majority of users are still using MLlib api.
Note that both of them are doing feature scalings to improve the convergence, and the only difference is ML version doesn't regularize the intercept. As a result, when lambda is zero, they will converge to the same solution.

Previously partially reviewed at https://github.com/apache/spark/pull/6386#issuecomment-168781424 re-opening for dbtsai to review.

Author: Holden Karau <holden@us.ibm.com>
Author: Holden Karau <holden@pigscanfly.ca>

Closes #10788 from holdenk/SPARK-7780-intercept-in-logisticregressionwithLBFGS-should-not-be-regularized.
parent 55512738
No related branches found
No related tags found
No related merge requests found
...@@ -247,15 +247,27 @@ class LogisticRegression @Since("1.2.0") ( ...@@ -247,15 +247,27 @@ class LogisticRegression @Since("1.2.0") (
@Since("1.5.0") @Since("1.5.0")
override def getThresholds: Array[Double] = super.getThresholds override def getThresholds: Array[Double] = super.getThresholds
override protected def train(dataset: DataFrame): LogisticRegressionModel = { private var optInitialModel: Option[LogisticRegressionModel] = None
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
/** @group setParam */
private[spark] def setInitialModel(model: LogisticRegressionModel): this.type = {
this.optInitialModel = Some(model)
this
}
override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = {
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
train(dataset, handlePersistence)
}
protected[spark] def train(dataset: DataFrame, handlePersistence: Boolean):
LogisticRegressionModel = {
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map {
case Row(label: Double, weight: Double, features: Vector) => case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features) Instance(label, weight, features)
} }
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
val (summarizer, labelSummarizer) = { val (summarizer, labelSummarizer) = {
...@@ -343,7 +355,21 @@ class LogisticRegression @Since("1.2.0") ( ...@@ -343,7 +355,21 @@ class LogisticRegression @Since("1.2.0") (
val initialCoefficientsWithIntercept = val initialCoefficientsWithIntercept =
Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures) Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures)
if ($(fitIntercept)) { if (optInitialModel.isDefined && optInitialModel.get.coefficients.size != numFeatures) {
val vec = optInitialModel.get.coefficients
logWarning(
s"Initial coefficients provided ${vec} did not match the expected size ${numFeatures}")
}
if (optInitialModel.isDefined && optInitialModel.get.coefficients.size == numFeatures) {
val initialCoefficientsWithInterceptArray = initialCoefficientsWithIntercept.toArray
optInitialModel.get.coefficients.foreachActive { case (index, value) =>
initialCoefficientsWithInterceptArray(index) = value
}
if ($(fitIntercept)) {
initialCoefficientsWithInterceptArray(numFeatures) == optInitialModel.get.intercept
}
} else if ($(fitIntercept)) {
/* /*
For binary logistic regression, when we initialize the coefficients as zeros, For binary logistic regression, when we initialize the coefficients as zeros,
it will converge faster if we initialize the intercept such that it will converge faster if we initialize the intercept such that
...@@ -434,7 +460,7 @@ object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { ...@@ -434,7 +460,7 @@ object LogisticRegression extends DefaultParamsReadable[LogisticRegression] {
*/ */
@Since("1.4.0") @Since("1.4.0")
@Experimental @Experimental
class LogisticRegressionModel private[ml] ( class LogisticRegressionModel private[spark] (
@Since("1.4.0") override val uid: String, @Since("1.4.0") override val uid: String,
@Since("1.6.0") val coefficients: Vector, @Since("1.6.0") val coefficients: Vector,
@Since("1.3.0") val intercept: Double) @Since("1.3.0") val intercept: Double)
......
...@@ -19,15 +19,18 @@ package org.apache.spark.mllib.classification ...@@ -19,15 +19,18 @@ package org.apache.spark.mllib.classification
import org.apache.spark.SparkContext import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since import org.apache.spark.annotation.Since
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.classification.impl.GLMClassificationModel import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.linalg.BLAS.dot
import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable} import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable}
import org.apache.spark.mllib.util.MLUtils.appendBias
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.storage.StorageLevel
/** /**
* Classification model trained using Multinomial/Binary Logistic Regression. * Classification model trained using Multinomial/Binary Logistic Regression.
...@@ -332,6 +335,13 @@ object LogisticRegressionWithSGD { ...@@ -332,6 +335,13 @@ object LogisticRegressionWithSGD {
* Limited-memory BFGS. Standard feature scaling and L2 regularization are used by default. * Limited-memory BFGS. Standard feature scaling and L2 regularization are used by default.
* NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1}
* for k classes multi-label classification problem. * for k classes multi-label classification problem.
*
* Earlier implementations of LogisticRegressionWithLBFGS applies a regularization
* penalty to all elements including the intercept. If this is called with one of
* standard updaters (L1Updater, or SquaredL2Updater) this is translated
* into a call to ml.LogisticRegression, otherwise this will use the existing mllib
* GeneralizedLinearAlgorithm trainer, resulting in a regularization penalty to the
* intercept.
*/ */
@Since("1.1.0") @Since("1.1.0")
class LogisticRegressionWithLBFGS class LogisticRegressionWithLBFGS
...@@ -374,4 +384,72 @@ class LogisticRegressionWithLBFGS ...@@ -374,4 +384,72 @@ class LogisticRegressionWithLBFGS
new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1) new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1)
} }
} }
/**
* Run Logistic Regression with the configured parameters on an input RDD
* of LabeledPoint entries.
*
* If a known updater is used calls the ml implementation, to avoid
* applying a regularization penalty to the intercept, otherwise
* defaults to the mllib implementation. If more than two classes
* or feature scaling is disabled, always uses mllib implementation.
* If using ml implementation, uses ml code to generate initial weights.
*/
override def run(input: RDD[LabeledPoint]): LogisticRegressionModel = {
run(input, generateInitialWeights(input), userSuppliedWeights = false)
}
/**
* Run Logistic Regression with the configured parameters on an input RDD
* of LabeledPoint entries starting from the initial weights provided.
*
* If a known updater is used calls the ml implementation, to avoid
* applying a regularization penalty to the intercept, otherwise
* defaults to the mllib implementation. If more than two classes
* or feature scaling is disabled, always uses mllib implementation.
* Uses user provided weights.
*/
override def run(input: RDD[LabeledPoint], initialWeights: Vector): LogisticRegressionModel = {
run(input, initialWeights, userSuppliedWeights = true)
}
private def run(input: RDD[LabeledPoint], initialWeights: Vector, userSuppliedWeights: Boolean):
LogisticRegressionModel = {
// ml's Logisitic regression only supports binary classifcation currently.
if (numOfLinearPredictor == 1) {
def runWithMlLogisitcRegression(elasticNetParam: Double) = {
// Prepare the ml LogisticRegression based on our settings
val lr = new org.apache.spark.ml.classification.LogisticRegression()
lr.setRegParam(optimizer.getRegParam())
lr.setElasticNetParam(elasticNetParam)
lr.setStandardization(useFeatureScaling)
if (userSuppliedWeights) {
val uid = Identifiable.randomUID("logreg-static")
lr.setInitialModel(new org.apache.spark.ml.classification.LogisticRegressionModel(
uid, initialWeights, 1.0))
}
lr.setFitIntercept(addIntercept)
lr.setMaxIter(optimizer.getNumIterations())
lr.setTol(optimizer.getConvergenceTol())
// Convert our input into a DataFrame
val sqlContext = new SQLContext(input.context)
import sqlContext.implicits._
val df = input.toDF()
// Determine if we should cache the DF
val handlePersistence = input.getStorageLevel == StorageLevel.NONE
// Train our model
val mlLogisticRegresionModel = lr.train(df, handlePersistence)
// convert the model
val weights = Vectors.dense(mlLogisticRegresionModel.coefficients.toArray)
createModel(weights, mlLogisticRegresionModel.intercept)
}
optimizer.getUpdater() match {
case x: SquaredL2Updater => runWithMlLogisitcRegression(1.0)
case x: L1Updater => runWithMlLogisitcRegression(0.0)
case _ => super.run(input, initialWeights)
}
} else {
super.run(input, initialWeights)
}
}
} }
...@@ -69,6 +69,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) ...@@ -69,6 +69,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
this this
} }
/*
* Get the convergence tolerance of iterations.
*/
private[mllib] def getConvergenceTol(): Double = {
this.convergenceTol
}
/** /**
* Set the maximal number of iterations for L-BFGS. Default 100. * Set the maximal number of iterations for L-BFGS. Default 100.
* @deprecated use [[LBFGS#setNumIterations]] instead * @deprecated use [[LBFGS#setNumIterations]] instead
...@@ -86,6 +93,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) ...@@ -86,6 +93,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
this this
} }
/**
* Get the maximum number of iterations for L-BFGS. Defaults to 100.
*/
private[mllib] def getNumIterations(): Int = {
this.maxNumIterations
}
/** /**
* Set the regularization parameter. Default 0.0. * Set the regularization parameter. Default 0.0.
*/ */
...@@ -94,6 +108,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) ...@@ -94,6 +108,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
this this
} }
/**
* Get the regularization parameter.
*/
private[mllib] def getRegParam(): Double = {
this.regParam
}
/** /**
* Set the gradient function (of the loss function of one single data example) * Set the gradient function (of the loss function of one single data example)
* to be used for L-BFGS. * to be used for L-BFGS.
...@@ -113,6 +134,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) ...@@ -113,6 +134,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
this this
} }
/**
* Returns the updater, limited to internal use.
*/
private[mllib] def getUpdater(): Updater = {
updater
}
override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = { override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = {
val (weights, _) = LBFGS.runLBFGS( val (weights, _) = LBFGS.runLBFGS(
data, data,
......
...@@ -140,7 +140,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] ...@@ -140,7 +140,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
* translated back to resulting model weights, so it's transparent to users. * translated back to resulting model weights, so it's transparent to users.
* Note: This technique is used in both libsvm and glmnet packages. Default false. * Note: This technique is used in both libsvm and glmnet packages. Default false.
*/ */
private var useFeatureScaling = false private[mllib] var useFeatureScaling = false
/** /**
* The dimension of training features. * The dimension of training features.
...@@ -196,12 +196,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] ...@@ -196,12 +196,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
} }
/** /**
* Run the algorithm with the configured parameters on an input * Generate the initial weights when the user does not supply them
* RDD of LabeledPoint entries.
*
*/ */
@Since("0.8.0") protected def generateInitialWeights(input: RDD[LabeledPoint]): Vector = {
def run(input: RDD[LabeledPoint]): M = {
if (numFeatures < 0) { if (numFeatures < 0) {
numFeatures = input.map(_.features.size).first() numFeatures = input.map(_.features.size).first()
} }
...@@ -217,16 +214,23 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] ...@@ -217,16 +214,23 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
* TODO: See if we can deprecate `intercept` in `GeneralizedLinearModel`, and always * TODO: See if we can deprecate `intercept` in `GeneralizedLinearModel`, and always
* have the intercept as part of weights to have consistent design. * have the intercept as part of weights to have consistent design.
*/ */
val initialWeights = { if (numOfLinearPredictor == 1) {
if (numOfLinearPredictor == 1) { Vectors.zeros(numFeatures)
Vectors.zeros(numFeatures) } else if (addIntercept) {
} else if (addIntercept) { Vectors.zeros((numFeatures + 1) * numOfLinearPredictor)
Vectors.zeros((numFeatures + 1) * numOfLinearPredictor) } else {
} else { Vectors.zeros(numFeatures * numOfLinearPredictor)
Vectors.zeros(numFeatures * numOfLinearPredictor)
}
} }
run(input, initialWeights) }
/**
* Run the algorithm with the configured parameters on an input
* RDD of LabeledPoint entries.
*
*/
@Since("0.8.0")
def run(input: RDD[LabeledPoint]): M = {
run(input, generateInitialWeights(input))
} }
/** /**
......
...@@ -168,7 +168,7 @@ private class MockLogisticRegression(uid: String) extends LogisticRegression(uid ...@@ -168,7 +168,7 @@ private class MockLogisticRegression(uid: String) extends LogisticRegression(uid
setMaxIter(1) setMaxIter(1)
override protected def train(dataset: DataFrame): LogisticRegressionModel = { override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = {
val labelSchema = dataset.schema($(labelCol)) val labelSchema = dataset.schema($(labelCol))
// check for label attribute propagation. // check for label attribute propagation.
assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2)) assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2))
......
...@@ -25,6 +25,7 @@ import org.scalatest.Matchers ...@@ -25,6 +25,7 @@ import org.scalatest.Matchers
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.TestingUtils._
...@@ -215,6 +216,11 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w ...@@ -215,6 +216,11 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
// Test if we can correctly learn A, B where Y = logistic(A + B*X) // Test if we can correctly learn A, B where Y = logistic(A + B*X)
test("logistic regression with LBFGS") { test("logistic regression with LBFGS") {
val updaters: List[Updater] = List(new SquaredL2Updater(), new L1Updater())
updaters.foreach(testLBFGS)
}
private def testLBFGS(myUpdater: Updater): Unit = {
val nPoints = 10000 val nPoints = 10000
val A = 2.0 val A = 2.0
val B = -1.5 val B = -1.5
...@@ -223,7 +229,15 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w ...@@ -223,7 +229,15 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
val testRDD = sc.parallelize(testData, 2) val testRDD = sc.parallelize(testData, 2)
testRDD.cache() testRDD.cache()
val lr = new LogisticRegressionWithLBFGS().setIntercept(true)
// Override the updater
class LogisticRegressionWithLBFGSCustomUpdater
extends LogisticRegressionWithLBFGS {
override val optimizer =
new LBFGS(new LogisticGradient, myUpdater)
}
val lr = new LogisticRegressionWithLBFGSCustomUpdater().setIntercept(true)
val model = lr.run(testRDD) val model = lr.run(testRDD)
...@@ -396,10 +410,11 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w ...@@ -396,10 +410,11 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
assert(modelA1.weights(0) ~== modelA3.weights(0) * 1.0E6 absTol 0.01) assert(modelA1.weights(0) ~== modelA3.weights(0) * 1.0E6 absTol 0.01)
// Training data with different scales without feature standardization // Training data with different scales without feature standardization
// will not yield the same result in the scaled space due to poor // should still converge quickly since the model still uses standardization but
// convergence rate. // simply modifies the regularization function. See regParamL1Fun and related
assert(modelB1.weights(0) !~== modelB2.weights(0) * 1.0E3 absTol 0.1) // inside of LogisticRegression
assert(modelB1.weights(0) !~== modelB3.weights(0) * 1.0E6 absTol 0.1) assert(modelB1.weights(0) ~== modelB2.weights(0) * 1.0E3 absTol 0.1)
assert(modelB1.weights(0) ~== modelB3.weights(0) * 1.0E6 absTol 0.1)
} }
test("multinomial logistic regression with LBFGS") { test("multinomial logistic regression with LBFGS") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment