Skip to content
Snippets Groups Projects
Commit 49d767d8 authored by actuaryzhang's avatar actuaryzhang Committed by Yanbo Liang
Browse files

[SPARK-18710][ML] Add offset in GLM

## What changes were proposed in this pull request?
Add support for offset in GLM. This is useful for at least two reasons:

1. Account for exposure: e.g., when modeling the number of accidents, we may need to use miles driven as an offset to access factors on frequency.
2. Test incremental effects of new variables: we can use predictions from the existing model as offset and run a much smaller model on only new variables. This avoids re-estimating the large model with all variables (old + new) and can be very important for efficient large-scaled analysis.

## How was this patch tested?
New test.

yanboliang srowen felixcheung sethah

Author: actuaryzhang <actuaryzhang10@gmail.com>

Closes #16699 from actuaryzhang/offset.
parent 52981715
No related branches found
No related tags found
No related merge requests found
......@@ -27,3 +27,24 @@ import org.apache.spark.ml.linalg.Vector
* @param features The vector of features for this data point.
*/
private[ml] case class Instance(label: Double, weight: Double, features: Vector)
/**
* Case class that represents an instance of data point with
* label, weight, offset and features.
* This is mainly used in GeneralizedLinearRegression currently.
*
* @param label Label for this data point.
* @param weight The weight of this instance.
* @param offset The offset used for this data point.
* @param features The vector of features for this data point.
*/
private[ml] case class OffsetInstance(
label: Double,
weight: Double,
offset: Double,
features: Vector) {
/** Converts to an [[Instance]] object by leaving out the offset. */
def toInstance: Instance = Instance(label, weight, features)
}
......@@ -18,7 +18,7 @@
package org.apache.spark.ml.optim
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.feature.{Instance, OffsetInstance}
import org.apache.spark.ml.linalg._
import org.apache.spark.rdd.RDD
......@@ -43,7 +43,7 @@ private[ml] class IterativelyReweightedLeastSquaresModel(
* find M-estimator in robust regression and other optimization problems.
*
* @param initialModel the initial guess model.
* @param reweightFunc the reweight function which is used to update offsets and weights
* @param reweightFunc the reweight function which is used to update working labels and weights
* at each iteration.
* @param fitIntercept whether to fit intercept.
* @param regParam L2 regularization parameter used by WLS.
......@@ -57,13 +57,13 @@ private[ml] class IterativelyReweightedLeastSquaresModel(
*/
private[ml] class IterativelyReweightedLeastSquares(
val initialModel: WeightedLeastSquaresModel,
val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double),
val reweightFunc: (OffsetInstance, WeightedLeastSquaresModel) => (Double, Double),
val fitIntercept: Boolean,
val regParam: Double,
val maxIter: Int,
val tol: Double) extends Logging with Serializable {
def fit(instances: RDD[Instance]): IterativelyReweightedLeastSquaresModel = {
def fit(instances: RDD[OffsetInstance]): IterativelyReweightedLeastSquaresModel = {
var converged = false
var iter = 0
......@@ -75,10 +75,10 @@ private[ml] class IterativelyReweightedLeastSquares(
oldModel = model
// Update offsets and weights using reweightFunc
// Update working labels and weights using reweightFunc
val newInstances = instances.map { instance =>
val (newOffset, newWeight) = reweightFunc(instance, oldModel)
Instance(newOffset, newWeight, instance.features)
val (newLabel, newWeight) = reweightFunc(instance, oldModel)
Instance(newLabel, newWeight, instance.features)
}
// Estimate new model
......
......@@ -18,7 +18,7 @@
package org.apache.spark.ml.optim
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.feature.{Instance, OffsetInstance}
import org.apache.spark.ml.linalg._
import org.apache.spark.rdd.RDD
......
......@@ -26,8 +26,8 @@ import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{BLAS, Vector}
import org.apache.spark.ml.feature.{Instance, OffsetInstance}
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.ml.optim._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
......@@ -138,6 +138,27 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
@Since("2.0.0")
def getLinkPredictionCol: String = $(linkPredictionCol)
/**
* Param for offset column name. If this is not set or empty, we treat all instance offsets
* as 0.0. The feature specified as offset has a constant coefficient of 1.0.
* @group param
*/
@Since("2.3.0")
final val offsetCol: Param[String] = new Param[String](this, "offsetCol", "The offset " +
"column name. If this is not set or empty, we treat all instance offsets as 0.0")
/** @group getParam */
@Since("2.3.0")
def getOffsetCol: String = $(offsetCol)
/** Checks whether weight column is set and nonempty. */
private[regression] def hasWeightCol: Boolean =
isSet(weightCol) && $(weightCol).nonEmpty
/** Checks whether offset column is set and nonempty. */
private[regression] def hasOffsetCol: Boolean =
isSet(offsetCol) && $(offsetCol).nonEmpty
/** Checks whether we should output link prediction. */
private[regression] def hasLinkPredictionCol: Boolean = {
isDefined(linkPredictionCol) && $(linkPredictionCol).nonEmpty
......@@ -172,6 +193,11 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
}
val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType)
if (hasOffsetCol) {
SchemaUtils.checkNumericType(schema, $(offsetCol))
}
if (hasLinkPredictionCol) {
SchemaUtils.appendColumn(newSchema, $(linkPredictionCol), DoubleType)
} else {
......@@ -306,6 +332,16 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
@Since("2.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)
/**
* Sets the value of param [[offsetCol]].
* If this is not set or empty, we treat all instance offsets as 0.0.
* Default is not set, so all instances have offset 0.0.
*
* @group setParam
*/
@Since("2.3.0")
def setOffsetCol(value: String): this.type = set(offsetCol, value)
/**
* Sets the solver algorithm used for optimization.
* Currently only supports "irls" which is also the default solver.
......@@ -329,7 +365,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
val instr = Instrumentation.create(this, dataset)
instr.logParams(labelCol, featuresCol, weightCol, predictionCol, linkPredictionCol,
instr.logParams(labelCol, featuresCol, weightCol, offsetCol, predictionCol, linkPredictionCol,
family, solver, fitIntercept, link, maxIter, regParam, tol)
instr.logNumFeatures(numFeatures)
......@@ -343,15 +379,16 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
"GeneralizedLinearRegression was given data with 0 features, and with Param fitIntercept " +
"set to false. To fit a model with 0 features, fitIntercept must be set to true." )
val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] =
dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
val w = if (!hasWeightCol) lit(1.0) else col($(weightCol))
val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType)
val model = if (familyAndLink.family == Gaussian && familyAndLink.link == Identity) {
// TODO: Make standardizeFeatures and standardizeLabel configurable.
val instances: RDD[Instance] =
dataset.select(col($(labelCol)), w, offset, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, offset: Double, features: Vector) =>
Instance(label - offset, weight, features)
}
val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), elasticNetParam = 0.0,
standardizeFeatures = true, standardizeLabel = true)
val wlsModel = optimizer.fit(instances)
......@@ -362,6 +399,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
wlsModel.diagInvAtWA.toArray, 1, getSolver)
model.setSummary(Some(trainingSummary))
} else {
val instances: RDD[OffsetInstance] =
dataset.select(col($(labelCol)), w, offset, col($(featuresCol))).rdd.map {
case Row(label: Double, weight: Double, offset: Double, features: Vector) =>
OffsetInstance(label, weight, offset, features)
}
// Fit Generalized Linear Model by iteratively reweighted least squares (IRLS).
val initialModel = familyAndLink.initialize(instances, $(fitIntercept), $(regParam))
val optimizer = new IterativelyReweightedLeastSquares(initialModel,
......@@ -425,12 +467,12 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
* Get the initial guess model for [[IterativelyReweightedLeastSquares]].
*/
def initialize(
instances: RDD[Instance],
instances: RDD[OffsetInstance],
fitIntercept: Boolean,
regParam: Double): WeightedLeastSquaresModel = {
val newInstances = instances.map { instance =>
val mu = family.initialize(instance.label, instance.weight)
val eta = predict(mu)
val eta = predict(mu) - instance.offset
Instance(eta, instance.weight, instance.features)
}
// TODO: Make standardizeFeatures and standardizeLabel configurable.
......@@ -441,16 +483,16 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
}
/**
* The reweight function used to update offsets and weights
* The reweight function used to update working labels and weights
* at each iteration of [[IterativelyReweightedLeastSquares]].
*/
val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double) = {
(instance: Instance, model: WeightedLeastSquaresModel) => {
val eta = model.predict(instance.features)
val reweightFunc: (OffsetInstance, WeightedLeastSquaresModel) => (Double, Double) = {
(instance: OffsetInstance, model: WeightedLeastSquaresModel) => {
val eta = model.predict(instance.features) + instance.offset
val mu = fitted(eta)
val offset = eta + (instance.label - mu) * link.deriv(mu)
val weight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu))
(offset, weight)
val newLabel = eta - instance.offset + (instance.label - mu) * link.deriv(mu)
val newWeight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu))
(newLabel, newWeight)
}
}
}
......@@ -950,15 +992,22 @@ class GeneralizedLinearRegressionModel private[ml] (
private lazy val familyAndLink = FamilyAndLink(this)
override protected def predict(features: Vector): Double = {
val eta = predictLink(features)
predict(features, 0.0)
}
/**
* Calculates the predicted value when offset is set.
*/
private def predict(features: Vector, offset: Double): Double = {
val eta = predictLink(features, offset)
familyAndLink.fitted(eta)
}
/**
* Calculate the link prediction (linear predictor) of the given instance.
* Calculates the link prediction (linear predictor) of the given instance.
*/
private def predictLink(features: Vector): Double = {
BLAS.dot(features, coefficients) + intercept
private def predictLink(features: Vector, offset: Double): Double = {
BLAS.dot(features, coefficients) + intercept + offset
}
override def transform(dataset: Dataset[_]): DataFrame = {
......@@ -967,14 +1016,16 @@ class GeneralizedLinearRegressionModel private[ml] (
}
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf { (features: Vector) => predict(features) }
val predictLinkUDF = udf { (features: Vector) => predictLink(features) }
val predictUDF = udf { (features: Vector, offset: Double) => predict(features, offset) }
val predictLinkUDF = udf { (features: Vector, offset: Double) => predictLink(features, offset) }
val offset = if (!hasOffsetCol) lit(0.0) else col($(offsetCol)).cast(DoubleType)
var output = dataset
if ($(predictionCol).nonEmpty) {
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)), offset))
}
if (hasLinkPredictionCol) {
output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol))))
output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)), offset))
}
output.toDF()
}
......@@ -1146,9 +1197,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
/** Degrees of freedom. */
@Since("2.0.0")
lazy val degreesOfFreedom: Long = {
numInstances - rank
}
lazy val degreesOfFreedom: Long = numInstances - rank
/** The residual degrees of freedom. */
@Since("2.0.0")
......@@ -1156,18 +1205,20 @@ class GeneralizedLinearRegressionSummary private[regression] (
/** The residual degrees of freedom for the null model. */
@Since("2.0.0")
lazy val residualDegreeOfFreedomNull: Long = if (model.getFitIntercept) {
numInstances - 1
} else {
numInstances
lazy val residualDegreeOfFreedomNull: Long = {
if (model.getFitIntercept) numInstances - 1 else numInstances
}
private def weightCol: Column = {
if (!model.isDefined(model.weightCol) || model.getWeightCol.isEmpty) {
lit(1.0)
} else {
col(model.getWeightCol)
}
private def label: Column = col(model.getLabelCol).cast(DoubleType)
private def prediction: Column = col(predictionCol)
private def weight: Column = {
if (!model.hasWeightCol) lit(1.0) else col(model.getWeightCol)
}
private def offset: Column = {
if (!model.hasOffsetCol) lit(0.0) else col(model.getOffsetCol).cast(DoubleType)
}
private[regression] lazy val devianceResiduals: DataFrame = {
......@@ -1175,25 +1226,23 @@ class GeneralizedLinearRegressionSummary private[regression] (
val r = math.sqrt(math.max(family.deviance(y, mu, weight), 0.0))
if (y > mu) r else -1.0 * r
}
val w = weightCol
predictions.select(
drUDF(col(model.getLabelCol), col(predictionCol), w).as("devianceResiduals"))
drUDF(label, prediction, weight).as("devianceResiduals"))
}
private[regression] lazy val pearsonResiduals: DataFrame = {
val prUDF = udf { mu: Double => family.variance(mu) }
val w = weightCol
predictions.select(col(model.getLabelCol).minus(col(predictionCol))
.multiply(sqrt(w)).divide(sqrt(prUDF(col(predictionCol)))).as("pearsonResiduals"))
predictions.select(label.minus(prediction)
.multiply(sqrt(weight)).divide(sqrt(prUDF(prediction))).as("pearsonResiduals"))
}
private[regression] lazy val workingResiduals: DataFrame = {
val wrUDF = udf { (y: Double, mu: Double) => (y - mu) * link.deriv(mu) }
predictions.select(wrUDF(col(model.getLabelCol), col(predictionCol)).as("workingResiduals"))
predictions.select(wrUDF(label, prediction).as("workingResiduals"))
}
private[regression] lazy val responseResiduals: DataFrame = {
predictions.select(col(model.getLabelCol).minus(col(predictionCol)).as("responseResiduals"))
predictions.select(label.minus(prediction).as("responseResiduals"))
}
/**
......@@ -1225,16 +1274,35 @@ class GeneralizedLinearRegressionSummary private[regression] (
*/
@Since("2.0.0")
lazy val nullDeviance: Double = {
val w = weightCol
val wtdmu: Double = if (model.getFitIntercept) {
val agg = predictions.agg(sum(w.multiply(col(model.getLabelCol))), sum(w)).first()
agg.getDouble(0) / agg.getDouble(1)
val intercept: Double = if (!model.getFitIntercept) {
0.0
} else {
link.unlink(0.0)
/*
Estimate intercept analytically when there is no offset, or when there is offset but
the model is Gaussian family with identity link. Otherwise, fit an intercept only model.
*/
if (!model.hasOffsetCol ||
(model.hasOffsetCol && family == Gaussian && link == Identity)) {
val agg = predictions.agg(sum(weight.multiply(
label.minus(offset))), sum(weight)).first()
link.link(agg.getDouble(0) / agg.getDouble(1))
} else {
// Create empty feature column and fit intercept only model using param setting from model
val featureNull = "feature_" + java.util.UUID.randomUUID.toString
val paramMap = model.extractParamMap()
paramMap.put(model.featuresCol, featureNull)
if (family.name != "tweedie") {
paramMap.remove(model.variancePower)
}
val emptyVectorUDF = udf{ () => Vectors.zeros(0) }
model.parent.fit(
dataset.withColumn(featureNull, emptyVectorUDF()), paramMap
).intercept
}
}
predictions.select(col(model.getLabelCol).cast(DoubleType), w).rdd.map {
case Row(y: Double, weight: Double) =>
family.deviance(y, wtdmu, weight)
predictions.select(label, offset, weight).rdd.map {
case Row(y: Double, offset: Double, weight: Double) =>
family.deviance(y, link.unlink(intercept + offset), weight)
}.sum()
}
......@@ -1243,8 +1311,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
*/
@Since("2.0.0")
lazy val deviance: Double = {
val w = weightCol
predictions.select(col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map {
predictions.select(label, prediction, weight).rdd.map {
case Row(label: Double, pred: Double, weight: Double) =>
family.deviance(label, pred, weight)
}.sum()
......@@ -1269,10 +1336,9 @@ class GeneralizedLinearRegressionSummary private[regression] (
/** Akaike Information Criterion (AIC) for the fitted model. */
@Since("2.0.0")
lazy val aic: Double = {
val w = weightCol
val weightSum = predictions.select(w).agg(sum(w)).first().getDouble(0)
val weightSum = predictions.select(weight).agg(sum(weight)).first().getDouble(0)
val t = predictions.select(
col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map {
label, prediction, weight).rdd.map {
case Row(label: Double, pred: Double, weight: Double) =>
(label, pred, weight)
}
......
......@@ -18,7 +18,7 @@
package org.apache.spark.ml.optim
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.feature.{Instance, OffsetInstance}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
......@@ -26,8 +26,8 @@ import org.apache.spark.rdd.RDD
class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext {
private var instances1: RDD[Instance] = _
private var instances2: RDD[Instance] = _
private var instances1: RDD[OffsetInstance] = _
private var instances2: RDD[OffsetInstance] = _
override def beforeAll(): Unit = {
super.beforeAll()
......@@ -39,10 +39,10 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes
w <- c(1, 2, 3, 4)
*/
instances1 = sc.parallelize(Seq(
Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)),
Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)),
Instance(0.0, 4.0, Vectors.dense(3.0, 3.0))
OffsetInstance(1.0, 1.0, 0.0, Vectors.dense(0.0, 5.0).toSparse),
OffsetInstance(0.0, 2.0, 0.0, Vectors.dense(1.0, 2.0)),
OffsetInstance(1.0, 3.0, 0.0, Vectors.dense(2.0, 1.0)),
OffsetInstance(0.0, 4.0, 0.0, Vectors.dense(3.0, 3.0))
), 2)
/*
R code:
......@@ -52,10 +52,10 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes
w <- c(1, 2, 3, 4)
*/
instances2 = sc.parallelize(Seq(
Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)),
Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)),
Instance(9.0, 4.0, Vectors.dense(3.0, 13.0))
OffsetInstance(2.0, 1.0, 0.0, Vectors.dense(0.0, 5.0).toSparse),
OffsetInstance(8.0, 2.0, 0.0, Vectors.dense(1.0, 7.0)),
OffsetInstance(3.0, 3.0, 0.0, Vectors.dense(2.0, 11.0)),
OffsetInstance(9.0, 4.0, 0.0, Vectors.dense(3.0, 13.0))
), 2)
}
......@@ -156,7 +156,7 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes
var idx = 0
for (fitIntercept <- Seq(false, true)) {
val initial = new WeightedLeastSquares(fitIntercept, regParam = 0.0, elasticNetParam = 0.0,
standardizeFeatures = false, standardizeLabel = false).fit(instances2)
standardizeFeatures = false, standardizeLabel = false).fit(instances2.map(_.toInstance))
val irls = new IterativelyReweightedLeastSquares(initial, L1RegressionReweightFunc,
fitIntercept, regParam = 0.0, maxIter = 200, tol = 1e-7).fit(instances2)
val actual = Vectors.dense(irls.intercept, irls.coefficients(0), irls.coefficients(1))
......@@ -169,29 +169,29 @@ class IterativelyReweightedLeastSquaresSuite extends SparkFunSuite with MLlibTes
object IterativelyReweightedLeastSquaresSuite {
def BinomialReweightFunc(
instance: Instance,
instance: OffsetInstance,
model: WeightedLeastSquaresModel): (Double, Double) = {
val eta = model.predict(instance.features)
val eta = model.predict(instance.features) + instance.offset
val mu = 1.0 / (1.0 + math.exp(-1.0 * eta))
val z = eta + (instance.label - mu) / (mu * (1.0 - mu))
val z = eta - instance.offset + (instance.label - mu) / (mu * (1.0 - mu))
val w = mu * (1 - mu) * instance.weight
(z, w)
}
def PoissonReweightFunc(
instance: Instance,
instance: OffsetInstance,
model: WeightedLeastSquaresModel): (Double, Double) = {
val eta = model.predict(instance.features)
val eta = model.predict(instance.features) + instance.offset
val mu = math.exp(eta)
val z = eta + (instance.label - mu) / mu
val z = eta - instance.offset + (instance.label - mu) / mu
val w = mu * instance.weight
(z, w)
}
def L1RegressionReweightFunc(
instance: Instance,
instance: OffsetInstance,
model: WeightedLeastSquaresModel): (Double, Double) = {
val eta = model.predict(instance.features)
val eta = model.predict(instance.features) + instance.offset
val e = math.max(math.abs(eta - instance.label), 1e-7)
val w = 1 / e
val y = instance.label
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment