Skip to content
Snippets Groups Projects
Commit d2e29516 authored by Davies Liu's avatar Davies Liu Committed by Xiangrui Meng
Browse files

[SPARK-4306] [MLlib] Python API for LogisticRegressionWithLBFGS

```
class LogisticRegressionWithLBFGS
 |  train(cls, data, iterations=100, initialWeights=None, corrections=10, tolerance=0.0001, regParam=0.01, intercept=False)
 |      Train a logistic regression model on the given data.
 |
 |      :param data:           The training data, an RDD of LabeledPoint.
 |      :param iterations:     The number of iterations (default: 100).
 |      :param initialWeights: The initial weights (default: None).
 |      :param regParam:       The regularizer parameter (default: 0.01).
 |      :param regType:        The type of regularizer used for training
 |                             our model.
 |                             :Allowed values:
 |                               - "l1" for using L1 regularization
 |                               - "l2" for using L2 regularization
 |                               - None for no regularization
 |                               (default: "l2")
 |      :param intercept:      Boolean parameter which indicates the use
 |                             or not of the augmented representation for
 |                             training data (i.e. whether bias features
 |                             are activated or not).
 |      :param corrections:    The number of corrections used in the LBFGS update (default: 10).
 |      :param tolerance:      The convergence tolerance of iterations for L-BFGS (default: 1e-4).
 |
 |      >>> data = [
 |      ...     LabeledPoint(0.0, [0.0, 1.0]),
 |      ...     LabeledPoint(1.0, [1.0, 0.0]),
 |      ... ]
 |      >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data))
 |      >>> lrm.predict([1.0, 0.0])
 |      1
 |      >>> lrm.predict([0.0, 1.0])
 |      0
 |      >>> lrm.predict(sc.parallelize([[1.0, 0.0], [0.0, 1.0]])).collect()
 |      [1, 0]
```

Author: Davies Liu <davies@databricks.com>

Closes #3307 from davies/lbfgs and squashes the following commits:

34bd986 [Davies Liu] Merge branch 'master' of http://git-wip-us.apache.org/repos/asf/spark into lbfgs
5a945a6 [Davies Liu] address comments
941061b [Davies Liu] Merge branch 'master' of github.com:apache/spark into lbfgs
03e5543 [Davies Liu] add it to docs
ed2f9a8 [Davies Liu] add regType
76cd1b6 [Davies Liu] reorder arguments
4429a74 [Davies Liu] Update classification.py
9252783 [Davies Liu] python api for LogisticRegressionWithLBFGS
parent 010bc86e
No related branches found
No related tags found
No related merge requests found
......@@ -229,6 +229,41 @@ class PythonMLLibAPI extends Serializable {
initialWeights)
}
/**
* Java stub for Python mllib LogisticRegressionWithLBFGS.train()
*/
def trainLogisticRegressionModelWithLBFGS(
data: JavaRDD[LabeledPoint],
numIterations: Int,
initialWeights: Vector,
regParam: Double,
regType: String,
intercept: Boolean,
corrections: Int,
tolerance: Double): JList[Object] = {
val LogRegAlg = new LogisticRegressionWithLBFGS()
LogRegAlg.setIntercept(intercept)
LogRegAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
.setNumCorrections(corrections)
.setConvergenceTol(tolerance)
if (regType == "l2") {
LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
LogRegAlg.optimizer.setUpdater(new L1Updater)
} else if (regType == null) {
LogRegAlg.optimizer.setUpdater(new SimpleUpdater)
} else {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ " Can only be initialized using the following string values: ['l1', 'l2', None].")
}
trainRegressionModel(
LogRegAlg,
data,
initialWeights)
}
/**
* Java stub for NaiveBayes.train()
*/
......
......@@ -26,8 +26,8 @@ from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'SVMModel',
'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes']
__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS',
'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes']
class LinearBinaryClassificationModel(LinearModel):
......@@ -151,7 +151,7 @@ class LogisticRegressionWithSGD(object):
(default: "l2")
@param intercept: Boolean parameter which indicates the use
:param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
are activated or not).
......@@ -164,6 +164,55 @@ class LogisticRegressionWithSGD(object):
return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)
class LogisticRegressionWithLBFGS(object):
@classmethod
def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2",
intercept=False, corrections=10, tolerance=1e-4):
"""
Train a logistic regression model on the given data.
:param data: The training data, an RDD of LabeledPoint.
:param iterations: The number of iterations (default: 100).
:param initialWeights: The initial weights (default: None).
:param regParam: The regularizer parameter (default: 0.01).
:param regType: The type of regularizer used for training
our model.
:Allowed values:
- "l1" for using L1 regularization
- "l2" for using L2 regularization
- None for no regularization
(default: "l2")
:param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
are activated or not).
:param corrections: The number of corrections used in the LBFGS
update (default: 10).
:param tolerance: The convergence tolerance of iterations for
L-BFGS (default: 1e-4).
>>> data = [
... LabeledPoint(0.0, [0.0, 1.0]),
... LabeledPoint(1.0, [1.0, 0.0]),
... ]
>>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data))
>>> lrm.predict([1.0, 0.0])
1
>>> lrm.predict([0.0, 1.0])
0
"""
def train(rdd, i):
return callMLlibFunc("trainLogisticRegressionModelWithLBFGS", rdd, int(iterations), i,
float(regParam), str(regType), bool(intercept), int(corrections),
float(tolerance))
return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)
class SVMModel(LinearBinaryClassificationModel):
"""A support vector machine.
......@@ -241,7 +290,7 @@ class SVMWithSGD(object):
(default: "l2")
@param intercept: Boolean parameter which indicates the use
:param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
are activated or not).
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment