Skip to content
Snippets Groups Projects
Commit dd36ec6b authored by DB Tsai's avatar DB Tsai Committed by Xiangrui Meng
Browse files

[SPARK-10738] [ML] Refactoring `Instance` out from LOR and LIR, and also cleaning up some code

Refactoring `Instance` case class out from LOR and LIR, and also cleaning up some code.

Author: DB Tsai <dbt@netflix.com>

Closes #8853 from dbtsai/refactoring.
parent 7e2e2682
No related branches found
No related tags found
No related merge requests found
......@@ -24,6 +24,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS,
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
......@@ -146,17 +147,6 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
}
}
/**
* Class that represents an instance of weighted data point with label and features.
*
* TODO: Refactor this class to proper place.
*
* @param label Label for this data point.
* @param weight The weight of this instance.
* @param features The vector of features for this data point.
*/
private[classification] case class Instance(label: Double, weight: Double, features: Vector)
/**
* :: Experimental ::
* Logistic regression.
......@@ -322,7 +312,7 @@ class LogisticRegression(override val uid: String)
if ($(fitIntercept)) {
/*
For binary logistic regression, when we initialize the weights as zeros,
For binary logistic regression, when we initialize the coefficients as zeros,
it will converge faster if we initialize the intercept such that
it follows the distribution of the labels.
......@@ -757,62 +747,63 @@ private class LogisticAggregator(
private val gradientSumArray = Array.ofDim[Double](coefficientsArray.length)
/**
* Add a new training data to this LogisticAggregator, and update the loss and gradient
* Add a new training instance to this LogisticAggregator, and update the loss and gradient
* of the objective function.
*
* @param label The label for this data point.
* @param data The features for one data point in dense/sparse vector format to be added
* into this aggregator.
* @param weight The weight for over-/undersamples each of training instance. Default is one.
* @param instance The instance of data point to be added.
* @return This LogisticAggregator object.
*/
def add(label: Double, data: Vector, weight: Double = 1.0): this.type = {
require(dim == data.size, s"Dimensions mismatch when adding new instance." +
s" Expecting $dim but got ${data.size}.")
require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
if (weight == 0.0) return this
def add(instance: Instance): this.type = {
instance match { case Instance(label, weight, features) =>
require(dim == features.size, s"Dimensions mismatch when adding new instance." +
s" Expecting $dim but got ${features.size}.")
require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
if (weight == 0.0) return this
val localCoefficientsArray = coefficientsArray
val localGradientSumArray = gradientSumArray
numClasses match {
case 2 =>
// For Binary Logistic Regression.
val margin = - {
var sum = 0.0
features.foreachActive { (index, value) =>
if (featuresStd(index) != 0.0 && value != 0.0) {
sum += localCoefficientsArray(index) * (value / featuresStd(index))
}
}
sum + {
if (fitIntercept) localCoefficientsArray(dim) else 0.0
}
}
val localCoefficientsArray = coefficientsArray
val localGradientSumArray = gradientSumArray
val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label)
numClasses match {
case 2 =>
// For Binary Logistic Regression.
val margin = - {
var sum = 0.0
data.foreachActive { (index, value) =>
features.foreachActive { (index, value) =>
if (featuresStd(index) != 0.0 && value != 0.0) {
sum += localCoefficientsArray(index) * (value / featuresStd(index))
localGradientSumArray(index) += multiplier * (value / featuresStd(index))
}
}
sum + { if (fitIntercept) localCoefficientsArray(dim) else 0.0 }
}
val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label)
data.foreachActive { (index, value) =>
if (featuresStd(index) != 0.0 && value != 0.0) {
localGradientSumArray(index) += multiplier * (value / featuresStd(index))
if (fitIntercept) {
localGradientSumArray(dim) += multiplier
}
}
if (fitIntercept) {
localGradientSumArray(dim) += multiplier
}
if (label > 0) {
// The following is equivalent to log(1 + exp(margin)) but more numerically stable.
lossSum += weight * MLUtils.log1pExp(margin)
} else {
lossSum += weight * (MLUtils.log1pExp(margin) - margin)
}
case _ =>
new NotImplementedError("LogisticRegression with ElasticNet in ML package only supports " +
"binary classification for now.")
if (label > 0) {
// The following is equivalent to log(1 + exp(margin)) but more numerically stable.
lossSum += weight * MLUtils.log1pExp(margin)
} else {
lossSum += weight * (MLUtils.log1pExp(margin) - margin)
}
case _ =>
new NotImplementedError("LogisticRegression with ElasticNet in ML package " +
"only supports binary classification for now.")
}
weightSum += weight
this
}
weightSum += weight
this
}
/**
......@@ -861,11 +852,11 @@ private class LogisticAggregator(
/**
* LogisticCostFun implements Breeze's DiffFunction[T] for a multinomial logistic loss function,
* as used in multi-class classification (it is also used in binary logistic regression).
* It returns the loss and gradient with L2 regularization at a particular point (weights).
* It returns the loss and gradient with L2 regularization at a particular point (coefficients).
* It's used in Breeze's convex optimization routines.
*/
private class LogisticCostFun(
data: RDD[Instance],
instances: RDD[Instance],
numClasses: Int,
fitIntercept: Boolean,
standardization: Boolean,
......@@ -875,15 +866,14 @@ private class LogisticCostFun(
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
val numFeatures = featuresStd.length
val w = Vectors.fromBreeze(coefficients)
val coeffs = Vectors.fromBreeze(coefficients)
val logisticAggregator = {
val seqOp = (c: LogisticAggregator, instance: Instance) =>
c.add(instance.label, instance.features, instance.weight)
val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance)
val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2)
data.treeAggregate(
new LogisticAggregator(w, numClasses, fitIntercept, featuresStd, featuresMean)
instances.treeAggregate(
new LogisticAggregator(coeffs, numClasses, fitIntercept, featuresStd, featuresMean)
)(seqOp, combOp)
}
......@@ -894,7 +884,7 @@ private class LogisticCostFun(
0.0
} else {
var sum = 0.0
w.foreachActive { (index, value) =>
coeffs.foreachActive { (index, value) =>
// If `fitIntercept` is true, the last term which is intercept doesn't
// contribute to the regularization.
if (index != numFeatures) {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.feature
import org.apache.spark.mllib.linalg.Vector
/**
* Class that represents an instance of weighted data point with label and features.
*
* @param label Label for this data point.
* @param weight The weight of this instance.
* @param features The vector of features for this data point.
*/
private[ml] case class Instance(label: Double, weight: Double, features: Vector)
......@@ -19,11 +19,12 @@ package org.apache.spark.ml.regression
import scala.collection.mutable
import breeze.linalg.{DenseVector => BDV, norm => brzNorm}
import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
......@@ -44,24 +45,13 @@ private[regression] trait LinearRegressionParams extends PredictorParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
with HasFitIntercept with HasStandardization with HasWeightCol
/**
* Class that represents an instance of weighted data point with label and features.
*
* TODO: Refactor this class to proper place.
*
* @param label Label for this data point.
* @param weight The weight of this instance.
* @param features The vector of features for this data point.
*/
private[regression] case class Instance(label: Double, weight: Double, features: Vector)
/**
* :: Experimental ::
* Linear regression.
*
* The learning objective is to minimize the squared error, with regularization.
* The specific squared error loss function used is:
* L = 1/2n ||A weights - y||^2^
* L = 1/2n ||A coefficients - y||^2^
*
* This support multiple types of regularization:
* - none (a.k.a. ordinary least squares)
......@@ -172,13 +162,14 @@ class LinearRegression(override val uid: String)
// If the yStd is zero, then the intercept is yMean with zero weights;
// as a result, training is not needed.
if (yStd == 0.0) {
logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " +
s"and the intercept will be the mean of the label; as a result, training is not needed.")
logWarning(s"The standard deviation of the label is zero, so the coefficients will be " +
s"zeros and the intercept will be the mean of the label; as a result, " +
s"training is not needed.")
if (handlePersistence) instances.unpersist()
val weights = Vectors.sparse(numFeatures, Seq())
val coefficients = Vectors.sparse(numFeatures, Seq())
val intercept = yMean
val model = new LinearRegressionModel(uid, weights, intercept)
val model = new LinearRegressionModel(uid, coefficients, intercept)
val trainingSummary = new LinearRegressionTrainingSummary(
model.transform(dataset),
$(predictionCol),
......@@ -218,11 +209,11 @@ class LinearRegression(override val uid: String)
new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegFun, $(tol))
}
val initialWeights = Vectors.zeros(numFeatures)
val initialCoefficients = Vectors.zeros(numFeatures)
val states = optimizer.iterations(new CachedDiffFunction(costFun),
initialWeights.toBreeze.toDenseVector)
initialCoefficients.toBreeze.toDenseVector)
val (weights, objectiveHistory) = {
val (coefficients, objectiveHistory) = {
/*
Note that in Linear Regression, the objective history (loss + regularization) returned
from optimizer is computed in the scaled space given by the following formula.
......@@ -243,18 +234,18 @@ class LinearRegression(override val uid: String)
}
/*
The weights are trained in the scaled space; we're converting them back to
The coefficients are trained in the scaled space; we're converting them back to
the original space.
*/
val rawWeights = state.x.toArray.clone()
val rawCoefficients = state.x.toArray.clone()
var i = 0
val len = rawWeights.length
val len = rawCoefficients.length
while (i < len) {
rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 }
rawCoefficients(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 }
i += 1
}
(Vectors.dense(rawWeights).compressed, arrayBuilder.result())
(Vectors.dense(rawCoefficients).compressed, arrayBuilder.result())
}
/*
......@@ -262,11 +253,15 @@ class LinearRegression(override val uid: String)
converged. See the following discussion for detail.
http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
*/
val intercept = if ($(fitIntercept)) yMean - dot(weights, Vectors.dense(featuresMean)) else 0.0
val intercept = if ($(fitIntercept)) {
yMean - dot(coefficients, Vectors.dense(featuresMean))
} else {
0.0
}
if (handlePersistence) instances.unpersist()
val model = copyValues(new LinearRegressionModel(uid, weights, intercept))
val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept))
val trainingSummary = new LinearRegressionTrainingSummary(
model.transform(dataset),
$(predictionCol),
......@@ -425,7 +420,7 @@ class LinearRegressionSummary private[regression] (
* For improving the convergence rate during the optimization process, and also preventing against
* features with very large variances exerting an overly large influence during model training,
* package like R's GLMNET performs the scaling to unit variance and removing the mean to reduce
* the condition number, and then trains the model in scaled space but returns the weights in
* the condition number, and then trains the model in scaled space but returns the coefficients in
* the original scale. See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf
*
* However, we don't want to apply the `StandardScaler` on the training dataset, and then cache
......@@ -456,7 +451,7 @@ class LinearRegressionSummary private[regression] (
* + \bar{y} / \hat{y}||^2
* = 1/2n ||\sum_i w_i^\prime x_i - y / \hat{y} + offset||^2 = 1/2n diff^2
* }}}
* where w_i^\prime^ is the effective weights defined by w_i/\hat{x_i}, offset is
* where w_i^\prime^ is the effective coefficients defined by w_i/\hat{x_i}, offset is
* {{{
* - \sum_i (w_i/\hat{x_i})\bar{x_i} + \bar{y} / \hat{y}.
* }}}, and diff is
......@@ -465,7 +460,7 @@ class LinearRegressionSummary private[regression] (
* }}}
*
*
* Note that the effective weights and offset don't depend on training dataset,
* Note that the effective coefficients and offset don't depend on training dataset,
* so they can be precomputed.
*
* Now, the first derivative of the objective function in scaled space is
......@@ -543,13 +538,13 @@ private class LeastSquaresAggregator(
private val gradientSumArray = Array.ofDim[Double](dim)
/**
* Add a new training data to this LeastSquaresAggregator, and update the loss and gradient
* Add a new training instance to this LeastSquaresAggregator, and update the loss and gradient
* of the objective function.
*
* @param instance The data point instance to be added.
* @param instance The instance of data point to be added.
* @return This LeastSquaresAggregator object.
*/
def add(instance: Instance): this.type =
def add(instance: Instance): this.type = {
instance match { case Instance(label, weight, features) =>
require(dim == features.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $dim but got ${features.size}.")
......@@ -573,6 +568,7 @@ private class LeastSquaresAggregator(
weightSum += weight
this
}
}
/**
* Merge another LeastSquaresAggregator, and update the loss and gradient
......@@ -621,11 +617,11 @@ private class LeastSquaresAggregator(
/**
* LeastSquaresCostFun implements Breeze's DiffFunction[T] for Least Squares cost.
* It returns the loss and gradient with L2 regularization at a particular point (weights).
* It returns the loss and gradient with L2 regularization at a particular point (coefficients).
* It's used in Breeze's convex optimization routines.
*/
private class LeastSquaresCostFun(
data: RDD[Instance],
instances: RDD[Instance],
labelStd: Double,
labelMean: Double,
fitIntercept: Boolean,
......@@ -635,12 +631,16 @@ private class LeastSquaresCostFun(
effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
val coeff = Vectors.fromBreeze(coefficients)
val coeffs = Vectors.fromBreeze(coefficients)
val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(coeff, labelStd,
labelMean, fitIntercept, featuresStd, featuresMean))(
seqOp = (aggregator, instance) => aggregator.add(instance),
combOp = (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
val leastSquaresAggregator = {
val seqOp = (c: LeastSquaresAggregator, instance: Instance) => c.add(instance)
val combOp = (c1: LeastSquaresAggregator, c2: LeastSquaresAggregator) => c1.merge(c2)
instances.treeAggregate(
new LeastSquaresAggregator(coeffs, labelStd, labelMean, fitIntercept, featuresStd,
featuresMean))(seqOp, combOp)
}
val totalGradientArray = leastSquaresAggregator.gradient.toArray
......@@ -648,7 +648,7 @@ private class LeastSquaresCostFun(
0.0
} else {
var sum = 0.0
coeff.foreachActive { (index, value) =>
coeffs.foreachActive { (index, value) =>
// The following code will compute the loss of the regularization; also
// the gradient of the regularization, and add back to totalGradientArray.
sum += {
......
......@@ -20,6 +20,7 @@ package org.apache.spark.ml.classification
import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
......
......@@ -20,6 +20,7 @@ package org.apache.spark.ml.regression
import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.regression.LabeledPoint
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment