Skip to content
Snippets Groups Projects
Commit 101663f1 authored by Yanbo Liang's avatar Yanbo Liang Committed by Xiangrui Meng
Browse files

[SPARK-13322][ML] AFTSurvivalRegression supports feature standardization

## What changes were proposed in this pull request?
AFTSurvivalRegression should support feature standardization, it will improve the convergence rate.
Test the convergence rate on the [Ovarian](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/ovarian.html) data which is standard data comes with Survival library in R,
* without standardization(before this PR) -> 74 iterations.
* with standardization(after this PR) -> 38 iterations.

But after this fix, with or without ```standardization``` will converge to the same solution. It means that ```standardization = false``` will run the same code route as ```standardization = true```. Because if the features are not standardized at all, it will result convergency issue when the features have very different scales. This behavior is the same as ML [```LinearRegression``` and ```LogisticRegression```](https://issues.apache.org/jira/browse/SPARK-8522). See more discussion about this topic at #11247.
cc mengxr
## How was this patch tested?
unit test.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #11365 from yanboliang/spark-13322.
parent 75e05a5a
No related branches found
No related tags found
No related merge requests found
......@@ -31,6 +31,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT}
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
......@@ -198,10 +199,20 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
val costFun = new AFTCostFun(instances, $(fitIntercept))
val featuresSummarizer = {
val seqOp = (c: MultivariateOnlineSummarizer, v: AFTPoint) => c.add(v.features)
val combOp = (c1: MultivariateOnlineSummarizer, c2: MultivariateOnlineSummarizer) => {
c1.merge(c2)
}
instances.treeAggregate(new MultivariateOnlineSummarizer)(seqOp, combOp)
}
val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
val costFun = new AFTCostFun(instances, $(fitIntercept), featuresStd)
val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size
val numFeatures = featuresStd.size
/*
The parameters vector has three parts:
the first element: Double, log(sigma), the log of scale parameter
......@@ -230,7 +241,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
if (handlePersistence) instances.unpersist()
val coefficients = Vectors.dense(parameters.slice(2, parameters.length))
val rawCoefficients = parameters.slice(2, parameters.length)
var i = 0
while (i < numFeatures) {
rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
i += 1
}
val coefficients = Vectors.dense(rawCoefficients)
val intercept = parameters(1)
val scale = math.exp(parameters(0))
val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale)
......@@ -434,29 +451,36 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
* @param parameters including three part: The log of scale parameter, the intercept and
* regression coefficients corresponding to the features.
* @param fitIntercept Whether to fit an intercept term.
* @param featuresStd The standard deviation values of the features.
*/
private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
extends Serializable {
private class AFTAggregator(
parameters: BDV[Double],
fitIntercept: Boolean,
featuresStd: Array[Double]) extends Serializable {
// the regression coefficients to the covariates
private val coefficients = parameters.slice(2, parameters.length)
private val intercept = parameters.valueAt(1)
private val intercept = parameters(1)
// sigma is the scale parameter of the AFT model
private val sigma = math.exp(parameters(0))
private var totalCnt: Long = 0L
private var lossSum = 0.0
private var gradientCoefficientSum = BDV.zeros[Double](coefficients.length)
private var gradientInterceptSum = 0.0
private var gradientLogSigmaSum = 0.0
// Here we optimize loss function over log(sigma), intercept and coefficients
private val gradientSumArray = Array.ofDim[Double](parameters.length)
def count: Long = totalCnt
def loss: Double = {
require(totalCnt > 0.0, s"The number of instances should be " +
s"greater than 0.0, but got $totalCnt.")
lossSum / totalCnt
}
def gradient: BDV[Double] = {
require(totalCnt > 0.0, s"The number of instances should be " +
s"greater than 0.0, but got $totalCnt.")
new BDV(gradientSumArray.map(_ / totalCnt.toDouble))
}
def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt
// Here we optimize loss function over coefficients, intercept and log(sigma)
def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)),
BDV(Array(gradientInterceptSum/totalCnt.toDouble)), gradientCoefficientSum/totalCnt.toDouble)
/**
* Add a new training data to this AFTAggregator, and update the loss and gradient
......@@ -466,25 +490,32 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
* @return This AFTAggregator object.
*/
def add(data: AFTPoint): this.type = {
val interceptFlag = if (fitIntercept) 1.0 else 0.0
val xi = data.features.toBreeze
val xi = data.features
val ti = data.label
val delta = data.censor
val epsilon = (math.log(ti) - coefficients.dot(xi) - intercept * interceptFlag ) / sigma
lossSum += math.log(sigma) * delta
lossSum += (math.exp(epsilon) - delta * epsilon)
val margin = {
var sum = 0.0
xi.foreachActive { (index, value) =>
if (featuresStd(index) != 0.0 && value != 0.0) {
sum += coefficients(index) * (value / featuresStd(index))
}
}
sum + intercept
}
val epsilon = (math.log(ti) - margin) / sigma
lossSum += delta * math.log(sigma) - delta * epsilon + math.exp(epsilon)
// Sanity check (should never occur):
assert(!lossSum.isInfinity,
s"AFTAggregator loss sum is infinity. Error for unknown reason.")
val multiplier = (delta - math.exp(epsilon)) / sigma
val deltaMinusExpEps = delta - math.exp(epsilon)
gradientCoefficientSum += xi * deltaMinusExpEps / sigma
gradientInterceptSum += interceptFlag * deltaMinusExpEps / sigma
gradientLogSigmaSum += delta + deltaMinusExpEps * epsilon
gradientSumArray(0) += delta + multiplier * sigma * epsilon
gradientSumArray(1) += { if (fitIntercept) multiplier else 0.0 }
xi.foreachActive { (index, value) =>
if (featuresStd(index) != 0.0 && value != 0.0) {
gradientSumArray(index + 2) += multiplier * (value / featuresStd(index))
}
}
totalCnt += 1
this
......@@ -503,9 +534,12 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
totalCnt += other.totalCnt
lossSum += other.lossSum
gradientCoefficientSum += other.gradientCoefficientSum
gradientInterceptSum += other.gradientInterceptSum
gradientLogSigmaSum += other.gradientLogSigmaSum
var i = 0
val len = this.gradientSumArray.length
while (i < len) {
this.gradientSumArray(i) += other.gradientSumArray(i)
i += 1
}
}
this
}
......@@ -516,12 +550,15 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
* It returns the loss and gradient at a particular point (parameters).
* It's used in Breeze's convex optimization routines.
*/
private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean)
extends DiffFunction[BDV[Double]] {
private class AFTCostFun(
data: RDD[AFTPoint],
fitIntercept: Boolean,
featuresStd: Array[Double]) extends DiffFunction[BDV[Double]] {
override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = {
val aftAggregator = data.treeAggregate(new AFTAggregator(parameters, fitIntercept))(
val aftAggregator = data.treeAggregate(
new AFTAggregator(parameters, fitIntercept, featuresStd))(
seqOp = (c, v) => (c, v) match {
case (aggregator, instance) => aggregator.add(instance)
},
......
......@@ -33,6 +33,7 @@ class AFTSurvivalRegressionSuite
@transient var datasetUnivariate: DataFrame = _
@transient var datasetMultivariate: DataFrame = _
@transient var datasetUnivariateScaled: DataFrame = _
override def beforeAll(): Unit = {
super.beforeAll()
......@@ -42,6 +43,11 @@ class AFTSurvivalRegressionSuite
datasetMultivariate = sqlContext.createDataFrame(
sc.parallelize(generateAFTInput(
2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0)))
datasetUnivariateScaled = sqlContext.createDataFrame(
sc.parallelize(generateAFTInput(
1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)).map { x =>
AFTPoint(Vectors.dense(x.features(0) * 1.0E3), x.label, x.censor)
})
}
/**
......@@ -356,6 +362,22 @@ class AFTSurvivalRegressionSuite
}
}
test("numerical stability of standardization") {
val trainer = new AFTSurvivalRegression()
val model1 = trainer.fit(datasetUnivariate)
val model2 = trainer.fit(datasetUnivariateScaled)
/**
* During training we standardize the dataset first, so no matter how we multiple
* a scaling factor into the dataset, the convergence rate should be the same,
* and the coefficients should equal to the original coefficients multiple by
* the scaling factor. It will have no effect on the intercept and scale.
*/
assert(model1.coefficients(0) ~== model2.coefficients(0) * 1.0E3 absTol 0.01)
assert(model1.intercept ~== model2.intercept absTol 0.01)
assert(model1.scale ~== model2.scale absTol 0.01)
}
test("read/write") {
def checkModelData(
model: AFTSurvivalRegressionModel,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment