diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index e8a1ff2278a92a5f1e4c0ec32c79c59f3311616c..1e5b4cb83c652e7195f2c2689975500263f73d8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -437,23 +437,25 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) extends Serializable { - // beta is the intercept and regression coefficients to the covariates - private val beta = parameters.slice(1, parameters.length) + // the regression coefficients to the covariates + private val coefficients = parameters.slice(2, parameters.length) + private val intercept = parameters.valueAt(1) // sigma is the scale parameter of the AFT model private val sigma = math.exp(parameters(0)) private var totalCnt: Long = 0L private var lossSum = 0.0 - private var gradientBetaSum = BDV.zeros[Double](beta.length) + private var gradientCoefficientSum = BDV.zeros[Double](coefficients.length) + private var gradientInterceptSum = 0.0 private var gradientLogSigmaSum = 0.0 def count: Long = totalCnt def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt - // Here we optimize loss function over beta and log(sigma) + // Here we optimize loss function over coefficients, intercept and log(sigma) def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)), - gradientBetaSum/totalCnt.toDouble) + BDV(Array(gradientInterceptSum/totalCnt.toDouble)), gradientCoefficientSum/totalCnt.toDouble) /** * Add a new training data to this AFTAggregator, and update the loss and gradient @@ -464,15 +466,12 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) */ def add(data: AFTPoint): this.type = { - // TODO: Don't create a new xi vector each time. - val xi = if (fitIntercept) { - Vectors.dense(Array(1.0) ++ data.features.toArray).toBreeze - } else { - Vectors.dense(Array(0.0) ++ data.features.toArray).toBreeze - } + val interceptFlag = if (fitIntercept) 1.0 else 0.0 + + val xi = data.features.toBreeze val ti = data.label val delta = data.censor - val epsilon = (math.log(ti) - beta.dot(xi)) / sigma + val epsilon = (math.log(ti) - coefficients.dot(xi) - intercept * interceptFlag ) / sigma lossSum += math.log(sigma) * delta lossSum += (math.exp(epsilon) - delta * epsilon) @@ -481,8 +480,10 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) assert(!lossSum.isInfinity, s"AFTAggregator loss sum is infinity. Error for unknown reason.") - gradientBetaSum += xi * (delta - math.exp(epsilon)) / sigma - gradientLogSigmaSum += delta + (delta - math.exp(epsilon)) * epsilon + val deltaMinusExpEps = delta - math.exp(epsilon) + gradientCoefficientSum += xi * deltaMinusExpEps / sigma + gradientInterceptSum += interceptFlag * deltaMinusExpEps / sigma + gradientLogSigmaSum += delta + deltaMinusExpEps * epsilon totalCnt += 1 this @@ -501,7 +502,8 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) totalCnt += other.totalCnt lossSum += other.lossSum - gradientBetaSum += other.gradientBetaSum + gradientCoefficientSum += other.gradientCoefficientSum + gradientInterceptSum += other.gradientInterceptSum gradientLogSigmaSum += other.gradientLogSigmaSum } this