Skip to content
Snippets Groups Projects
Commit 287bea13 authored by sethah's avatar sethah Committed by DB Tsai
Browse files

[SPARK-7159][ML] Add multiclass logistic regression to Spark ML

## What changes were proposed in this pull request?

This patch adds a new estimator/transformer `MultinomialLogisticRegression` to spark ML.

JIRA: [SPARK-7159](https://issues.apache.org/jira/browse/SPARK-7159)

## How was this patch tested?

Added new test suite `MultinomialLogisticRegressionSuite`.

## Approach

### Do not use a "pivot" class in the algorithm formulation

Many implementations of multinomial logistic regression treat the problem as K - 1 independent binary logistic regression models where K is the number of possible outcomes in the output variable. In this case, one outcome is chosen as a "pivot" and the other K - 1 outcomes are regressed against the pivot. This is somewhat undesirable since the coefficients returned will be different for different choices of pivot variables. An alternative approach to the problem models class conditional probabilites using the softmax function and will return uniquely identifiable coefficients (assuming regularization is applied). This second approach is used in R's glmnet and was also recommended by dbtsai.

### Separate multinomial logistic regression and binary logistic regression

The initial design makes multinomial logistic regression a separate estimator/transformer than the existing LogisticRegression estimator/transformer. An alternative design would be to merge them into one.

**Arguments for:**

* The multinomial case without pivot is distinctly different than the current binary case since the binary case uses a pivot class.
* The current logistic regression model in ML uses a vector of coefficients and a scalar intercept. In the multinomial case, we require a matrix of coefficients and a vector of intercepts. There are potential workarounds for this issue if we were to merge the two estimators, but none are particularly elegant.

**Arguments against:**

* It may be inconvenient for users to have to switch the estimator class when transitioning between binary and multiclass (although the new multinomial estimator can be used for two class outcomes).
* Some portions of the code are repeated.

This is a major design point and warrants more discussion.

### Mean centering

When no regularization is applied, the coefficients will not be uniquely identifiable. This is not hard to show and is discussed in further detail [here](https://core.ac.uk/download/files/153/6287975.pdf). R's glmnet deals with this by choosing the minimum l2 regularized solution (i.e. mean centering). Additionally, the intercepts are never regularized so they are always mean centered. This is the approach taken in this PR as well.

### Feature scaling

In current ML logistic regression, the features are always standardized when running the optimization algorithm. They are always returned to the user in the original feature space, however. This same approach is maintained in this patch as well, but the implementation details are different. In ML logistic regression, the unregularized feature values are divided by the column standard deviation in every gradient update iteration. In contrast, MLlib transforms the entire input dataset to the scaled space _before_ optimizaton. In ML, this means that `numFeatures * numClasses` extra scalar divisions are required in every iteration. Performance testing shows that this has significant (4x in some cases) slow downs in each iteration. This can be avoided by transforming the input to the scaled space ala MLlib once, before iteration begins. This does add some overhead initially, but can make significant time savings in some cases.

One issue with this approach is that if the input data is already cached, there may not be enough memory to cache the transformed data, which would make the algorithm _much_ slower. The tradeoffs here merit more discussion.

### Specifying and inferring the number of outcome classes

The estimator checks the dataframe label column for metadata which specifies the number of values. If they are not specified, the length of the `histogram` variable is used, which is essentially the maximum value found in the column. The assumption then, is that the labels are zero-indexed when they are provided to the algorithm.

## Performance

Below are some performance tests I have run so far. I am happy to add more cases or trials if we deem them necessary.

Test cluster: 4 bare metal nodes, 128 GB RAM each, 48 cores each

Notes:

* Time in units of seconds
* Metric is classification accuracy

| algo   |   elasticNetParam | fitIntercept   |   metric |   maxIter |   numPoints |   numClasses |   numFeatures |    time | standardization   |   regParam |
|--------|-------------------|----------------|----------|-----------|-------------|--------------|---------------|---------|-------------------|------------|
| ml     |                 0 | true           | 0.746415 |        30 |      100000 |            3 |        100000 | 327.923 | true              |          0 |
| mllib  |                 0 | true           | 0.743785 |        30 |      100000 |            3 |        100000 | 390.217 | true              |          0 |

| algo   |   elasticNetParam | fitIntercept   |   metric |   maxIter |   numPoints |   numClasses |   numFeatures |    time | standardization   |   regParam |
|--------|-------------------|----------------|----------|-----------|-------------|--------------|---------------|---------|-------------------|------------|
| ml     |                 0 | true           | 0.973238 |        30 |     2000000 |            3 |         10000 | 385.476 | true              |          0 |
| mllib  |                 0 | true           | 0.949828 |        30 |     2000000 |            3 |         10000 | 550.403 | true              |          0 |

| algo   |   elasticNetParam | fitIntercept   |   metric |   maxIter |   numPoints |   numClasses |   numFeatures |    time | standardization   |   regParam |
|--------|-------------------|----------------|----------|-----------|-------------|--------------|---------------|---------|-------------------|------------|
| mllib  |                 0 | true           | 0.864358 |        30 |     2000000 |            3 |         10000 | 543.359 | true              |        0.1 |
| ml     |                 0 | true           | 0.867418 |        30 |     2000000 |            3 |         10000 | 401.955 | true              |        0.1 |

| algo   |   elasticNetParam | fitIntercept   |   metric |   maxIter |   numPoints |   numClasses |   numFeatures |    time | standardization   |   regParam |
|--------|-------------------|----------------|----------|-----------|-------------|--------------|---------------|---------|-------------------|------------|
| ml     |                 1 | true           | 0.807449 |        30 |     2000000 |            3 |         10000 | 334.892 | true              |       0.05 |

| algo   |   elasticNetParam | fitIntercept   |   metric |   maxIter |   numPoints |   numClasses |   numFeatures |    time | standardization   |   regParam |
|--------|-------------------|----------------|----------|-----------|-------------|--------------|---------------|---------|-------------------|------------|
| ml     |                 0 | true           | 0.602006 |        30 |     2000000 |          500 |           100 | 112.319 | true              |          0 |
| mllib  |                 0 | true           | 0.567226 |        30 |     2000000 |          500 |           100 | 263.768 | true              |          0 |e           | 0.567226 |        30 |     2000000 |          500 |           100 | 263.768 | true              |          0 |

## References

Friedman, et al. ["Regularization Paths for Generalized Linear Models via Coordinate Descent"](https://core.ac.uk/download/files/153/6287975.pdf)
[http://web.stanford.edu/~hastie/glmnet/glmnet_alpha.html](http://web.stanford.edu/~hastie/glmnet/glmnet_alpha.html)

## Follow up items
* Consider using level 2 BLAS routines in the gradient computations - [SPARK-17134](https://issues.apache.org/jira/browse/SPARK-17134)
* Add model summary for MLOR - [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139)
* Add initial model to MLOR and add test for intercept priors - [SPARK-17140](https://issues.apache.org/jira/browse/SPARK-17140)
* Python API - [SPARK-17138](https://issues.apache.org/jira/browse/SPARK-17138)
* Consider changing the tree aggregation level for MLOR/BLOR or making it user configurable to avoid memory problems with high dimensional data - [SPARK-17090](https://issues.apache.org/jira/browse/SPARK-17090)
* Refactor helper classes out of `LogisticRegression.scala` - [SPARK-17135](https://issues.apache.org/jira/browse/SPARK-17135)
* Design optimizer interface for added flexibility in ML algos - [SPARK-17136](https://issues.apache.org/jira/browse/SPARK-17136)
* Support compressing the coefficients and intercepts for MLOR models - [SPARK-17137](https://issues.apache.org/jira/browse/SPARK-17137)

Author: sethah <seth.hendrickson16@gmail.com>

Closes #13796 from sethah/SPARK-7159_M.
parent b482c09f
No related branches found
No related tags found
No related merge requests found
...@@ -63,6 +63,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas ...@@ -63,6 +63,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
* equivalent. * equivalent.
* *
* Default is 0.5. * Default is 0.5.
*
* @group setParam * @group setParam
*/ */
def setThreshold(value: Double): this.type = { def setThreshold(value: Double): this.type = {
...@@ -131,6 +132,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas ...@@ -131,6 +132,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
/** /**
* If [[threshold]] and [[thresholds]] are both set, ensures they are consistent. * If [[threshold]] and [[thresholds]] are both set, ensures they are consistent.
*
* @throws IllegalArgumentException if [[threshold]] and [[thresholds]] are not equivalent * @throws IllegalArgumentException if [[threshold]] and [[thresholds]] are not equivalent
*/ */
protected def checkThresholdConsistency(): Unit = { protected def checkThresholdConsistency(): Unit = {
...@@ -153,8 +155,8 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas ...@@ -153,8 +155,8 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
/** /**
* Logistic regression. * Logistic regression.
* Currently, this class only supports binary classification. It will support multiclass * Currently, this class only supports binary classification. For multiclass classification,
* in the future. * use [[MultinomialLogisticRegression]]
*/ */
@Since("1.2.0") @Since("1.2.0")
class LogisticRegression @Since("1.2.0") ( class LogisticRegression @Since("1.2.0") (
...@@ -168,6 +170,7 @@ class LogisticRegression @Since("1.2.0") ( ...@@ -168,6 +170,7 @@ class LogisticRegression @Since("1.2.0") (
/** /**
* Set the regularization parameter. * Set the regularization parameter.
* Default is 0.0. * Default is 0.0.
*
* @group setParam * @group setParam
*/ */
@Since("1.2.0") @Since("1.2.0")
...@@ -179,6 +182,7 @@ class LogisticRegression @Since("1.2.0") ( ...@@ -179,6 +182,7 @@ class LogisticRegression @Since("1.2.0") (
* For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
* For 0 < alpha < 1, the penalty is a combination of L1 and L2. * For 0 < alpha < 1, the penalty is a combination of L1 and L2.
* Default is 0.0 which is an L2 penalty. * Default is 0.0 which is an L2 penalty.
*
* @group setParam * @group setParam
*/ */
@Since("1.4.0") @Since("1.4.0")
...@@ -188,6 +192,7 @@ class LogisticRegression @Since("1.2.0") ( ...@@ -188,6 +192,7 @@ class LogisticRegression @Since("1.2.0") (
/** /**
* Set the maximum number of iterations. * Set the maximum number of iterations.
* Default is 100. * Default is 100.
*
* @group setParam * @group setParam
*/ */
@Since("1.2.0") @Since("1.2.0")
...@@ -198,6 +203,7 @@ class LogisticRegression @Since("1.2.0") ( ...@@ -198,6 +203,7 @@ class LogisticRegression @Since("1.2.0") (
* Set the convergence tolerance of iterations. * Set the convergence tolerance of iterations.
* Smaller value will lead to higher accuracy with the cost of more iterations. * Smaller value will lead to higher accuracy with the cost of more iterations.
* Default is 1E-6. * Default is 1E-6.
*
* @group setParam * @group setParam
*/ */
@Since("1.4.0") @Since("1.4.0")
...@@ -207,6 +213,7 @@ class LogisticRegression @Since("1.2.0") ( ...@@ -207,6 +213,7 @@ class LogisticRegression @Since("1.2.0") (
/** /**
* Whether to fit an intercept term. * Whether to fit an intercept term.
* Default is true. * Default is true.
*
* @group setParam * @group setParam
*/ */
@Since("1.4.0") @Since("1.4.0")
...@@ -220,6 +227,7 @@ class LogisticRegression @Since("1.2.0") ( ...@@ -220,6 +227,7 @@ class LogisticRegression @Since("1.2.0") (
* the models should be always converged to the same solution when no regularization * the models should be always converged to the same solution when no regularization
* is applied. In R's GLMNET package, the default behavior is true as well. * is applied. In R's GLMNET package, the default behavior is true as well.
* Default is true. * Default is true.
*
* @group setParam * @group setParam
*/ */
@Since("1.5.0") @Since("1.5.0")
...@@ -233,9 +241,10 @@ class LogisticRegression @Since("1.2.0") ( ...@@ -233,9 +241,10 @@ class LogisticRegression @Since("1.2.0") (
override def getThreshold: Double = super.getThreshold override def getThreshold: Double = super.getThreshold
/** /**
* Whether to over-/under-sample training instances according to the given weights in weightCol. * Sets the value of param [[weightCol]].
* If not set or empty String, all instances are treated equally (weight 1.0). * If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one. * Default is not set, so all instances have weight one.
*
* @group setParam * @group setParam
*/ */
@Since("1.6.0") @Since("1.6.0")
...@@ -310,12 +319,15 @@ class LogisticRegression @Since("1.2.0") ( ...@@ -310,12 +319,15 @@ class LogisticRegression @Since("1.2.0") (
throw new SparkException(msg) throw new SparkException(msg)
} }
val isConstantLabel = histogram.count(_ != 0) == 1
if (numClasses > 2) { if (numClasses > 2) {
val msg = s"Currently, LogisticRegression with ElasticNet in ML package only supports " + val msg = s"LogisticRegression with ElasticNet in ML package only supports " +
s"binary classification. Found $numClasses in the input dataset." s"binary classification. Found $numClasses in the input dataset. Consider using " +
s"MultinomialLogisticRegression instead."
logError(msg) logError(msg)
throw new SparkException(msg) throw new SparkException(msg)
} else if ($(fitIntercept) && numClasses == 2 && histogram(0) == 0.0) { } else if ($(fitIntercept) && numClasses == 2 && isConstantLabel) {
logWarning(s"All labels are one and fitIntercept=true, so the coefficients will be " + logWarning(s"All labels are one and fitIntercept=true, so the coefficients will be " +
s"zeros and the intercept will be positive infinity; as a result, " + s"zeros and the intercept will be positive infinity; as a result, " +
s"training is not needed.") s"training is not needed.")
...@@ -326,12 +338,9 @@ class LogisticRegression @Since("1.2.0") ( ...@@ -326,12 +338,9 @@ class LogisticRegression @Since("1.2.0") (
s"training is not needed.") s"training is not needed.")
(Vectors.sparse(numFeatures, Seq()), Double.NegativeInfinity, Array.empty[Double]) (Vectors.sparse(numFeatures, Seq()), Double.NegativeInfinity, Array.empty[Double])
} else { } else {
if (!$(fitIntercept) && numClasses == 2 && histogram(0) == 0.0) { if (!$(fitIntercept) && isConstantLabel) {
logWarning(s"All labels are one and fitIntercept=false. It's a dangerous ground, " + logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " +
s"so the algorithm may not converge.") s"dangerous ground, so the algorithm may not converge.")
} else if (!$(fitIntercept) && numClasses == 1) {
logWarning(s"All labels are zero and fitIntercept=false. It's a dangerous ground, " +
s"so the algorithm may not converge.")
} }
val featuresMean = summarizer.mean.toArray val featuresMean = summarizer.mean.toArray
...@@ -349,7 +358,7 @@ class LogisticRegression @Since("1.2.0") ( ...@@ -349,7 +358,7 @@ class LogisticRegression @Since("1.2.0") (
val bcFeaturesStd = instances.context.broadcast(featuresStd) val bcFeaturesStd = instances.context.broadcast(featuresStd)
val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept), val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
$(standardization), bcFeaturesStd, regParamL2) $(standardization), bcFeaturesStd, regParamL2, multinomial = false)
val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) { val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
...@@ -416,7 +425,7 @@ class LogisticRegression @Since("1.2.0") ( ...@@ -416,7 +425,7 @@ class LogisticRegression @Since("1.2.0") (
/* /*
Note that in Logistic Regression, the objective history (loss + regularization) Note that in Logistic Regression, the objective history (loss + regularization)
is log-likelihood which is invariance under feature standardization. As a result, is log-likelihood which is invariant under feature standardization. As a result,
the objective history from optimizer is the same as the one in the original space. the objective history from optimizer is the same as the one in the original space.
*/ */
val arrayBuilder = mutable.ArrayBuilder.make[Double] val arrayBuilder = mutable.ArrayBuilder.make[Double]
...@@ -559,6 +568,7 @@ class LogisticRegressionModel private[spark] ( ...@@ -559,6 +568,7 @@ class LogisticRegressionModel private[spark] (
/** /**
* Evaluates the model on a test dataset. * Evaluates the model on a test dataset.
*
* @param dataset Test dataset to evaluate model on. * @param dataset Test dataset to evaluate model on.
*/ */
@Since("2.0.0") @Since("2.0.0")
...@@ -681,6 +691,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { ...@@ -681,6 +691,7 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
val data = sparkSession.read.format("parquet").load(dataPath) val data = sparkSession.read.format("parquet").load(dataPath)
// We will need numClasses, numFeatures in the future for multinomial logreg support. // We will need numClasses, numFeatures in the future for multinomial logreg support.
// TODO: remove numClasses and numFeatures fields?
val Row(numClasses: Int, numFeatures: Int, intercept: Double, coefficients: Vector) = val Row(numClasses: Int, numFeatures: Int, intercept: Double, coefficients: Vector) =
MLUtils.convertVectorColumnsToML(data, "coefficients") MLUtils.convertVectorColumnsToML(data, "coefficients")
.select("numClasses", "numFeatures", "intercept", "coefficients") .select("numClasses", "numFeatures", "intercept", "coefficients")
...@@ -710,6 +721,7 @@ private[classification] class MultiClassSummarizer extends Serializable { ...@@ -710,6 +721,7 @@ private[classification] class MultiClassSummarizer extends Serializable {
/** /**
* Add a new label into this MultilabelSummarizer, and update the distinct map. * Add a new label into this MultilabelSummarizer, and update the distinct map.
*
* @param label The label for this data point. * @param label The label for this data point.
* @param weight The weight of this instances. * @param weight The weight of this instances.
* @return This MultilabelSummarizer * @return This MultilabelSummarizer
...@@ -933,32 +945,310 @@ class BinaryLogisticRegressionSummary private[classification] ( ...@@ -933,32 +945,310 @@ class BinaryLogisticRegressionSummary private[classification] (
} }
/** /**
* LogisticAggregator computes the gradient and loss for binary logistic loss function, as used * LogisticAggregator computes the gradient and loss for binary or multinomial logistic (softmax)
* in binary classification for instances in sparse or dense vector in an online fashion. * loss function, as used in classification for instances in sparse or dense vector in an online
* * fashion.
* Note that multinomial logistic loss is not supported yet!
* *
* Two LogisticAggregator can be merged together to have a summary of loss and gradient of * Two LogisticAggregators can be merged together to have a summary of loss and gradient of
* the corresponding joint dataset. * the corresponding joint dataset.
* *
* For improving the convergence rate during the optimization process and also to prevent against
* features with very large variances exerting an overly large influence during model training,
* packages like R's GLMNET perform the scaling to unit variance and remove the mean in order to
* reduce the condition number. The model is then trained in this scaled space, but returns the
* coefficients in the original scale. See page 9 in
* http://cran.r-project.org/web/packages/glmnet/glmnet.pdf
*
* However, we don't want to apply the [[org.apache.spark.ml.feature.StandardScaler]] on the
* training dataset, and then cache the standardized dataset since it will create a lot of overhead.
* As a result, we perform the scaling implicitly when we compute the objective function (though
* we do not subtract the mean).
*
* Note that there is a difference between multinomial (softmax) and binary loss. The binary case
* uses one outcome class as a "pivot" and regresses the other class against the pivot. In the
* multinomial case, the softmax loss function is used to model each class probability
* independently. Using softmax loss produces `K` sets of coefficients, while using a pivot class
* produces `K - 1` sets of coefficients (a single coefficient vector in the binary case). In the
* binary case, we can say that the coefficients are shared between the positive and negative
* classes. When regularization is applied, multinomial (softmax) loss will produce a result
* different from binary loss since the positive and negative don't share the coefficients while the
* binary regression shares the coefficients between positive and negative.
*
* The following is a mathematical derivation for the multinomial (softmax) loss.
*
* The probability of the multinomial outcome $y$ taking on any of the K possible outcomes is:
*
* <p><blockquote>
* $$
* P(y_i=0|\vec{x}_i, \beta) = \frac{e^{\vec{x}_i^T \vec{\beta}_0}}{\sum_{k=0}^{K-1}
* e^{\vec{x}_i^T \vec{\beta}_k}} \\
* P(y_i=1|\vec{x}_i, \beta) = \frac{e^{\vec{x}_i^T \vec{\beta}_1}}{\sum_{k=0}^{K-1}
* e^{\vec{x}_i^T \vec{\beta}_k}}\\
* P(y_i=K-1|\vec{x}_i, \beta) = \frac{e^{\vec{x}_i^T \vec{\beta}_{K-1}}\,}{\sum_{k=0}^{K-1}
* e^{\vec{x}_i^T \vec{\beta}_k}}
* $$
* </blockquote></p>
*
* The model coefficients $\beta = (\beta_0, \beta_1, \beta_2, ..., \beta_{K-1})$ become a matrix
* which has dimension of $K \times (N+1)$ if the intercepts are added. If the intercepts are not
* added, the dimension will be $K \times N$.
*
* Note that the coefficients in the model above lack identifiability. That is, any constant scalar
* can be added to all of the coefficients and the probabilities remain the same.
*
* <p><blockquote>
* $$
* \begin{align}
* \frac{e^{\vec{x}_i^T \left(\vec{\beta}_0 + \vec{c}\right)}}{\sum_{k=0}^{K-1}
* e^{\vec{x}_i^T \left(\vec{\beta}_k + \vec{c}\right)}}
* = \frac{e^{\vec{x}_i^T \vec{\beta}_0}e^{\vec{x}_i^T \vec{c}}\,}{e^{\vec{x}_i^T \vec{c}}
* \sum_{k=0}^{K-1} e^{\vec{x}_i^T \vec{\beta}_k}}
* = \frac{e^{\vec{x}_i^T \vec{\beta}_0}}{\sum_{k=0}^{K-1} e^{\vec{x}_i^T \vec{\beta}_k}}
* \end{align}
* $$
* </blockquote></p>
*
* However, when regularization is added to the loss function, the coefficients are indeed
* identifiable because there is only one set of coefficients which minimizes the regularization
* term. When no regularization is applied, we choose the coefficients with the minimum L2
* penalty for consistency and reproducibility. For further discussion see:
*
* Friedman, et al. "Regularization Paths for Generalized Linear Models via Coordinate Descent"
*
* The loss of objective function for a single instance of data (we do not include the
* regularization term here for simplicity) can be written as
*
* <p><blockquote>
* $$
* \begin{align}
* \ell\left(\beta, x_i\right) &= -log{P\left(y_i \middle| \vec{x}_i, \beta\right)} \\
* &= log\left(\sum_{k=0}^{K-1}e^{\vec{x}_i^T \vec{\beta}_k}\right) - \vec{x}_i^T \vec{\beta}_y\\
* &= log\left(\sum_{k=0}^{K-1} e^{margins_k}\right) - margins_y
* \end{align}
* $$
* </blockquote></p>
*
* where ${margins}_k = \vec{x}_i^T \vec{\beta}_k$.
*
* For optimization, we have to calculate the first derivative of the loss function, and a simple
* calculation shows that
*
* <p><blockquote>
* $$
* \begin{align}
* \frac{\partial \ell(\beta, \vec{x}_i, w_i)}{\partial \beta_{j, k}}
* &= x_{i,j} \cdot w_i \cdot \left(\frac{e^{\vec{x}_i \cdot \vec{\beta}_k}}{\sum_{k'=0}^{K-1}
* e^{\vec{x}_i \cdot \vec{\beta}_{k'}}\,} - I_{y=k}\right) \\
* &= x_{i, j} \cdot w_i \cdot multiplier_k
* \end{align}
* $$
* </blockquote></p>
*
* where $w_i$ is the sample weight, $I_{y=k}$ is an indicator function
*
* <p><blockquote>
* $$
* I_{y=k} = \begin{cases}
* 1 & y = k \\
* 0 & else
* \end{cases}
* $$
* </blockquote></p>
*
* and
*
* <p><blockquote>
* $$
* multiplier_k = \left(\frac{e^{\vec{x}_i \cdot \vec{\beta}_k}}{\sum_{k=0}^{K-1}
* e^{\vec{x}_i \cdot \vec{\beta}_k}} - I_{y=k}\right)
* $$
* </blockquote></p>
*
* If any of margins is larger than 709.78, the numerical computation of multiplier and loss
* function will suffer from arithmetic overflow. This issue occurs when there are outliers in
* data which are far away from the hyperplane, and this will cause the failing of training once
* infinity is introduced. Note that this is only a concern when max(margins) > 0.
*
* Fortunately, when max(margins) = maxMargin > 0, the loss function and the multiplier can easily
* be rewritten into the following equivalent numerically stable formula.
*
* <p><blockquote>
* $$
* \ell\left(\beta, x\right) = log\left(\sum_{k=0}^{K-1} e^{margins_k - maxMargin}\right) -
* margins_{y} + maxMargin
* $$
* </blockquote></p>
*
* Note that each term, $(margins_k - maxMargin)$ in the exponential is no greater than zero; as a
* result, overflow will not happen with this formula.
*
* For $multiplier$, a similar trick can be applied as the following,
*
* <p><blockquote>
* $$
* multiplier_k = \left(\frac{e^{\vec{x}_i \cdot \vec{\beta}_k - maxMargin}}{\sum_{k'=0}^{K-1}
* e^{\vec{x}_i \cdot \vec{\beta}_{k'} - maxMargin}} - I_{y=k}\right)
* $$
* </blockquote></p>
*
* @param bcCoefficients The broadcast coefficients corresponding to the features. * @param bcCoefficients The broadcast coefficients corresponding to the features.
* @param bcFeaturesStd The broadcast standard deviation values of the features. * @param bcFeaturesStd The broadcast standard deviation values of the features.
* @param numClasses the number of possible outcomes for k classes classification problem in * @param numClasses the number of possible outcomes for k classes classification problem in
* Multinomial Logistic Regression. * Multinomial Logistic Regression.
* @param fitIntercept Whether to fit an intercept term. * @param fitIntercept Whether to fit an intercept term.
* @param multinomial Whether to use multinomial (softmax) or binary loss
*/ */
private class LogisticAggregator( private class LogisticAggregator(
val bcCoefficients: Broadcast[Vector], bcCoefficients: Broadcast[Vector],
val bcFeaturesStd: Broadcast[Array[Double]], bcFeaturesStd: Broadcast[Array[Double]],
private val numFeatures: Int,
numClasses: Int, numClasses: Int,
fitIntercept: Boolean) extends Serializable { fitIntercept: Boolean,
multinomial: Boolean) extends Serializable with Logging {
private val numFeatures = bcFeaturesStd.value.length
private val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures
private val coefficientSize = bcCoefficients.value.size
if (multinomial) {
require(numClasses == coefficientSize / numFeaturesPlusIntercept, s"The number of " +
s"coefficients should be ${numClasses * numFeaturesPlusIntercept} but was $coefficientSize")
} else {
require(coefficientSize == numFeaturesPlusIntercept, s"Expected $numFeaturesPlusIntercept " +
s"coefficients but got $coefficientSize")
require(numClasses == 1 || numClasses == 2, s"Binary logistic aggregator requires numClasses " +
s"in {1, 2} but found $numClasses.")
}
private var weightSum = 0.0 private var weightSum = 0.0
private var lossSum = 0.0 private var lossSum = 0.0
private val gradientSumArray = private val gradientSumArray = Array.ofDim[Double](coefficientSize)
Array.ofDim[Double](if (fitIntercept) numFeatures + 1 else numFeatures)
if (multinomial && numClasses <= 2) {
logInfo(s"Multinomial logistic regression for binary classification yields separate " +
s"coefficients for positive and negative classes. When no regularization is applied, the" +
s"result will be effectively the same as binary logistic regression. When regularization" +
s"is applied, multinomial loss will produce a result different from binary loss.")
}
/** Update gradient and loss using binary loss function. */
private def binaryUpdateInPlace(
features: Vector,
weight: Double,
label: Double): Unit = {
val localFeaturesStd = bcFeaturesStd.value
val localCoefficients = bcCoefficients.value
val localGradientArray = gradientSumArray
val margin = - {
var sum = 0.0
features.foreachActive { (index, value) =>
if (localFeaturesStd(index) != 0.0 && value != 0.0) {
sum += localCoefficients(index) * value / localFeaturesStd(index)
}
}
if (fitIntercept) sum += localCoefficients(numFeaturesPlusIntercept - 1)
sum
}
val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label)
features.foreachActive { (index, value) =>
if (localFeaturesStd(index) != 0.0 && value != 0.0) {
localGradientArray(index) += multiplier * value / localFeaturesStd(index)
}
}
if (fitIntercept) {
localGradientArray(numFeaturesPlusIntercept - 1) += multiplier
}
if (label > 0) {
// The following is equivalent to log(1 + exp(margin)) but more numerically stable.
lossSum += weight * MLUtils.log1pExp(margin)
} else {
lossSum += weight * (MLUtils.log1pExp(margin) - margin)
}
}
/** Update gradient and loss using multinomial (softmax) loss function. */
private def multinomialUpdateInPlace(
features: Vector,
weight: Double,
label: Double): Unit = {
// TODO: use level 2 BLAS operations
/*
Note: this can still be used when numClasses = 2 for binary
logistic regression without pivoting.
*/
val localFeaturesStd = bcFeaturesStd.value
val localCoefficients = bcCoefficients.value
val localGradientArray = gradientSumArray
// marginOfLabel is margins(label) in the formula
var marginOfLabel = 0.0
var maxMargin = Double.NegativeInfinity
val margins = Array.tabulate(numClasses) { i =>
var margin = 0.0
features.foreachActive { (index, value) =>
if (localFeaturesStd(index) != 0.0 && value != 0.0) {
margin += localCoefficients(i * numFeaturesPlusIntercept + index) *
value / localFeaturesStd(index)
}
}
if (fitIntercept) {
margin += localCoefficients(i * numFeaturesPlusIntercept + numFeatures)
}
if (i == label.toInt) marginOfLabel = margin
if (margin > maxMargin) {
maxMargin = margin
}
margin
}
/**
* When maxMargin > 0, the original formula could cause overflow.
* We address this by subtracting maxMargin from all the margins, so it's guaranteed
* that all of the new margins will be smaller than zero to prevent arithmetic overflow.
*/
val sum = {
var temp = 0.0
if (maxMargin > 0) {
for (i <- 0 until numClasses) {
margins(i) -= maxMargin
temp += math.exp(margins(i))
}
} else {
for (i <- 0 until numClasses) {
temp += math.exp(margins(i))
}
}
temp
}
for (i <- 0 until numClasses) {
val multiplier = math.exp(margins(i)) / sum - {
if (label == i) 1.0 else 0.0
}
features.foreachActive { (index, value) =>
if (localFeaturesStd(index) != 0.0 && value != 0.0) {
localGradientArray(i * numFeaturesPlusIntercept + index) +=
weight * multiplier * value / localFeaturesStd(index)
}
}
if (fitIntercept) {
localGradientArray(i * numFeaturesPlusIntercept + numFeatures) += weight * multiplier
}
}
val loss = if (maxMargin > 0) {
math.log(sum) - marginOfLabel + maxMargin
} else {
math.log(sum) - marginOfLabel
}
lossSum += weight * loss
}
/** /**
* Add a new training instance to this LogisticAggregator, and update the loss and gradient * Add a new training instance to this LogisticAggregator, and update the loss and gradient
...@@ -975,52 +1265,10 @@ private class LogisticAggregator( ...@@ -975,52 +1265,10 @@ private class LogisticAggregator(
if (weight == 0.0) return this if (weight == 0.0) return this
val coefficientsArray = bcCoefficients.value match { if (multinomial) {
case dv: DenseVector => dv.values multinomialUpdateInPlace(features, weight, label)
case _ => } else {
throw new IllegalArgumentException( binaryUpdateInPlace(features, weight, label)
"coefficients only supports dense vector" +
s"but got type ${bcCoefficients.value.getClass}.")
}
val localGradientSumArray = gradientSumArray
val featuresStd = bcFeaturesStd.value
numClasses match {
case 2 =>
// For Binary Logistic Regression.
val margin = - {
var sum = 0.0
features.foreachActive { (index, value) =>
if (featuresStd(index) != 0.0 && value != 0.0) {
sum += coefficientsArray(index) * (value / featuresStd(index))
}
}
sum + {
if (fitIntercept) coefficientsArray(numFeatures) else 0.0
}
}
val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label)
features.foreachActive { (index, value) =>
if (featuresStd(index) != 0.0 && value != 0.0) {
localGradientSumArray(index) += multiplier * (value / featuresStd(index))
}
}
if (fitIntercept) {
localGradientSumArray(numFeatures) += multiplier
}
if (label > 0) {
// The following is equivalent to log(1 + exp(margin)) but more numerically stable.
lossSum += weight * MLUtils.log1pExp(margin)
} else {
lossSum += weight * (MLUtils.log1pExp(margin) - margin)
}
case _ =>
new NotImplementedError("LogisticRegression with ElasticNet in ML package " +
"only supports binary classification for now.")
} }
weightSum += weight weightSum += weight
this this
...@@ -1071,8 +1319,8 @@ private class LogisticAggregator( ...@@ -1071,8 +1319,8 @@ private class LogisticAggregator(
} }
/** /**
* LogisticCostFun implements Breeze's DiffFunction[T] for a multinomial logistic loss function, * LogisticCostFun implements Breeze's DiffFunction[T] for a multinomial (softmax) logistic loss
* as used in multi-class classification (it is also used in binary logistic regression). * function, as used in multi-class classification (it is also used in binary logistic regression).
* It returns the loss and gradient with L2 regularization at a particular point (coefficients). * It returns the loss and gradient with L2 regularization at a particular point (coefficients).
* It's used in Breeze's convex optimization routines. * It's used in Breeze's convex optimization routines.
*/ */
...@@ -1082,36 +1330,36 @@ private class LogisticCostFun( ...@@ -1082,36 +1330,36 @@ private class LogisticCostFun(
fitIntercept: Boolean, fitIntercept: Boolean,
standardization: Boolean, standardization: Boolean,
bcFeaturesStd: Broadcast[Array[Double]], bcFeaturesStd: Broadcast[Array[Double]],
regParamL2: Double) extends DiffFunction[BDV[Double]] { regParamL2: Double,
multinomial: Boolean) extends DiffFunction[BDV[Double]] {
val featuresStd = bcFeaturesStd.value
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
val numFeatures = featuresStd.length
val coeffs = Vectors.fromBreeze(coefficients) val coeffs = Vectors.fromBreeze(coefficients)
val bcCoeffs = instances.context.broadcast(coeffs) val bcCoeffs = instances.context.broadcast(coeffs)
val n = coeffs.size val featuresStd = bcFeaturesStd.value
val numFeatures = featuresStd.length
val logisticAggregator = { val logisticAggregator = {
val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance) val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance)
val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2) val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2)
instances.treeAggregate( instances.treeAggregate(
new LogisticAggregator(bcCoeffs, bcFeaturesStd, numFeatures, numClasses, fitIntercept) new LogisticAggregator(bcCoeffs, bcFeaturesStd, numClasses, fitIntercept,
multinomial)
)(seqOp, combOp) )(seqOp, combOp)
} }
val totalGradientArray = logisticAggregator.gradient.toArray val totalGradientArray = logisticAggregator.gradient.toArray
// regVal is the sum of coefficients squares excluding intercept for L2 regularization. // regVal is the sum of coefficients squares excluding intercept for L2 regularization.
val regVal = if (regParamL2 == 0.0) { val regVal = if (regParamL2 == 0.0) {
0.0 0.0
} else { } else {
var sum = 0.0 var sum = 0.0
coeffs.foreachActive { (index, value) => coeffs.foreachActive { case (index, value) =>
// If `fitIntercept` is true, the last term which is intercept doesn't // We do not apply regularization to the intercepts
// contribute to the regularization. val isIntercept = fitIntercept && ((index + 1) % (numFeatures + 1) == 0)
if (index != numFeatures) { if (!isIntercept) {
// The following code will compute the loss of the regularization; also // The following code will compute the loss of the regularization; also
// the gradient of the regularization, and add back to totalGradientArray. // the gradient of the regularization, and add back to totalGradientArray.
sum += { sum += {
...@@ -1119,13 +1367,18 @@ private class LogisticCostFun( ...@@ -1119,13 +1367,18 @@ private class LogisticCostFun(
totalGradientArray(index) += regParamL2 * value totalGradientArray(index) += regParamL2 * value
value * value value * value
} else { } else {
if (featuresStd(index) != 0.0) { val featureIndex = if (fitIntercept) {
index % (numFeatures + 1)
} else {
index % numFeatures
}
if (featuresStd(featureIndex) != 0.0) {
// If `standardization` is false, we still standardize the data // If `standardization` is false, we still standardize the data
// to improve the rate of convergence; as a result, we have to // to improve the rate of convergence; as a result, we have to
// perform this reverse standardization by penalizing each component // perform this reverse standardization by penalizing each component
// differently to get effectively the same objective function when // differently to get effectively the same objective function when
// the training dataset is not standardized. // the training dataset is not standardized.
val temp = value / (featuresStd(index) * featuresStd(index)) val temp = value / (featuresStd(featureIndex) * featuresStd(featureIndex))
totalGradientArray(index) += regParamL2 * temp totalGradientArray(index) += regParamL2 * temp
value * temp value * temp
} else { } else {
......
...@@ -19,12 +19,14 @@ package org.apache.spark.ml.util ...@@ -19,12 +19,14 @@ package org.apache.spark.ml.util
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.recommendation.{ALS, ALSModel} import org.apache.spark.ml.recommendation.{ALS, ALSModel}
import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
...@@ -179,4 +181,47 @@ object MLTestingUtils extends SparkFunSuite { ...@@ -179,4 +181,47 @@ object MLTestingUtils extends SparkFunSuite {
.map(t => t -> df.select(col(labelColName).cast(t), col(predictionColName))) .map(t => t -> df.select(col(labelColName).cast(t), col(predictionColName)))
.toMap .toMap
} }
def genClassificationInstancesWithWeightedOutliers(
spark: SparkSession,
numClasses: Int,
numInstances: Int): DataFrame = {
val data = Array.tabulate[Instance](numInstances) { i =>
val feature = i % numClasses
if (i < numInstances / 3) {
// give large weights to minority of data with 1 to 1 mapping feature to label
Instance(feature, 1.0, Vectors.dense(feature))
} else {
// give small weights to majority of data points with reverse mapping
Instance(numClasses - feature - 1, 0.01, Vectors.dense(feature))
}
}
val labelMeta =
NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses).toMetadata()
spark.createDataFrame(data).select(col("label").as("label", labelMeta), col("weight"),
col("features"))
}
def genEquivalentOversampledAndWeightedInstances(
data: DataFrame,
labelCol: String,
featuresCol: String,
seed: Long): (DataFrame, DataFrame) = {
import data.sparkSession.implicits._
val rng = scala.util.Random
rng.setSeed(seed)
val sample: () => Int = () => rng.nextInt(10) + 1
val sampleUDF = udf(sample)
val rawData = data.select(labelCol, featuresCol).withColumn("samples", sampleUDF())
val overSampledData = rawData.rdd.flatMap {
case Row(label: Double, features: Vector, n: Int) =>
Iterator.fill(n)(Instance(label, 1.0, features))
}.toDF()
rng.setSeed(seed)
val weightedData = rawData.rdd.map {
case Row(label: Double, features: Vector, n: Int) =>
Instance(label, n.toDouble, features)
}.toDF()
(overSampledData, weightedData)
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment