Skip to content
Snippets Groups Projects
Commit 33ef3aa7 authored by Narine Kokhlikyan's avatar Narine Kokhlikyan Committed by Xiangrui Meng
Browse files

[SPARK-13295][ ML, MLLIB ] AFTSurvivalRegression.AFTAggregator improvements -...

[SPARK-13295][ ML, MLLIB ] AFTSurvivalRegression.AFTAggregator improvements - avoid creating new instances of arrays/vectors for each record

As also mentioned/marked by TODO in AFTAggregator.AFTAggregator.add(data: AFTPoint) method a new array is being created for intercept value and it is being concatenated
with another array which contains the betas, the resulted Array is being converted into a Dense vector which in its turn is being converted into breeze vector.
This is expensive and not necessarily beautiful.

I've tried to solve above mentioned problem by simple algebraic decompositions - keeping and treating intercept independently.

Please let me know what do you think and if you have any questions.

Thanks,
Narine

Author: Narine Kokhlikyan <narine.kokhlikyan@gmail.com>

Closes #11179 from NarineK/survivaloptim.
parent 02b1feff
No related branches found
No related tags found
No related merge requests found
...@@ -437,23 +437,25 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] ...@@ -437,23 +437,25 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
extends Serializable { extends Serializable {
// beta is the intercept and regression coefficients to the covariates // the regression coefficients to the covariates
private val beta = parameters.slice(1, parameters.length) private val coefficients = parameters.slice(2, parameters.length)
private val intercept = parameters.valueAt(1)
// sigma is the scale parameter of the AFT model // sigma is the scale parameter of the AFT model
private val sigma = math.exp(parameters(0)) private val sigma = math.exp(parameters(0))
private var totalCnt: Long = 0L private var totalCnt: Long = 0L
private var lossSum = 0.0 private var lossSum = 0.0
private var gradientBetaSum = BDV.zeros[Double](beta.length) private var gradientCoefficientSum = BDV.zeros[Double](coefficients.length)
private var gradientInterceptSum = 0.0
private var gradientLogSigmaSum = 0.0 private var gradientLogSigmaSum = 0.0
def count: Long = totalCnt def count: Long = totalCnt
def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt
// Here we optimize loss function over beta and log(sigma) // Here we optimize loss function over coefficients, intercept and log(sigma)
def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)), def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)),
gradientBetaSum/totalCnt.toDouble) BDV(Array(gradientInterceptSum/totalCnt.toDouble)), gradientCoefficientSum/totalCnt.toDouble)
/** /**
* Add a new training data to this AFTAggregator, and update the loss and gradient * Add a new training data to this AFTAggregator, and update the loss and gradient
...@@ -464,15 +466,12 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) ...@@ -464,15 +466,12 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
*/ */
def add(data: AFTPoint): this.type = { def add(data: AFTPoint): this.type = {
// TODO: Don't create a new xi vector each time. val interceptFlag = if (fitIntercept) 1.0 else 0.0
val xi = if (fitIntercept) {
Vectors.dense(Array(1.0) ++ data.features.toArray).toBreeze val xi = data.features.toBreeze
} else {
Vectors.dense(Array(0.0) ++ data.features.toArray).toBreeze
}
val ti = data.label val ti = data.label
val delta = data.censor val delta = data.censor
val epsilon = (math.log(ti) - beta.dot(xi)) / sigma val epsilon = (math.log(ti) - coefficients.dot(xi) - intercept * interceptFlag ) / sigma
lossSum += math.log(sigma) * delta lossSum += math.log(sigma) * delta
lossSum += (math.exp(epsilon) - delta * epsilon) lossSum += (math.exp(epsilon) - delta * epsilon)
...@@ -481,8 +480,10 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) ...@@ -481,8 +480,10 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
assert(!lossSum.isInfinity, assert(!lossSum.isInfinity,
s"AFTAggregator loss sum is infinity. Error for unknown reason.") s"AFTAggregator loss sum is infinity. Error for unknown reason.")
gradientBetaSum += xi * (delta - math.exp(epsilon)) / sigma val deltaMinusExpEps = delta - math.exp(epsilon)
gradientLogSigmaSum += delta + (delta - math.exp(epsilon)) * epsilon gradientCoefficientSum += xi * deltaMinusExpEps / sigma
gradientInterceptSum += interceptFlag * deltaMinusExpEps / sigma
gradientLogSigmaSum += delta + deltaMinusExpEps * epsilon
totalCnt += 1 totalCnt += 1
this this
...@@ -501,7 +502,8 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) ...@@ -501,7 +502,8 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
totalCnt += other.totalCnt totalCnt += other.totalCnt
lossSum += other.lossSum lossSum += other.lossSum
gradientBetaSum += other.gradientBetaSum gradientCoefficientSum += other.gradientCoefficientSum
gradientInterceptSum += other.gradientInterceptSum
gradientLogSigmaSum += other.gradientLogSigmaSum gradientLogSigmaSum += other.gradientLogSigmaSum
} }
this this
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment