Skip to content
Snippets Groups Projects
Commit 287bea13 authored by sethah's avatar sethah Committed by DB Tsai
Browse files

[SPARK-7159][ML] Add multiclass logistic regression to Spark ML

## What changes were proposed in this pull request?

This patch adds a new estimator/transformer `MultinomialLogisticRegression` to spark ML.

JIRA: [SPARK-7159](https://issues.apache.org/jira/browse/SPARK-7159)

## How was this patch tested?

Added new test suite `MultinomialLogisticRegressionSuite`.

## Approach

### Do not use a "pivot" class in the algorithm formulation

Many implementations of multinomial logistic regression treat the problem as K - 1 independent binary logistic regression models where K is the number of possible outcomes in the output variable. In this case, one outcome is chosen as a "pivot" and the other K - 1 outcomes are regressed against the pivot. This is somewhat undesirable since the coefficients returned will be different for different choices of pivot variables. An alternative approach to the problem models class conditional probabilites using the softmax function and will return uniquely identifiable coefficients (assuming regularization is applied). This second approach is used in R's glmnet and was also recommended by dbtsai.

### Separate multinomial logistic regression and binary logistic regression

The initial design makes multinomial logistic regression a separate estimator/transformer than the existing LogisticRegression estimator/transformer. An alternative design would be to merge them into one.

**Arguments for:**

* The multinomial case without pivot is distinctly different than the current binary case since the binary case uses a pivot class.
* The current logistic regression model in ML uses a vector of coefficients and a scalar intercept. In the multinomial case, we require a matrix of coefficients and a vector of intercepts. There are potential workarounds for this issue if we were to merge the two estimators, but none are particularly elegant.

**Arguments against:**

* It may be inconvenient for users to have to switch the estimator class when transitioning between binary and multiclass (although the new multinomial estimator can be used for two class outcomes).
* Some portions of the code are repeated.

This is a major design point and warrants more discussion.

### Mean centering

When no regularization is applied, the coefficients will not be uniquely identifiable. This is not hard to show and is discussed in further detail [here](https://core.ac.uk/download/files/153/6287975.pdf). R's glmnet deals with this by choosing the minimum l2 regularized solution (i.e. mean centering). Additionally, the intercepts are never regularized so they are always mean centered. This is the approach taken in this PR as well.

### Feature scaling

In current ML logistic regression, the features are always standardized when running the optimization algorithm. They are always returned to the user in the original feature space, however. This same approach is maintained in this patch as well, but the implementation details are different. In ML logistic regression, the unregularized feature values are divided by the column standard deviation in every gradient update iteration. In contrast, MLlib transforms the entire input dataset to the scaled space _before_ optimizaton. In ML, this means that `numFeatures * numClasses` extra scalar divisions are required in every iteration. Performance testing shows that this has significant (4x in some cases) slow downs in each iteration. This can be avoided by transforming the input to the scaled space ala MLlib once, before iteration begins. This does add some overhead initially, but can make significant time savings in some cases.

One issue with this approach is that if the input data is already cached, there may not be enough memory to cache the transformed data, which would make the algorithm _much_ slower. The tradeoffs here merit more discussion.

### Specifying and inferring the number of outcome classes

The estimator checks the dataframe label column for metadata which specifies the number of values. If they are not specified, the length of the `histogram` variable is used, which is essentially the maximum value found in the column. The assumption then, is that the labels are zero-indexed when they are provided to the algorithm.

## Performance

Below are some performance tests I have run so far. I am happy to add more cases or trials if we deem them necessary.

Test cluster: 4 bare metal nodes, 128 GB RAM each, 48 cores each

Notes:

* Time in units of seconds
* Metric is classification accuracy

| algo   |   elasticNetParam | fitIntercept   |   metric |   maxIter |   numPoints |   numClasses |   numFeatures |    time | standardization   |   regParam |
|--------|-------------------|----------------|----------|-----------|-------------|--------------|---------------|---------|-------------------|------------|
| ml     |                 0 | true           | 0.746415 |        30 |      100000 |            3 |        100000 | 327.923 | true              |          0 |
| mllib  |                 0 | true           | 0.743785 |        30 |      100000 |            3 |        100000 | 390.217 | true              |          0 |

| algo   |   elasticNetParam | fitIntercept   |   metric |   maxIter |   numPoints |   numClasses |   numFeatures |    time | standardization   |   regParam |
|--------|-------------------|----------------|----------|-----------|-------------|--------------|---------------|---------|-------------------|------------|
| ml     |                 0 | true           | 0.973238 |        30 |     2000000 |            3 |         10000 | 385.476 | true              |          0 |
| mllib  |                 0 | true           | 0.949828 |        30 |     2000000 |            3 |         10000 | 550.403 | true              |          0 |

| algo   |   elasticNetParam | fitIntercept   |   metric |   maxIter |   numPoints |   numClasses |   numFeatures |    time | standardization   |   regParam |
|--------|-------------------|----------------|----------|-----------|-------------|--------------|---------------|---------|-------------------|------------|
| mllib  |                 0 | true           | 0.864358 |        30 |     2000000 |            3 |         10000 | 543.359 | true              |        0.1 |
| ml     |                 0 | true           | 0.867418 |        30 |     2000000 |            3 |         10000 | 401.955 | true              |        0.1 |

| algo   |   elasticNetParam | fitIntercept   |   metric |   maxIter |   numPoints |   numClasses |   numFeatures |    time | standardization   |   regParam |
|--------|-------------------|----------------|----------|-----------|-------------|--------------|---------------|---------|-------------------|------------|
| ml     |                 1 | true           | 0.807449 |        30 |     2000000 |            3 |         10000 | 334.892 | true              |       0.05 |

| algo   |   elasticNetParam | fitIntercept   |   metric |   maxIter |   numPoints |   numClasses |   numFeatures |    time | standardization   |   regParam |
|--------|-------------------|----------------|----------|-----------|-------------|--------------|---------------|---------|-------------------|------------|
| ml     |                 0 | true           | 0.602006 |        30 |     2000000 |          500 |           100 | 112.319 | true              |          0 |
| mllib  |                 0 | true           | 0.567226 |        30 |     2000000 |          500 |           100 | 263.768 | true              |          0 |e           | 0.567226 |        30 |     2000000 |          500 |           100 | 263.768 | true              |          0 |

## References

Friedman, et al. ["Regularization Paths for Generalized Linear Models via Coordinate Descent"](https://core.ac.uk/download/files/153/6287975.pdf)
[http://web.stanford.edu/~hastie/glmnet/glmnet_alpha.html](http://web.stanford.edu/~hastie/glmnet/glmnet_alpha.html)

## Follow up items
* Consider using level 2 BLAS routines in the gradient computations - [SPARK-17134](https://issues.apache.org/jira/browse/SPARK-17134)
* Add model summary for MLOR - [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139)
* Add initial model to MLOR and add test for intercept priors - [SPARK-17140](https://issues.apache.org/jira/browse/SPARK-17140)
* Python API - [SPARK-17138](https://issues.apache.org/jira/browse/SPARK-17138)
* Consider changing the tree aggregation level for MLOR/BLOR or making it user configurable to avoid memory problems with high dimensional data - [SPARK-17090](https://issues.apache.org/jira/browse/SPARK-17090)
* Refactor helper classes out of `LogisticRegression.scala` - [SPARK-17135](https://issues.apache.org/jira/browse/SPARK-17135)
* Design optimizer interface for added flexibility in ML algos - [SPARK-17136](https://issues.apache.org/jira/browse/SPARK-17136)
* Support compressing the coefficients and intercepts for MLOR models - [SPARK-17137](https://issues.apache.org/jira/browse/SPARK-17137)

Author: sethah <seth.hendrickson16@gmail.com>

Closes #13796 from sethah/SPARK-7159_M.
parent b482c09f
No related branches found
No related tags found
No related merge requests found
......@@ -63,6 +63,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
* equivalent.
*
* Default is 0.5.
*
* @group setParam
*/
def setThreshold(value: Double): this.type = {
......@@ -131,6 +132,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
/**
* If [[threshold]] and [[thresholds]] are both set, ensures they are consistent.
*
* @throws IllegalArgumentException if [[threshold]] and [[thresholds]] are not equivalent
*/
protected def checkThresholdConsistency(): Unit = {
......@@ -153,8 +155,8 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
/**
* Logistic regression.
* Currently, this class only supports binary classification. It will support multiclass
* in the future.
* Currently, this class only supports binary classification. For multiclass classification,
* use [[MultinomialLogisticRegression]]
*/
@Since("1.2.0")
class LogisticRegression @Since("1.2.0") (
......@@ -168,6 +170,7 @@ class LogisticRegression @Since("1.2.0") (
/**
* Set the regularization parameter.
* Default is 0.0.
*
* @group setParam
*/
@Since("1.2.0")
......@@ -179,6 +182,7 @@ class LogisticRegression @Since("1.2.0") (
* For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
* For 0 < alpha < 1, the penalty is a combination of L1 and L2.
* Default is 0.0 which is an L2 penalty.
*
* @group setParam
*/
@Since("1.4.0")
......@@ -188,6 +192,7 @@ class LogisticRegression @Since("1.2.0") (
/**
* Set the maximum number of iterations.
* Default is 100.
*
* @group setParam
*/
@Since("1.2.0")
......@@ -198,6 +203,7 @@ class LogisticRegression @Since("1.2.0") (
* Set the convergence tolerance of iterations.
* Smaller value will lead to higher accuracy with the cost of more iterations.
* Default is 1E-6.
*
* @group setParam
*/
@Since("1.4.0")
......@@ -207,6 +213,7 @@ class LogisticRegression @Since("1.2.0") (
/**
* Whether to fit an intercept term.
* Default is true.
*
* @group setParam
*/
@Since("1.4.0")
......@@ -220,6 +227,7 @@ class LogisticRegression @Since("1.2.0") (
* the models should be always converged to the same solution when no regularization
* is applied. In R's GLMNET package, the default behavior is true as well.
* Default is true.
*
* @group setParam
*/
@Since("1.5.0")
......@@ -233,9 +241,10 @@ class LogisticRegression @Since("1.2.0") (
override def getThreshold: Double = super.getThreshold
/**
* Whether to over-/under-sample training instances according to the given weights in weightCol.
* If not set or empty String, all instances are treated equally (weight 1.0).
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@Since("1.6.0")
......@@ -310,12 +319,15 @@ class LogisticRegression @Since("1.2.0") (
throw new SparkException(msg)
}
val isConstantLabel = histogram.count(_ != 0) == 1
if (numClasses > 2) {
val msg = s"Currently, LogisticRegression with ElasticNet in ML package only supports " +
s"binary classification. Found $numClasses in the input dataset."
val msg = s"LogisticRegression with ElasticNet in ML package only supports " +
s"binary classification. Found $numClasses in the input dataset. Consider using " +
s"MultinomialLogisticRegression instead."
logError(msg)
throw new SparkException(msg)
} else if ($(fitIntercept) && numClasses == 2 && histogram(0) == 0.0) {
} else if ($(fitIntercept) && numClasses == 2 && isConstantLabel) {
logWarning(s"All labels are one and fitIntercept=true, so the coefficients will be " +
s"zeros and the intercept will be positive infinity; as a result, " +
s"training is not needed.")
......@@ -326,12 +338,9 @@ class LogisticRegression @Since("1.2.0") (
s"training is not needed.")
(Vectors.sparse(numFeatures, Seq()), Double.NegativeInfinity, Array.empty[Double])
} else {
if (!$(fitIntercept) && numClasses == 2 && histogram(0) == 0.0) {
logWarning(s"All labels are one and fitIntercept=false. It's a dangerous ground, " +
s"so the algorithm may not converge.")
} else if (!$(fitIntercept) && numClasses == 1) {
logWarning(s"All labels are zero and fitIntercept=false. It's a dangerous ground, " +
s"so the algorithm may not converge.")
if (!$(fitIntercept) && isConstantLabel) {
logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " +
s"dangerous ground, so the algorithm may not converge.")
}
val featuresMean = summarizer.mean.toArray
......@@ -349,7 +358,7 @@ class LogisticRegression @Since("1.2.0") (
val bcFeaturesStd = instances.context.broadcast(featuresStd)
val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
$(standardization), bcFeaturesStd, regParamL2)
$(standardization), bcFeaturesStd, regParamL2, multinomial = false)
val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
......@@ -416,7 +425,7 @@ class LogisticRegression @Since("1.2.0") (
/*
Note that in Logistic Regression, the objective history (loss + regularization)
is log-likelihood which is invariance under feature standardization. As a result,
is log-likelihood which is invariant under feature standardization. As a result,
the objective history from optimizer is the same as the one in the original space.
*/
val arrayBuilder = mutable.ArrayBuilder.make[Double]
......@@ -559,6 +568,7 @@ class LogisticRegressionModel private[spark] (
/**
* Evaluates the model on a test dataset.
*
* @param dataset Test dataset to evaluate model on.
*/
@Since("2.0.0")
......@@ -681,6 +691,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
val data = sparkSession.read.format("parquet").load(dataPath)
// We will need numClasses, numFeatures in the future for multinomial logreg support.
// TODO: remove numClasses and numFeatures fields?
val Row(numClasses: Int, numFeatures: Int, intercept: Double, coefficients: Vector) =
MLUtils.convertVectorColumnsToML(data, "coefficients")
.select("numClasses", "numFeatures", "intercept", "coefficients")
......@@ -710,6 +721,7 @@ private[classification] class MultiClassSummarizer extends Serializable {
/**
* Add a new label into this MultilabelSummarizer, and update the distinct map.
*
* @param label The label for this data point.
* @param weight The weight of this instances.
* @return This MultilabelSummarizer
......@@ -933,32 +945,310 @@ class BinaryLogisticRegressionSummary private[classification] (
}
/**
* LogisticAggregator computes the gradient and loss for binary logistic loss function, as used
* in binary classification for instances in sparse or dense vector in an online fashion.
*
* Note that multinomial logistic loss is not supported yet!
* LogisticAggregator computes the gradient and loss for binary or multinomial logistic (softmax)
* loss function, as used in classification for instances in sparse or dense vector in an online
* fashion.
*
* Two LogisticAggregator can be merged together to have a summary of loss and gradient of
* Two LogisticAggregators can be merged together to have a summary of loss and gradient of
* the corresponding joint dataset.
*
* For improving the convergence rate during the optimization process and also to prevent against
* features with very large variances exerting an overly large influence during model training,
* packages like R's GLMNET perform the scaling to unit variance and remove the mean in order to
* reduce the condition number. The model is then trained in this scaled space, but returns the
* coefficients in the original scale. See page 9 in
* http://cran.r-project.org/web/packages/glmnet/glmnet.pdf
*
* However, we don't want to apply the [[org.apache.spark.ml.feature.StandardScaler]] on the
* training dataset, and then cache the standardized dataset since it will create a lot of overhead.
* As a result, we perform the scaling implicitly when we compute the objective function (though
* we do not subtract the mean).
*
* Note that there is a difference between multinomial (softmax) and binary loss. The binary case
* uses one outcome class as a "pivot" and regresses the other class against the pivot. In the
* multinomial case, the softmax loss function is used to model each class probability
* independently. Using softmax loss produces `K` sets of coefficients, while using a pivot class
* produces `K - 1` sets of coefficients (a single coefficient vector in the binary case). In the
* binary case, we can say that the coefficients are shared between the positive and negative
* classes. When regularization is applied, multinomial (softmax) loss will produce a result
* different from binary loss since the positive and negative don't share the coefficients while the
* binary regression shares the coefficients between positive and negative.
*
* The following is a mathematical derivation for the multinomial (softmax) loss.
*
* The probability of the multinomial outcome $y$ taking on any of the K possible outcomes is:
*
* <p><blockquote>
* $$
* P(y_i=0|\vec{x}_i, \beta) = \frac{e^{\vec{x}_i^T \vec{\beta}_0}}{\sum_{k=0}^{K-1}
* e^{\vec{x}_i^T \vec{\beta}_k}} \\
* P(y_i=1|\vec{x}_i, \beta) = \frac{e^{\vec{x}_i^T \vec{\beta}_1}}{\sum_{k=0}^{K-1}
* e^{\vec{x}_i^T \vec{\beta}_k}}\\
* P(y_i=K-1|\vec{x}_i, \beta) = \frac{e^{\vec{x}_i^T \vec{\beta}_{K-1}}\,}{\sum_{k=0}^{K-1}
* e^{\vec{x}_i^T \vec{\beta}_k}}
* $$
* </blockquote></p>
*
* The model coefficients $\beta = (\beta_0, \beta_1, \beta_2, ..., \beta_{K-1})$ become a matrix
* which has dimension of $K \times (N+1)$ if the intercepts are added. If the intercepts are not
* added, the dimension will be $K \times N$.
*
* Note that the coefficients in the model above lack identifiability. That is, any constant scalar
* can be added to all of the coefficients and the probabilities remain the same.
*
* <p><blockquote>
* $$
* \begin{align}
* \frac{e^{\vec{x}_i^T \left(\vec{\beta}_0 + \vec{c}\right)}}{\sum_{k=0}^{K-1}
* e^{\vec{x}_i^T \left(\vec{\beta}_k + \vec{c}\right)}}
* = \frac{e^{\vec{x}_i^T \vec{\beta}_0}e^{\vec{x}_i^T \vec{c}}\,}{e^{\vec{x}_i^T \vec{c}}
* \sum_{k=0}^{K-1} e^{\vec{x}_i^T \vec{\beta}_k}}
* = \frac{e^{\vec{x}_i^T \vec{\beta}_0}}{\sum_{k=0}^{K-1} e^{\vec{x}_i^T \vec{\beta}_k}}
* \end{align}
* $$
* </blockquote></p>
*
* However, when regularization is added to the loss function, the coefficients are indeed
* identifiable because there is only one set of coefficients which minimizes the regularization
* term. When no regularization is applied, we choose the coefficients with the minimum L2
* penalty for consistency and reproducibility. For further discussion see:
*
* Friedman, et al. "Regularization Paths for Generalized Linear Models via Coordinate Descent"
*
* The loss of objective function for a single instance of data (we do not include the
* regularization term here for simplicity) can be written as
*
* <p><blockquote>
* $$
* \begin{align}
* \ell\left(\beta, x_i\right) &= -log{P\left(y_i \middle| \vec{x}_i, \beta\right)} \\
* &= log\left(\sum_{k=0}^{K-1}e^{\vec{x}_i^T \vec{\beta}_k}\right) - \vec{x}_i^T \vec{\beta}_y\\
* &= log\left(\sum_{k=0}^{K-1} e^{margins_k}\right) - margins_y
* \end{align}
* $$
* </blockquote></p>
*
* where ${margins}_k = \vec{x}_i^T \vec{\beta}_k$.
*
* For optimization, we have to calculate the first derivative of the loss function, and a simple
* calculation shows that
*
* <p><blockquote>
* $$
* \begin{align}
* \frac{\partial \ell(\beta, \vec{x}_i, w_i)}{\partial \beta_{j, k}}
* &= x_{i,j} \cdot w_i \cdot \left(\frac{e^{\vec{x}_i \cdot \vec{\beta}_k}}{\sum_{k'=0}^{K-1}
* e^{\vec{x}_i \cdot \vec{\beta}_{k'}}\,} - I_{y=k}\right) \\
* &= x_{i, j} \cdot w_i \cdot multiplier_k
* \end{align}
* $$
* </blockquote></p>
*
* where $w_i$ is the sample weight, $I_{y=k}$ is an indicator function
*
* <p><blockquote>
* $$
* I_{y=k} = \begin{cases}
* 1 & y = k \\
* 0 & else
* \end{cases}
* $$
* </blockquote></p>
*
* and
*
* <p><blockquote>
* $$
* multiplier_k = \left(\frac{e^{\vec{x}_i \cdot \vec{\beta}_k}}{\sum_{k=0}^{K-1}
* e^{\vec{x}_i \cdot \vec{\beta}_k}} - I_{y=k}\right)
* $$
* </blockquote></p>
*
* If any of margins is larger than 709.78, the numerical computation of multiplier and loss
* function will suffer from arithmetic overflow. This issue occurs when there are outliers in
* data which are far away from the hyperplane, and this will cause the failing of training once
* infinity is introduced. Note that this is only a concern when max(margins) > 0.
*
* Fortunately, when max(margins) = maxMargin > 0, the loss function and the multiplier can easily
* be rewritten into the following equivalent numerically stable formula.
*
* <p><blockquote>
* $$
* \ell\left(\beta, x\right) = log\left(\sum_{k=0}^{K-1} e^{margins_k - maxMargin}\right) -
* margins_{y} + maxMargin
* $$
* </blockquote></p>
*
* Note that each term, $(margins_k - maxMargin)$ in the exponential is no greater than zero; as a
* result, overflow will not happen with this formula.
*
* For $multiplier$, a similar trick can be applied as the following,
*
* <p><blockquote>
* $$
* multiplier_k = \left(\frac{e^{\vec{x}_i \cdot \vec{\beta}_k - maxMargin}}{\sum_{k'=0}^{K-1}
* e^{\vec{x}_i \cdot \vec{\beta}_{k'} - maxMargin}} - I_{y=k}\right)
* $$
* </blockquote></p>
*
* @param bcCoefficients The broadcast coefficients corresponding to the features.
* @param bcFeaturesStd The broadcast standard deviation values of the features.
* @param numClasses the number of possible outcomes for k classes classification problem in
* Multinomial Logistic Regression.
* @param fitIntercept Whether to fit an intercept term.
* @param multinomial Whether to use multinomial (softmax) or binary loss
*/
private class LogisticAggregator(
val bcCoefficients: Broadcast[Vector],
val bcFeaturesStd: Broadcast[Array[Double]],
private val numFeatures: Int,
bcCoefficients: Broadcast[Vector],
bcFeaturesStd: Broadcast[Array[Double]],
numClasses: Int,
fitIntercept: Boolean) extends Serializable {
fitIntercept: Boolean,
multinomial: Boolean) extends Serializable with Logging {
private val numFeatures = bcFeaturesStd.value.length
private val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures
private val coefficientSize = bcCoefficients.value.size
if (multinomial) {
require(numClasses == coefficientSize / numFeaturesPlusIntercept, s"The number of " +
s"coefficients should be ${numClasses * numFeaturesPlusIntercept} but was $coefficientSize")
} else {
require(coefficientSize == numFeaturesPlusIntercept, s"Expected $numFeaturesPlusIntercept " +
s"coefficients but got $coefficientSize")
require(numClasses == 1 || numClasses == 2, s"Binary logistic aggregator requires numClasses " +
s"in {1, 2} but found $numClasses.")
}
private var weightSum = 0.0
private var lossSum = 0.0
private val gradientSumArray =
Array.ofDim[Double](if (fitIntercept) numFeatures + 1 else numFeatures)
private val gradientSumArray = Array.ofDim[Double](coefficientSize)
if (multinomial && numClasses <= 2) {
logInfo(s"Multinomial logistic regression for binary classification yields separate " +
s"coefficients for positive and negative classes. When no regularization is applied, the" +
s"result will be effectively the same as binary logistic regression. When regularization" +
s"is applied, multinomial loss will produce a result different from binary loss.")
}
/** Update gradient and loss using binary loss function. */
private def binaryUpdateInPlace(
features: Vector,
weight: Double,
label: Double): Unit = {
val localFeaturesStd = bcFeaturesStd.value
val localCoefficients = bcCoefficients.value
val localGradientArray = gradientSumArray
val margin = - {
var sum = 0.0
features.foreachActive { (index, value) =>
if (localFeaturesStd(index) != 0.0 && value != 0.0) {
sum += localCoefficients(index) * value / localFeaturesStd(index)
}
}
if (fitIntercept) sum += localCoefficients(numFeaturesPlusIntercept - 1)
sum
}
val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label)
features.foreachActive { (index, value) =>
if (localFeaturesStd(index) != 0.0 && value != 0.0) {
localGradientArray(index) += multiplier * value / localFeaturesStd(index)
}
}
if (fitIntercept) {
localGradientArray(numFeaturesPlusIntercept - 1) += multiplier
}
if (label > 0) {
// The following is equivalent to log(1 + exp(margin)) but more numerically stable.
lossSum += weight * MLUtils.log1pExp(margin)
} else {
lossSum += weight * (MLUtils.log1pExp(margin) - margin)
}
}
/** Update gradient and loss using multinomial (softmax) loss function. */
private def multinomialUpdateInPlace(
features: Vector,
weight: Double,
label: Double): Unit = {
// TODO: use level 2 BLAS operations
/*
Note: this can still be used when numClasses = 2 for binary
logistic regression without pivoting.
*/
val localFeaturesStd = bcFeaturesStd.value
val localCoefficients = bcCoefficients.value
val localGradientArray = gradientSumArray
// marginOfLabel is margins(label) in the formula
var marginOfLabel = 0.0
var maxMargin = Double.NegativeInfinity
val margins = Array.tabulate(numClasses) { i =>
var margin = 0.0
features.foreachActive { (index, value) =>
if (localFeaturesStd(index) != 0.0 && value != 0.0) {
margin += localCoefficients(i * numFeaturesPlusIntercept + index) *
value / localFeaturesStd(index)
}
}
if (fitIntercept) {
margin += localCoefficients(i * numFeaturesPlusIntercept + numFeatures)
}
if (i == label.toInt) marginOfLabel = margin
if (margin > maxMargin) {
maxMargin = margin
}
margin
}
/**
* When maxMargin > 0, the original formula could cause overflow.
* We address this by subtracting maxMargin from all the margins, so it's guaranteed
* that all of the new margins will be smaller than zero to prevent arithmetic overflow.
*/
val sum = {
var temp = 0.0
if (maxMargin > 0) {
for (i <- 0 until numClasses) {
margins(i) -= maxMargin
temp += math.exp(margins(i))
}
} else {
for (i <- 0 until numClasses) {
temp += math.exp(margins(i))
}
}
temp
}
for (i <- 0 until numClasses) {
val multiplier = math.exp(margins(i)) / sum - {
if (label == i) 1.0 else 0.0
}
features.foreachActive { (index, value) =>
if (localFeaturesStd(index) != 0.0 && value != 0.0) {
localGradientArray(i * numFeaturesPlusIntercept + index) +=
weight * multiplier * value / localFeaturesStd(index)
}
}
if (fitIntercept) {
localGradientArray(i * numFeaturesPlusIntercept + numFeatures) += weight * multiplier
}
}
val loss = if (maxMargin > 0) {
math.log(sum) - marginOfLabel + maxMargin
} else {
math.log(sum) - marginOfLabel
}
lossSum += weight * loss
}
/**
* Add a new training instance to this LogisticAggregator, and update the loss and gradient
......@@ -975,52 +1265,10 @@ private class LogisticAggregator(
if (weight == 0.0) return this
val coefficientsArray = bcCoefficients.value match {
case dv: DenseVector => dv.values
case _ =>
throw new IllegalArgumentException(
"coefficients only supports dense vector" +
s"but got type ${bcCoefficients.value.getClass}.")
}
val localGradientSumArray = gradientSumArray
val featuresStd = bcFeaturesStd.value
numClasses match {
case 2 =>
// For Binary Logistic Regression.
val margin = - {
var sum = 0.0
features.foreachActive { (index, value) =>
if (featuresStd(index) != 0.0 && value != 0.0) {
sum += coefficientsArray(index) * (value / featuresStd(index))
}
}
sum + {
if (fitIntercept) coefficientsArray(numFeatures) else 0.0
}
}
val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label)
features.foreachActive { (index, value) =>
if (featuresStd(index) != 0.0 && value != 0.0) {
localGradientSumArray(index) += multiplier * (value / featuresStd(index))
}
}
if (fitIntercept) {
localGradientSumArray(numFeatures) += multiplier
}
if (label > 0) {
// The following is equivalent to log(1 + exp(margin)) but more numerically stable.
lossSum += weight * MLUtils.log1pExp(margin)
} else {
lossSum += weight * (MLUtils.log1pExp(margin) - margin)
}
case _ =>
new NotImplementedError("LogisticRegression with ElasticNet in ML package " +
"only supports binary classification for now.")
if (multinomial) {
multinomialUpdateInPlace(features, weight, label)
} else {
binaryUpdateInPlace(features, weight, label)
}
weightSum += weight
this
......@@ -1071,8 +1319,8 @@ private class LogisticAggregator(
}
/**
* LogisticCostFun implements Breeze's DiffFunction[T] for a multinomial logistic loss function,
* as used in multi-class classification (it is also used in binary logistic regression).
* LogisticCostFun implements Breeze's DiffFunction[T] for a multinomial (softmax) logistic loss
* function, as used in multi-class classification (it is also used in binary logistic regression).
* It returns the loss and gradient with L2 regularization at a particular point (coefficients).
* It's used in Breeze's convex optimization routines.
*/
......@@ -1082,36 +1330,36 @@ private class LogisticCostFun(
fitIntercept: Boolean,
standardization: Boolean,
bcFeaturesStd: Broadcast[Array[Double]],
regParamL2: Double) extends DiffFunction[BDV[Double]] {
regParamL2: Double,
multinomial: Boolean) extends DiffFunction[BDV[Double]] {
val featuresStd = bcFeaturesStd.value
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
val numFeatures = featuresStd.length
val coeffs = Vectors.fromBreeze(coefficients)
val bcCoeffs = instances.context.broadcast(coeffs)
val n = coeffs.size
val featuresStd = bcFeaturesStd.value
val numFeatures = featuresStd.length
val logisticAggregator = {
val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance)
val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2)
instances.treeAggregate(
new LogisticAggregator(bcCoeffs, bcFeaturesStd, numFeatures, numClasses, fitIntercept)
new LogisticAggregator(bcCoeffs, bcFeaturesStd, numClasses, fitIntercept,
multinomial)
)(seqOp, combOp)
}
val totalGradientArray = logisticAggregator.gradient.toArray
// regVal is the sum of coefficients squares excluding intercept for L2 regularization.
val regVal = if (regParamL2 == 0.0) {
0.0
} else {
var sum = 0.0
coeffs.foreachActive { (index, value) =>
// If `fitIntercept` is true, the last term which is intercept doesn't
// contribute to the regularization.
if (index != numFeatures) {
coeffs.foreachActive { case (index, value) =>
// We do not apply regularization to the intercepts
val isIntercept = fitIntercept && ((index + 1) % (numFeatures + 1) == 0)
if (!isIntercept) {
// The following code will compute the loss of the regularization; also
// the gradient of the regularization, and add back to totalGradientArray.
sum += {
......@@ -1119,13 +1367,18 @@ private class LogisticCostFun(
totalGradientArray(index) += regParamL2 * value
value * value
} else {
if (featuresStd(index) != 0.0) {
val featureIndex = if (fitIntercept) {
index % (numFeatures + 1)
} else {
index % numFeatures
}
if (featuresStd(featureIndex) != 0.0) {
// If `standardization` is false, we still standardize the data
// to improve the rate of convergence; as a result, we have to
// perform this reverse standardization by penalizing each component
// differently to get effectively the same objective function when
// the training dataset is not standardized.
val temp = value / (featuresStd(index) * featuresStd(index))
val temp = value / (featuresStd(featureIndex) * featuresStd(featureIndex))
totalGradientArray(index) += regParamL2 * temp
value * temp
} else {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.classification
import scala.collection.mutable
import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{CachedDiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel
/**
* Params for multinomial logistic (softmax) regression.
*/
private[classification] trait MultinomialLogisticRegressionParams
extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter
with HasFitIntercept with HasTol with HasStandardization with HasWeightCol {
/**
* Set thresholds in multiclass (or binary) classification to adjust the probability of
* predicting each class. Array must have length equal to the number of classes, with values >= 0.
* The class with largest value p/t is predicted, where p is the original probability of that
* class and t is the class' threshold.
*
* @group setParam
*/
def setThresholds(value: Array[Double]): this.type = {
set(thresholds, value)
}
/**
* Get thresholds for binary or multiclass classification.
*
* @group getParam
*/
override def getThresholds: Array[Double] = {
$(thresholds)
}
}
/**
* :: Experimental ::
* Multinomial Logistic (softmax) regression.
*/
@Since("2.1.0")
@Experimental
class MultinomialLogisticRegression @Since("2.1.0") (
@Since("2.1.0") override val uid: String)
extends ProbabilisticClassifier[Vector,
MultinomialLogisticRegression, MultinomialLogisticRegressionModel]
with MultinomialLogisticRegressionParams with DefaultParamsWritable with Logging {
@Since("2.1.0")
def this() = this(Identifiable.randomUID("mlogreg"))
/**
* Set the regularization parameter.
* Default is 0.0.
*
* @group setParam
*/
@Since("2.1.0")
def setRegParam(value: Double): this.type = set(regParam, value)
setDefault(regParam -> 0.0)
/**
* Set the ElasticNet mixing parameter.
* For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
* For 0 < alpha < 1, the penalty is a combination of L1 and L2.
* Default is 0.0 which is an L2 penalty.
*
* @group setParam
*/
@Since("2.1.0")
def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value)
setDefault(elasticNetParam -> 0.0)
/**
* Set the maximum number of iterations.
* Default is 100.
*
* @group setParam
*/
@Since("2.1.0")
def setMaxIter(value: Int): this.type = set(maxIter, value)
setDefault(maxIter -> 100)
/**
* Set the convergence tolerance of iterations.
* Smaller value will lead to higher accuracy with the cost of more iterations.
* Default is 1E-6.
*
* @group setParam
*/
@Since("2.1.0")
def setTol(value: Double): this.type = set(tol, value)
setDefault(tol -> 1E-6)
/**
* Whether to fit an intercept term.
* Default is true.
*
* @group setParam
*/
@Since("2.1.0")
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
setDefault(fitIntercept -> true)
/**
* Whether to standardize the training features before fitting the model.
* The coefficients of models will be always returned on the original scale,
* so it will be transparent for users. Note that with/without standardization,
* the models should always converge to the same solution when no regularization
* is applied. In R's GLMNET package, the default behavior is true as well.
* Default is true.
*
* @group setParam
*/
@Since("2.1.0")
def setStandardization(value: Boolean): this.type = set(standardization, value)
setDefault(standardization -> true)
/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@Since("2.1.0")
def setWeightCol(value: String): this.type = set(weightCol, value)
@Since("2.1.0")
override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
override protected[spark] def train(dataset: Dataset[_]): MultinomialLogisticRegressionModel = {
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
val instr = Instrumentation.create(this, instances)
instr.logParams(regParam, elasticNetParam, standardization, thresholds,
maxIter, tol, fitIntercept)
val (summarizer, labelSummarizer) = {
val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
instance: Instance) =>
(c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight))
val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer),
c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) =>
(c1._1.merge(c2._1), c1._2.merge(c2._2))
instances.treeAggregate(
new MultivariateOnlineSummarizer, new MultiClassSummarizer)(seqOp, combOp)
}
val histogram = labelSummarizer.histogram
val numInvalid = labelSummarizer.countInvalid
val numFeatures = summarizer.mean.size
val numFeaturesPlusIntercept = if (getFitIntercept) numFeatures + 1 else numFeatures
val numClasses = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
case Some(n: Int) =>
require(n >= histogram.length, s"Specified number of classes $n was " +
s"less than the number of unique labels ${histogram.length}")
n
case None => histogram.length
}
instr.logNumClasses(numClasses)
instr.logNumFeatures(numFeatures)
val (coefficients, intercepts, objectiveHistory) = {
if (numInvalid != 0) {
val msg = s"Classification labels should be in {0 to ${numClasses - 1} " +
s"Found $numInvalid invalid labels."
logError(msg)
throw new SparkException(msg)
}
val isConstantLabel = histogram.count(_ != 0) == 1
if ($(fitIntercept) && isConstantLabel) {
// we want to produce a model that will always predict the constant label so all the
// coefficients will be zero, and the constant label class intercept will be +inf
val constantLabelIndex = Vectors.dense(histogram).argmax
(Matrices.sparse(numClasses, numFeatures, Array.fill(numFeatures + 1)(0),
Array.empty[Int], Array.empty[Double]),
Vectors.sparse(numClasses, Seq((constantLabelIndex, Double.PositiveInfinity))),
Array.empty[Double])
} else {
if (!$(fitIntercept) && isConstantLabel) {
logWarning(s"All labels belong to a single class and fitIntercept=false. It's" +
s"a dangerous ground, so the algorithm may not converge.")
}
val featuresStd = summarizer.variance.toArray.map(math.sqrt)
val featuresMean = summarizer.mean.toArray
if (!$(fitIntercept) && (0 until numFeatures).exists { i =>
featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) {
logWarning("Fitting MultinomialLogisticRegressionModel without intercept on dataset " +
"with constant nonzero column, Spark MLlib outputs zero coefficients for constant " +
"nonzero columns. This behavior is the same as R glmnet but different from LIBSVM.")
}
val regParamL1 = $(elasticNetParam) * $(regParam)
val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
val bcFeaturesStd = instances.context.broadcast(featuresStd)
val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
$(standardization), bcFeaturesStd, regParamL2, multinomial = true)
val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
} else {
val standardizationParam = $(standardization)
def regParamL1Fun = (index: Int) => {
// Remove the L1 penalization on the intercept
val isIntercept = $(fitIntercept) && ((index + 1) % numFeaturesPlusIntercept == 0)
if (isIntercept) {
0.0
} else {
if (standardizationParam) {
regParamL1
} else {
val featureIndex = if ($(fitIntercept)) {
index % numFeaturesPlusIntercept
} else {
index % numFeatures
}
// If `standardization` is false, we still standardize the data
// to improve the rate of convergence; as a result, we have to
// perform this reverse standardization by penalizing each component
// differently to get effectively the same objective function when
// the training dataset is not standardized.
if (featuresStd(featureIndex) != 0.0) {
regParamL1 / featuresStd(featureIndex)
} else {
0.0
}
}
}
}
new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
}
val initialCoefficientsWithIntercept = Vectors.zeros(numClasses * numFeaturesPlusIntercept)
if ($(fitIntercept)) {
/*
For multinomial logistic regression, when we initialize the coefficients as zeros,
it will converge faster if we initialize the intercepts such that
it follows the distribution of the labels.
{{{
P(1) = \exp(b_1) / Z
...
P(K) = \exp(b_K) / Z
where Z = \sum_{k=1}^{K} \exp(b_k)
}}}
Since this doesn't have a unique solution, one of the solutions that satisfies the
above equations is
{{{
\exp(b_k) = count_k * \exp(\lambda)
b_k = \log(count_k) * \lambda
}}}
\lambda is a free parameter, so choose the phase \lambda such that the
mean is centered. This yields
{{{
b_k = \log(count_k)
b_k' = b_k - \mean(b_k)
}}}
*/
val rawIntercepts = histogram.map(c => math.log(c + 1)) // add 1 for smoothing
val rawMean = rawIntercepts.sum / rawIntercepts.length
rawIntercepts.indices.foreach { i =>
initialCoefficientsWithIntercept.toArray(i * numFeaturesPlusIntercept + numFeatures) =
rawIntercepts(i) - rawMean
}
}
val states = optimizer.iterations(new CachedDiffFunction(costFun),
initialCoefficientsWithIntercept.asBreeze.toDenseVector)
/*
Note that in Multinomial Logistic Regression, the objective history
(loss + regularization) is log-likelihood which is invariant under feature
standardization. As a result, the objective history from optimizer is the same as the
one in the original space.
*/
val arrayBuilder = mutable.ArrayBuilder.make[Double]
var state: optimizer.State = null
while (states.hasNext) {
state = states.next()
arrayBuilder += state.adjustedValue
}
if (state == null) {
val msg = s"${optimizer.getClass.getName} failed."
logError(msg)
throw new SparkException(msg)
}
bcFeaturesStd.destroy(blocking = false)
/*
The coefficients are trained in the scaled space; we're converting them back to
the original space.
Note that the intercept in scaled space and original space is the same;
as a result, no scaling is needed.
*/
val rawCoefficients = state.x.toArray
val interceptsArray: Array[Double] = if ($(fitIntercept)) {
Array.tabulate(numClasses) { i =>
val coefIndex = (i + 1) * numFeaturesPlusIntercept - 1
rawCoefficients(coefIndex)
}
} else {
Array[Double]()
}
val coefficientArray: Array[Double] = Array.tabulate(numClasses * numFeatures) { i =>
// flatIndex will loop though rawCoefficients, and skip the intercept terms.
val flatIndex = if ($(fitIntercept)) i + i / numFeatures else i
val featureIndex = i % numFeatures
if (featuresStd(featureIndex) != 0.0) {
rawCoefficients(flatIndex) / featuresStd(featureIndex)
} else {
0.0
}
}
val coefficientMatrix =
new DenseMatrix(numClasses, numFeatures, coefficientArray, isTransposed = true)
/*
When no regularization is applied, the coefficients lack identifiability because
we do not use a pivot class. We can add any constant value to the coefficients and
get the same likelihood. So here, we choose the mean centered coefficients for
reproducibility. This method follows the approach in glmnet, described here:
Friedman, et al. "Regularization Paths for Generalized Linear Models via
Coordinate Descent," https://core.ac.uk/download/files/153/6287975.pdf
*/
if ($(regParam) == 0.0) {
val coefficientMean = coefficientMatrix.values.sum / (numClasses * numFeatures)
coefficientMatrix.update(_ - coefficientMean)
}
/*
The intercepts are never regularized, so we always center the mean.
*/
val interceptVector = if (interceptsArray.nonEmpty) {
val interceptMean = interceptsArray.sum / numClasses
interceptsArray.indices.foreach { i => interceptsArray(i) -= interceptMean }
Vectors.dense(interceptsArray)
} else {
Vectors.sparse(numClasses, Seq())
}
(coefficientMatrix, interceptVector, arrayBuilder.result())
}
}
if (handlePersistence) instances.unpersist()
val model = copyValues(
new MultinomialLogisticRegressionModel(uid, coefficients, intercepts, numClasses))
instr.logSuccess(model)
model
}
@Since("2.1.0")
override def copy(extra: ParamMap): MultinomialLogisticRegression = defaultCopy(extra)
}
@Since("2.1.0")
object MultinomialLogisticRegression extends DefaultParamsReadable[MultinomialLogisticRegression] {
@Since("2.1.0")
override def load(path: String): MultinomialLogisticRegression = super.load(path)
}
/**
* :: Experimental ::
* Model produced by [[MultinomialLogisticRegression]].
*/
@Since("2.1.0")
@Experimental
class MultinomialLogisticRegressionModel private[spark] (
@Since("2.1.0") override val uid: String,
@Since("2.1.0") val coefficients: Matrix,
@Since("2.1.0") val intercepts: Vector,
@Since("2.1.0") val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, MultinomialLogisticRegressionModel]
with MultinomialLogisticRegressionParams with MLWritable {
@Since("2.1.0")
override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
@Since("2.1.0")
override def getThresholds: Array[Double] = super.getThresholds
@Since("2.1.0")
override val numFeatures: Int = coefficients.numCols
/** Margin (rawPrediction) for each class label. */
private val margins: Vector => Vector = (features) => {
val m = intercepts.toDense.copy
BLAS.gemv(1.0, coefficients, features, 1.0, m)
m
}
/** Score (probability) for each class label. */
private val scores: Vector => Vector = (features) => {
val m = margins(features)
val maxMarginIndex = m.argmax
val marginArray = m.toArray
val maxMargin = marginArray(maxMarginIndex)
// adjust margins for overflow
val sum = {
var temp = 0.0
var k = 0
while (k < numClasses) {
marginArray(k) = if (maxMargin > 0) {
math.exp(marginArray(k) - maxMargin)
} else {
math.exp(marginArray(k))
}
temp += marginArray(k)
k += 1
}
temp
}
val scores = Vectors.dense(marginArray)
BLAS.scal(1 / sum, scores)
scores
}
/**
* Predict label for the given feature vector.
* The behavior of this can be adjusted using [[thresholds]].
*/
override protected def predict(features: Vector): Double = {
if (isDefined(thresholds)) {
val thresholds: Array[Double] = getThresholds
val probabilities = scores(features).toArray
var argMax = 0
var max = Double.NegativeInfinity
var i = 0
while (i < numClasses) {
if (thresholds(i) == 0.0) {
max = Double.PositiveInfinity
argMax = i
} else {
val scaled = probabilities(i) / thresholds(i)
if (scaled > max) {
max = scaled
argMax = i
}
}
i += 1
}
argMax
} else {
scores(features).argmax
}
}
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
rawPrediction match {
case dv: DenseVector =>
val size = dv.size
val values = dv.values
// get the maximum margin
val maxMarginIndex = rawPrediction.argmax
val maxMargin = rawPrediction(maxMarginIndex)
if (maxMargin == Double.PositiveInfinity) {
var k = 0
while (k < size) {
values(k) = if (k == maxMarginIndex) 1.0 else 0.0
k += 1
}
} else {
val sum = {
var temp = 0.0
var k = 0
while (k < numClasses) {
values(k) = if (maxMargin > 0) {
math.exp(values(k) - maxMargin)
} else {
math.exp(values(k))
}
temp += values(k)
k += 1
}
temp
}
BLAS.scal(1 / sum, dv)
}
dv
case sv: SparseVector =>
throw new RuntimeException("Unexpected error in MultinomialLogisticRegressionModel:" +
" raw2probabilitiesInPlace encountered SparseVector")
}
}
override protected def predictRaw(features: Vector): Vector = margins(features)
@Since("2.1.0")
override def copy(extra: ParamMap): MultinomialLogisticRegressionModel = {
val newModel =
copyValues(
new MultinomialLogisticRegressionModel(uid, coefficients, intercepts, numClasses), extra)
newModel.setParent(parent)
}
/**
* Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance.
*
* This does not save the [[parent]] currently.
*/
@Since("2.1.0")
override def write: MLWriter =
new MultinomialLogisticRegressionModel.MultinomialLogisticRegressionModelWriter(this)
}
@Since("2.1.0")
object MultinomialLogisticRegressionModel extends MLReadable[MultinomialLogisticRegressionModel] {
@Since("2.1.0")
override def read: MLReader[MultinomialLogisticRegressionModel] =
new MultinomialLogisticRegressionModelReader
@Since("2.1.0")
override def load(path: String): MultinomialLogisticRegressionModel = super.load(path)
/** [[MLWriter]] instance for [[MultinomialLogisticRegressionModel]] */
private[MultinomialLogisticRegressionModel]
class MultinomialLogisticRegressionModelWriter(instance: MultinomialLogisticRegressionModel)
extends MLWriter with Logging {
private case class Data(
numClasses: Int,
numFeatures: Int,
intercepts: Vector,
coefficients: Matrix)
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: numClasses, numFeatures, intercept, coefficients
val data = Data(instance.numClasses, instance.numFeatures, instance.intercepts,
instance.coefficients)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
private class MultinomialLogisticRegressionModelReader
extends MLReader[MultinomialLogisticRegressionModel] {
/** Checked against metadata when loading model */
private val className = classOf[MultinomialLogisticRegressionModel].getName
override def load(path: String): MultinomialLogisticRegressionModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.format("parquet").load(dataPath)
.select("numClasses", "numFeatures", "intercepts", "coefficients").head()
val numClasses = data.getAs[Int](data.fieldIndex("numClasses"))
val intercepts = data.getAs[Vector](data.fieldIndex("intercepts"))
val coefficients = data.getAs[Matrix](data.fieldIndex("coefficients"))
val model =
new MultinomialLogisticRegressionModel(metadata.uid, coefficients, intercepts, numClasses)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.classification
import scala.language.existentials
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.LogisticRegressionSuite._
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, Row}
class MultinomialLogisticRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@transient var dataset: Dataset[_] = _
@transient var multinomialDataset: DataFrame = _
private val eps: Double = 1e-5
override def beforeAll(): Unit = {
super.beforeAll()
dataset = {
val nPoints = 100
val coefficients = Array(
-0.57997, 0.912083, -0.371077,
-0.16624, -0.84355, -0.048509)
val xMean = Array(5.843, 3.057)
val xVariance = Array(0.6856, 0.1899)
val testData = generateMultinomialLogisticInput(
coefficients, xMean, xVariance, addIntercept = true, nPoints, 42)
val df = spark.createDataFrame(sc.parallelize(testData, 4))
df.cache()
df
}
multinomialDataset = {
val nPoints = 10000
val coefficients = Array(
-0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
-0.16624, -0.84355, -0.048509, -0.301789, 4.170682)
val xMean = Array(5.843, 3.057, 3.758, 1.199)
val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
val testData = generateMultinomialLogisticInput(
coefficients, xMean, xVariance, addIntercept = true, nPoints, 42)
val df = spark.createDataFrame(sc.parallelize(testData, 4))
df.cache()
df
}
}
/**
* Enable the ignored test to export the dataset into CSV format,
* so we can validate the training accuracy compared with R's glmnet package.
*/
ignore("export test data into CSV format") {
val rdd = multinomialDataset.rdd.map { case Row(label: Double, features: Vector) =>
label + "," + features.toArray.mkString(",")
}.repartition(1)
rdd.saveAsTextFile("target/tmp/MultinomialLogisticRegressionSuite/multinomialDataset")
}
test("params") {
ParamsSuite.checkParams(new MultinomialLogisticRegression)
val model = new MultinomialLogisticRegressionModel("mLogReg",
Matrices.dense(2, 1, Array(0.0, 0.0)), Vectors.dense(0.0, 0.0), 2)
ParamsSuite.checkParams(model)
}
test("multinomial logistic regression: default params") {
val mlr = new MultinomialLogisticRegression
assert(mlr.getLabelCol === "label")
assert(mlr.getFeaturesCol === "features")
assert(mlr.getPredictionCol === "prediction")
assert(mlr.getRawPredictionCol === "rawPrediction")
assert(mlr.getProbabilityCol === "probability")
assert(!mlr.isDefined(mlr.weightCol))
assert(!mlr.isDefined(mlr.thresholds))
assert(mlr.getFitIntercept)
assert(mlr.getStandardization)
val model = mlr.fit(dataset)
model.transform(dataset)
.select("label", "probability", "prediction", "rawPrediction")
.collect()
assert(model.getFeaturesCol === "features")
assert(model.getPredictionCol === "prediction")
assert(model.getRawPredictionCol === "rawPrediction")
assert(model.getProbabilityCol === "probability")
assert(model.intercepts !== Vectors.dense(0.0, 0.0))
assert(model.hasParent)
}
test("multinomial logistic regression with intercept without regularization") {
val trainer1 = (new MultinomialLogisticRegression).setFitIntercept(true)
.setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true).setMaxIter(100)
val trainer2 = (new MultinomialLogisticRegression).setFitIntercept(true)
.setElasticNetParam(0.0).setRegParam(0.0).setStandardization(false)
val model1 = trainer1.fit(multinomialDataset)
val model2 = trainer2.fit(multinomialDataset)
/*
Using the following R code to load the data and train the model using glmnet package.
> library("glmnet")
> data <- read.csv("path", header=FALSE)
> label = as.factor(data$V1)
> features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
> coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0, lambda = 0))
> coefficients
$`0`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
-2.24493379
V2 0.25096771
V3 -0.03915938
V4 0.14766639
V5 0.36810817
$`1`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
0.3778931
V2 -0.3327489
V3 0.8893666
V4 -0.2306948
V5 -0.4442330
$`2`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
1.86704066
V2 0.08178121
V3 -0.85020722
V4 0.08302840
V5 0.07612480
*/
val coefficientsR = new DenseMatrix(3, 4, Array(
0.2509677, -0.0391594, 0.1476664, 0.3681082,
-0.3327489, 0.8893666, -0.2306948, -0.4442330,
0.0817812, -0.8502072, 0.0830284, 0.0761248), isTransposed = true)
val interceptsR = Vectors.dense(-2.2449338, 0.3778931, 1.8670407)
assert(model1.coefficients ~== coefficientsR relTol 0.05)
assert(model1.coefficients.toArray.sum ~== 0.0 absTol eps)
assert(model1.intercepts ~== interceptsR relTol 0.05)
assert(model1.intercepts.toArray.sum ~== 0.0 absTol eps)
assert(model2.coefficients ~== coefficientsR relTol 0.05)
assert(model2.coefficients.toArray.sum ~== 0.0 absTol eps)
assert(model2.intercepts ~== interceptsR relTol 0.05)
assert(model2.intercepts.toArray.sum ~== 0.0 absTol eps)
}
test("multinomial logistic regression without intercept without regularization") {
val trainer1 = (new MultinomialLogisticRegression).setFitIntercept(false)
.setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true)
val trainer2 = (new MultinomialLogisticRegression).setFitIntercept(false)
.setElasticNetParam(0.0).setRegParam(0.0).setStandardization(false)
val model1 = trainer1.fit(multinomialDataset)
val model2 = trainer2.fit(multinomialDataset)
/*
Using the following R code to load the data and train the model using glmnet package.
library("glmnet")
data <- read.csv("path", header=FALSE)
label = as.factor(data$V1)
features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0, lambda = 0,
intercept=F))
> coefficients
$`0`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 0.06992464
V3 -0.36562784
V4 0.12142680
V5 0.32052211
$`1`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 -0.3036269
V3 0.9449630
V4 -0.2271038
V5 -0.4364839
$`2`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 0.2337022
V3 -0.5793351
V4 0.1056770
V5 0.1159618
*/
val coefficientsR = new DenseMatrix(3, 4, Array(
0.0699246, -0.3656278, 0.1214268, 0.3205221,
-0.3036269, 0.9449630, -0.2271038, -0.4364839,
0.2337022, -0.5793351, 0.1056770, 0.1159618), isTransposed = true)
assert(model1.coefficients ~== coefficientsR relTol 0.05)
assert(model1.coefficients.toArray.sum ~== 0.0 absTol eps)
assert(model1.intercepts.toArray === Array.fill(3)(0.0))
assert(model1.intercepts.toArray.sum ~== 0.0 absTol eps)
assert(model2.coefficients ~== coefficientsR relTol 0.05)
assert(model2.coefficients.toArray.sum ~== 0.0 absTol eps)
assert(model2.intercepts.toArray === Array.fill(3)(0.0))
assert(model2.intercepts.toArray.sum ~== 0.0 absTol eps)
}
test("multinomial logistic regression with intercept with L1 regularization") {
// use tighter constraints because OWL-QN solver takes longer to converge
val trainer1 = (new MultinomialLogisticRegression).setFitIntercept(true)
.setElasticNetParam(1.0).setRegParam(0.05).setStandardization(true)
.setMaxIter(300).setTol(1e-10)
val trainer2 = (new MultinomialLogisticRegression).setFitIntercept(true)
.setElasticNetParam(1.0).setRegParam(0.05).setStandardization(false)
.setMaxIter(300).setTol(1e-10)
val model1 = trainer1.fit(multinomialDataset)
val model2 = trainer2.fit(multinomialDataset)
/*
Use the following R code to load the data and train the model using glmnet package.
library("glmnet")
data <- read.csv("path", header=FALSE)
label = as.factor(data$V1)
features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 1,
lambda = 0.05, standardization=T))
coefficients = coef(glmnet(features, label, family="multinomial", alpha = 1, lambda = 0.05,
standardization=F))
> coefficientsStd
$`0`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
-0.68988825
V2 .
V3 .
V4 .
V5 0.09404023
$`1`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
-0.2303499
V2 -0.1232443
V3 0.3258380
V4 -0.1564688
V5 -0.2053965
$`2`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
0.9202381
V2 .
V3 -0.4803856
V4 .
V5 .
> coefficients
$`0`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
-0.44893320
V2 .
V3 .
V4 0.01933812
V5 0.03666044
$`1`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
0.7376760
V2 -0.0577182
V3 .
V4 -0.2081718
V5 -0.1304592
$`2`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
-0.2887428
V2 .
V3 .
V4 .
V5 .
*/
val coefficientsRStd = new DenseMatrix(3, 4, Array(
0.0, 0.0, 0.0, 0.09404023,
-0.1232443, 0.3258380, -0.1564688, -0.2053965,
0.0, -0.4803856, 0.0, 0.0), isTransposed = true)
val interceptsRStd = Vectors.dense(-0.68988825, -0.2303499, 0.9202381)
val coefficientsR = new DenseMatrix(3, 4, Array(
0.0, 0.0, 0.01933812, 0.03666044,
-0.0577182, 0.0, -0.2081718, -0.1304592,
0.0, 0.0, 0.0, 0.0), isTransposed = true)
val interceptsR = Vectors.dense(-0.44893320, 0.7376760, -0.2887428)
assert(model1.coefficients ~== coefficientsRStd absTol 0.02)
assert(model1.intercepts ~== interceptsRStd relTol 0.1)
assert(model1.intercepts.toArray.sum ~== 0.0 absTol eps)
assert(model2.coefficients ~== coefficientsR absTol 0.02)
assert(model2.intercepts ~== interceptsR relTol 0.1)
assert(model2.intercepts.toArray.sum ~== 0.0 absTol eps)
}
test("multinomial logistic regression without intercept with L1 regularization") {
val trainer1 = (new MultinomialLogisticRegression).setFitIntercept(false)
.setElasticNetParam(1.0).setRegParam(0.05).setStandardization(true)
val trainer2 = (new MultinomialLogisticRegression).setFitIntercept(false)
.setElasticNetParam(1.0).setRegParam(0.05).setStandardization(false)
val model1 = trainer1.fit(multinomialDataset)
val model2 = trainer2.fit(multinomialDataset)
/*
Use the following R code to load the data and train the model using glmnet package.
library("glmnet")
data <- read.csv("path", header=FALSE)
label = as.factor(data$V1)
features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 1,
lambda = 0.05, intercept=F, standardization=T))
coefficients = coef(glmnet(features, label, family="multinomial", alpha = 1, lambda = 0.05,
intercept=F, standardization=F))
> coefficientsStd
$`0`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 .
V3 .
V4 .
V5 0.01525105
$`1`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 -0.1502410
V3 0.5134658
V4 -0.1601146
V5 -0.2500232
$`2`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 0.003301875
V3 .
V4 .
V5 .
> coefficients
$`0`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 .
V3 .
V4 .
V5 .
$`1`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 .
V3 0.1943624
V4 -0.1902577
V5 -0.1028789
$`2`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 .
V3 .
V4 .
V5 .
*/
val coefficientsRStd = new DenseMatrix(3, 4, Array(
0.0, 0.0, 0.0, 0.01525105,
-0.1502410, 0.5134658, -0.1601146, -0.2500232,
0.003301875, 0.0, 0.0, 0.0), isTransposed = true)
val coefficientsR = new DenseMatrix(3, 4, Array(
0.0, 0.0, 0.0, 0.0,
0.0, 0.1943624, -0.1902577, -0.1028789,
0.0, 0.0, 0.0, 0.0), isTransposed = true)
assert(model1.coefficients ~== coefficientsRStd absTol 0.01)
assert(model1.intercepts.toArray === Array.fill(3)(0.0))
assert(model1.intercepts.toArray.sum ~== 0.0 absTol eps)
assert(model2.coefficients ~== coefficientsR absTol 0.01)
assert(model2.intercepts.toArray === Array.fill(3)(0.0))
assert(model2.intercepts.toArray.sum ~== 0.0 absTol eps)
}
test("multinomial logistic regression with intercept with L2 regularization") {
val trainer1 = (new MultinomialLogisticRegression).setFitIntercept(true)
.setElasticNetParam(0.0).setRegParam(0.1).setStandardization(true)
val trainer2 = (new MultinomialLogisticRegression).setFitIntercept(true)
.setElasticNetParam(0.0).setRegParam(0.1).setStandardization(false)
val model1 = trainer1.fit(multinomialDataset)
val model2 = trainer2.fit(multinomialDataset)
/*
Use the following R code to load the data and train the model using glmnet package.
library("glmnet")
data <- read.csv("path", header=FALSE)
label = as.factor(data$V1)
features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 0,
lambda = 0.1, intercept=T, standardization=T))
coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0,
lambda = 0.1, intercept=T, standardization=F))
> coefficientsStd
$`0`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
-1.70040424
V2 0.17576070
V3 0.01527894
V4 0.10216108
V5 0.26099531
$`1`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
0.2438590
V2 -0.2238875
V3 0.5967610
V4 -0.1555496
V5 -0.3010479
$`2`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
1.45654525
V2 0.04812679
V3 -0.61203992
V4 0.05338850
V5 0.04005258
> coefficients
$`0`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
-1.65488543
V2 0.15715048
V3 0.01992903
V4 0.12428858
V5 0.22130317
$`1`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
1.1297533
V2 -0.1974768
V3 0.2776373
V4 -0.1869445
V5 -0.2510320
$`2`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
0.52513212
V2 0.04032627
V3 -0.29756637
V4 0.06265594
V5 0.02972883
*/
val coefficientsRStd = new DenseMatrix(3, 4, Array(
0.17576070, 0.01527894, 0.10216108, 0.26099531,
-0.2238875, 0.5967610, -0.1555496, -0.3010479,
0.04812679, -0.61203992, 0.05338850, 0.04005258), isTransposed = true)
val interceptsRStd = Vectors.dense(-1.70040424, 0.2438590, 1.45654525)
val coefficientsR = new DenseMatrix(3, 4, Array(
0.15715048, 0.01992903, 0.12428858, 0.22130317,
-0.1974768, 0.2776373, -0.1869445, -0.2510320,
0.04032627, -0.29756637, 0.06265594, 0.02972883), isTransposed = true)
val interceptsR = Vectors.dense(-1.65488543, 1.1297533, 0.52513212)
assert(model1.coefficients ~== coefficientsRStd relTol 0.05)
assert(model1.intercepts ~== interceptsRStd relTol 0.05)
assert(model1.intercepts.toArray.sum ~== 0.0 absTol eps)
assert(model2.coefficients ~== coefficientsR relTol 0.05)
assert(model2.intercepts ~== interceptsR relTol 0.05)
assert(model2.intercepts.toArray.sum ~== 0.0 absTol eps)
}
test("multinomial logistic regression without intercept with L2 regularization") {
val trainer1 = (new MultinomialLogisticRegression).setFitIntercept(false)
.setElasticNetParam(0.0).setRegParam(0.1).setStandardization(true)
val trainer2 = (new MultinomialLogisticRegression).setFitIntercept(false)
.setElasticNetParam(0.0).setRegParam(0.1).setStandardization(false)
val model1 = trainer1.fit(multinomialDataset)
val model2 = trainer2.fit(multinomialDataset)
/*
Use the following R code to load the data and train the model using glmnet package.
library("glmnet")
data <- read.csv("path", header=FALSE)
label = as.factor(data$V1)
features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 0,
lambda = 0.1, intercept=F, standardization=T))
coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0,
lambda = 0.1, intercept=F, standardization=F))
> coefficientsStd
$`0`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 0.03904171
V3 -0.23354322
V4 0.08288096
V5 0.22706393
$`1`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 -0.2061848
V3 0.6341398
V4 -0.1530059
V5 -0.2958455
$`2`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 0.16714312
V3 -0.40059658
V4 0.07012496
V5 0.06878158
> coefficients
$`0`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 -0.005704542
V3 -0.144466409
V4 0.092080736
V5 0.182927657
$`1`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 -0.08469036
V3 0.38996748
V4 -0.16468436
V5 -0.22522976
$`2`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 0.09039490
V3 -0.24550107
V4 0.07260362
V5 0.04230210
*/
val coefficientsRStd = new DenseMatrix(3, 4, Array(
0.03904171, -0.23354322, 0.08288096, 0.2270639,
-0.2061848, 0.6341398, -0.1530059, -0.2958455,
0.16714312, -0.40059658, 0.07012496, 0.06878158), isTransposed = true)
val coefficientsR = new DenseMatrix(3, 4, Array(
-0.005704542, -0.144466409, 0.092080736, 0.182927657,
-0.08469036, 0.38996748, -0.16468436, -0.22522976,
0.0903949, -0.24550107, 0.07260362, 0.0423021), isTransposed = true)
assert(model1.coefficients ~== coefficientsRStd absTol 0.01)
assert(model1.intercepts.toArray === Array.fill(3)(0.0))
assert(model1.intercepts.toArray.sum ~== 0.0 absTol eps)
assert(model2.coefficients ~== coefficientsR absTol 0.01)
assert(model2.intercepts.toArray === Array.fill(3)(0.0))
assert(model2.intercepts.toArray.sum ~== 0.0 absTol eps)
}
test("multinomial logistic regression with intercept with elasticnet regularization") {
val trainer1 = (new MultinomialLogisticRegression).setFitIntercept(true)
.setElasticNetParam(0.5).setRegParam(0.1).setStandardization(true)
.setMaxIter(300).setTol(1e-10)
val trainer2 = (new MultinomialLogisticRegression).setFitIntercept(true)
.setElasticNetParam(0.5).setRegParam(0.1).setStandardization(false)
.setMaxIter(300).setTol(1e-10)
val model1 = trainer1.fit(multinomialDataset)
val model2 = trainer2.fit(multinomialDataset)
/*
Use the following R code to load the data and train the model using glmnet package.
library("glmnet")
data <- read.csv("path", header=FALSE)
label = as.factor(data$V1)
features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 0.5,
lambda = 0.1, intercept=T, standardization=T))
coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0.5,
lambda = 0.1, intercept=T, standardization=F))
> coefficientsStd
$`0`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
-0.5521819483
V2 0.0003092611
V3 .
V4 .
V5 0.0913818490
$`1`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
-0.27531989
V2 -0.09790029
V3 0.28502034
V4 -0.12416487
V5 -0.16513373
$`2`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
0.8275018
V2 .
V3 -0.4044859
V4 .
V5 .
> coefficients
$`0`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
-0.39876213
V2 .
V3 .
V4 0.02547520
V5 0.03893991
$`1`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
0.61089869
V2 -0.04224269
V3 .
V4 -0.18923970
V5 -0.09104249
$`2`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
-0.2121366
V2 .
V3 .
V4 .
V5 .
*/
val coefficientsRStd = new DenseMatrix(3, 4, Array(
0.0003092611, 0.0, 0.0, 0.091381849,
-0.09790029, 0.28502034, -0.12416487, -0.16513373,
0.0, -0.4044859, 0.0, 0.0), isTransposed = true)
val interceptsRStd = Vectors.dense(-0.5521819483, -0.27531989, 0.8275018)
val coefficientsR = new DenseMatrix(3, 4, Array(
0.0, 0.0, 0.0254752, 0.03893991,
-0.04224269, 0.0, -0.1892397, -0.09104249,
0.0, 0.0, 0.0, 0.0), isTransposed = true)
val interceptsR = Vectors.dense(-0.39876213, 0.61089869, -0.2121366)
assert(model1.coefficients ~== coefficientsRStd absTol 0.01)
assert(model1.intercepts ~== interceptsRStd absTol 0.01)
assert(model1.intercepts.toArray.sum ~== 0.0 absTol eps)
assert(model2.coefficients ~== coefficientsR absTol 0.01)
assert(model2.intercepts ~== interceptsR absTol 0.01)
assert(model2.intercepts.toArray.sum ~== 0.0 absTol eps)
}
test("multinomial logistic regression without intercept with elasticnet regularization") {
val trainer1 = (new MultinomialLogisticRegression).setFitIntercept(false)
.setElasticNetParam(0.5).setRegParam(0.1).setStandardization(true)
.setMaxIter(300).setTol(1e-10)
val trainer2 = (new MultinomialLogisticRegression).setFitIntercept(false)
.setElasticNetParam(0.5).setRegParam(0.1).setStandardization(false)
.setMaxIter(300).setTol(1e-10)
val model1 = trainer1.fit(multinomialDataset)
val model2 = trainer2.fit(multinomialDataset)
/*
Use the following R code to load the data and train the model using glmnet package.
library("glmnet")
data <- read.csv("path", header=FALSE)
label = as.factor(data$V1)
features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
coefficientsStd = coef(glmnet(features, label, family="multinomial", alpha = 0.5,
lambda = 0.1, intercept=F, standardization=T))
coefficients = coef(glmnet(features, label, family="multinomial", alpha = 0.5,
lambda = 0.1, intercept=F, standardization=F))
> coefficientsStd
$`0`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 .
V3 .
V4 .
V5 0.03543706
$`1`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 -0.1187387
V3 0.4025482
V4 -0.1270969
V5 -0.1918386
$`2`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 0.00774365
V3 .
V4 .
V5 .
> coefficients
$`0`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 .
V3 .
V4 .
V5 .
$`1`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 .
V3 0.14666497
V4 -0.16570638
V5 -0.05982875
$`2`
5 x 1 sparse Matrix of class "dgCMatrix"
s0
.
V2 .
V3 .
V4 .
V5 .
*/
val coefficientsRStd = new DenseMatrix(3, 4, Array(
0.0, 0.0, 0.0, 0.03543706,
-0.1187387, 0.4025482, -0.1270969, -0.1918386,
0.0, 0.0, 0.0, 0.00774365), isTransposed = true)
val coefficientsR = new DenseMatrix(3, 4, Array(
0.0, 0.0, 0.0, 0.0,
0.0, 0.14666497, -0.16570638, -0.05982875,
0.0, 0.0, 0.0, 0.0), isTransposed = true)
assert(model1.coefficients ~== coefficientsRStd absTol 0.01)
assert(model1.intercepts.toArray === Array.fill(3)(0.0))
assert(model1.intercepts.toArray.sum ~== 0.0 absTol eps)
assert(model2.coefficients ~== coefficientsR absTol 0.01)
assert(model2.intercepts.toArray === Array.fill(3)(0.0))
assert(model2.intercepts.toArray.sum ~== 0.0 absTol eps)
}
/*
test("multinomial logistic regression with intercept with strong L1 regularization") {
// TODO: implement this test to check that the priors on the intercepts are correct
// TODO: when initial model becomes available
}
*/
test("prediction") {
val model = new MultinomialLogisticRegressionModel("mLogReg",
Matrices.dense(3, 2, Array(0.0, 0.0, 0.0, 1.0, 2.0, 3.0)),
Vectors.dense(0.0, 0.0, 0.0), 3)
val overFlowData = spark.createDataFrame(Seq(
LabeledPoint(1.0, Vectors.dense(0.0, 1000.0)),
LabeledPoint(1.0, Vectors.dense(0.0, -1.0))
))
val results = model.transform(overFlowData).select("rawPrediction", "probability").collect()
// probabilities are correct when margins have to be adjusted
val raw1 = results(0).getAs[Vector](0)
val prob1 = results(0).getAs[Vector](1)
assert(raw1 === Vectors.dense(1000.0, 2000.0, 3000.0))
assert(prob1 ~== Vectors.dense(0.0, 0.0, 1.0) absTol eps)
// probabilities are correct when margins don't have to be adjusted
val raw2 = results(1).getAs[Vector](0)
val prob2 = results(1).getAs[Vector](1)
assert(raw2 === Vectors.dense(-1.0, -2.0, -3.0))
assert(prob2 ~== Vectors.dense(0.66524096, 0.24472847, 0.09003057) relTol eps)
}
test("multinomial logistic regression: Predictor, Classifier methods") {
val mlr = new MultinomialLogisticRegression
val model = mlr.fit(dataset)
assert(model.numClasses === 3)
val numFeatures = dataset.select("features").first().getAs[Vector](0).size
assert(model.numFeatures === numFeatures)
val results = model.transform(dataset)
// check that raw prediction is coefficients dot features + intercept
results.select("rawPrediction", "features").collect().foreach {
case Row(raw: Vector, features: Vector) =>
assert(raw.size === 3)
val margins = Array.tabulate(3) { k =>
var margin = 0.0
features.foreachActive { (index, value) =>
margin += value * model.coefficients(k, index)
}
margin += model.intercepts(k)
margin
}
assert(raw ~== Vectors.dense(margins) relTol eps)
}
// Compare rawPrediction with probability
results.select("rawPrediction", "probability").collect().foreach {
case Row(raw: Vector, prob: Vector) =>
assert(raw.size === 3)
assert(prob.size === 3)
val max = raw.toArray.max
val subtract = if (max > 0) max else 0.0
val sum = raw.toArray.map(x => math.exp(x - subtract)).sum
val probFromRaw0 = math.exp(raw(0) - subtract) / sum
val probFromRaw1 = math.exp(raw(1) - subtract) / sum
assert(prob(0) ~== probFromRaw0 relTol eps)
assert(prob(1) ~== probFromRaw1 relTol eps)
assert(prob(2) ~== 1.0 - probFromRaw1 - probFromRaw0 relTol eps)
}
// Compare prediction with probability
results.select("prediction", "probability").collect().foreach {
case Row(pred: Double, prob: Vector) =>
val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
assert(pred == predFromProb)
}
}
test("multinomial logistic regression coefficients should be centered") {
val mlr = new MultinomialLogisticRegression().setMaxIter(1)
val model = mlr.fit(dataset)
assert(model.intercepts.toArray.sum ~== 0.0 absTol 1e-6)
assert(model.coefficients.toArray.sum ~== 0.0 absTol 1e-6)
}
test("numClasses specified in metadata/inferred") {
val mlr = new MultinomialLogisticRegression().setMaxIter(1)
// specify more classes than unique label values
val labelMeta = NominalAttribute.defaultAttr.withName("label").withNumValues(4).toMetadata()
val df = dataset.select(dataset("label").as("label", labelMeta), dataset("features"))
val model1 = mlr.fit(df)
assert(model1.numClasses === 4)
assert(model1.intercepts.size === 4)
// specify two classes when there are really three
val labelMeta1 = NominalAttribute.defaultAttr.withName("label").withNumValues(2).toMetadata()
val df1 = dataset.select(dataset("label").as("label", labelMeta1), dataset("features"))
val thrown = intercept[IllegalArgumentException] {
mlr.fit(df1)
}
assert(thrown.getMessage.contains("less than the number of unique labels"))
// mlr should infer the number of classes if not specified
val model3 = mlr.fit(dataset)
assert(model3.numClasses === 3)
}
test("all labels the same") {
val constantData = spark.createDataFrame(Seq(
LabeledPoint(4.0, Vectors.dense(0.0)),
LabeledPoint(4.0, Vectors.dense(1.0)),
LabeledPoint(4.0, Vectors.dense(2.0)))
)
val mlr = new MultinomialLogisticRegression
val model = mlr.fit(constantData)
val results = model.transform(constantData)
results.select("rawPrediction", "probability", "prediction").collect().foreach {
case Row(raw: Vector, prob: Vector, pred: Double) =>
assert(raw === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, Double.PositiveInfinity)))
assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0)))
assert(pred === 4.0)
}
// force the model to be trained with only one class
val constantZeroData = spark.createDataFrame(Seq(
LabeledPoint(0.0, Vectors.dense(0.0)),
LabeledPoint(0.0, Vectors.dense(1.0)),
LabeledPoint(0.0, Vectors.dense(2.0)))
)
val modelZeroLabel = mlr.setFitIntercept(false).fit(constantZeroData)
val resultsZero = modelZeroLabel.transform(constantZeroData)
resultsZero.select("rawPrediction", "probability", "prediction").collect().foreach {
case Row(raw: Vector, prob: Vector, pred: Double) =>
assert(prob === Vectors.dense(Array(1.0)))
assert(pred === 0.0)
}
// ensure that the correct value is predicted when numClasses passed through metadata
val labelMeta = NominalAttribute.defaultAttr.withName("label").withNumValues(6).toMetadata()
val constantDataWithMetadata = constantData
.select(constantData("label").as("label", labelMeta), constantData("features"))
val modelWithMetadata = mlr.setFitIntercept(true).fit(constantDataWithMetadata)
val resultsWithMetadata = modelWithMetadata.transform(constantDataWithMetadata)
resultsWithMetadata.select("rawPrediction", "probability", "prediction").collect().foreach {
case Row(raw: Vector, prob: Vector, pred: Double) =>
assert(raw === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, Double.PositiveInfinity, 0.0)))
assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0, 0.0)))
assert(pred === 4.0)
}
// TODO: check num iters is zero when it become available in the model
}
test("weighted data") {
val numClasses = 5
val numPoints = 40
val outlierData = MLTestingUtils.genClassificationInstancesWithWeightedOutliers(spark,
numClasses, numPoints)
val testData = spark.createDataFrame(Array.tabulate[LabeledPoint](numClasses) { i =>
LabeledPoint(i.toDouble, Vectors.dense(i.toDouble))
})
val mlr = new MultinomialLogisticRegression().setWeightCol("weight")
val model = mlr.fit(outlierData)
val results = model.transform(testData).select("label", "prediction").collect()
// check that the predictions are the one to one mapping
results.foreach { case Row(label: Double, pred: Double) =>
assert(label === pred)
}
val (overSampledData, weightedData) =
MLTestingUtils.genEquivalentOversampledAndWeightedInstances(outlierData, "label", "features",
42L)
val weightedModel = mlr.fit(weightedData)
val overSampledModel = mlr.setWeightCol("").fit(overSampledData)
assert(weightedModel.coefficients ~== overSampledModel.coefficients relTol 0.01)
}
test("thresholds prediction") {
val mlr = new MultinomialLogisticRegression
val model = mlr.fit(dataset)
val basePredictions = model.transform(dataset).select("prediction").collect()
// should predict all zeros
model.setThresholds(Array(1, 1000, 1000))
val zeroPredictions = model.transform(dataset).select("prediction").collect()
assert(zeroPredictions.forall(_.getDouble(0) === 0.0))
// should predict all ones
model.setThresholds(Array(1000, 1, 1000))
val onePredictions = model.transform(dataset).select("prediction").collect()
assert(onePredictions.forall(_.getDouble(0) === 1.0))
// should predict all twos
model.setThresholds(Array(1000, 1000, 1))
val twoPredictions = model.transform(dataset).select("prediction").collect()
assert(twoPredictions.forall(_.getDouble(0) === 2.0))
// constant threshold scaling is the same as no thresholds
model.setThresholds(Array(1000, 1000, 1000))
val scaledPredictions = model.transform(dataset).select("prediction").collect()
assert(scaledPredictions.zip(basePredictions).forall { case (scaled, base) =>
scaled.getDouble(0) === base.getDouble(0)
})
}
test("read/write") {
def checkModelData(
model: MultinomialLogisticRegressionModel,
model2: MultinomialLogisticRegressionModel): Unit = {
assert(model.intercepts === model2.intercepts)
assert(model.coefficients.toArray === model2.coefficients.toArray)
assert(model.numClasses === model2.numClasses)
assert(model.numFeatures === model2.numFeatures)
}
val mlr = new MultinomialLogisticRegression()
testEstimatorAndModelReadWrite(mlr, dataset,
MultinomialLogisticRegressionSuite.allParamSettings,
checkModelData)
}
test("should support all NumericType labels and not support other types") {
val mlr = new MultinomialLogisticRegression().setMaxIter(1)
MLTestingUtils
.checkNumericTypes[MultinomialLogisticRegressionModel, MultinomialLogisticRegression](
mlr, spark) { (expected, actual) =>
assert(expected.intercepts === actual.intercepts)
assert(expected.coefficients.toArray === actual.coefficients.toArray)
}
}
}
object MultinomialLogisticRegressionSuite {
/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.
*/
val allParamSettings: Map[String, Any] = ProbabilisticClassifierSuite.allParamSettings ++ Map(
"probabilityCol" -> "myProbability",
"thresholds" -> Array(0.4, 0.6),
"regParam" -> 0.01,
"elasticNetParam" -> 0.1,
"maxIter" -> 2, // intentionally small
"fitIntercept" -> true,
"tol" -> 0.8,
"standardization" -> false
)
}
......@@ -19,12 +19,14 @@ package org.apache.spark.ml.util
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.recommendation.{ALS, ALSModel}
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
......@@ -179,4 +181,47 @@ object MLTestingUtils extends SparkFunSuite {
.map(t => t -> df.select(col(labelColName).cast(t), col(predictionColName)))
.toMap
}
def genClassificationInstancesWithWeightedOutliers(
spark: SparkSession,
numClasses: Int,
numInstances: Int): DataFrame = {
val data = Array.tabulate[Instance](numInstances) { i =>
val feature = i % numClasses
if (i < numInstances / 3) {
// give large weights to minority of data with 1 to 1 mapping feature to label
Instance(feature, 1.0, Vectors.dense(feature))
} else {
// give small weights to majority of data points with reverse mapping
Instance(numClasses - feature - 1, 0.01, Vectors.dense(feature))
}
}
val labelMeta =
NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses).toMetadata()
spark.createDataFrame(data).select(col("label").as("label", labelMeta), col("weight"),
col("features"))
}
def genEquivalentOversampledAndWeightedInstances(
data: DataFrame,
labelCol: String,
featuresCol: String,
seed: Long): (DataFrame, DataFrame) = {
import data.sparkSession.implicits._
val rng = scala.util.Random
rng.setSeed(seed)
val sample: () => Int = () => rng.nextInt(10) + 1
val sampleUDF = udf(sample)
val rawData = data.select(labelCol, featuresCol).withColumn("samples", sampleUDF())
val overSampledData = rawData.rdd.flatMap {
case Row(label: Double, features: Vector, n: Int) =>
Iterator.fill(n)(Instance(label, 1.0, features))
}.toDF()
rng.setSeed(seed)
val weightedData = rawData.rdd.map {
case Row(label: Double, features: Vector, n: Int) =>
Instance(label, n.toDouble, features)
}.toDF()
(overSampledData, weightedData)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment