diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index be234f7fea44fb77547149d07f03ab4eac475228..3179f4882fd49145a17dd4a26b52416994bd0aba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} @@ -219,7 +220,9 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S "columns. This behavior is different from R survival::survreg.") } - val costFun = new AFTCostFun(instances, $(fitIntercept), featuresStd) + val bcFeaturesStd = instances.context.broadcast(featuresStd) + + val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd) val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) /* @@ -247,6 +250,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S state.x.toArray.clone() } + bcFeaturesStd.destroy(blocking = false) if (handlePersistence) instances.unpersist() val rawCoefficients = parameters.slice(2, parameters.length) @@ -478,26 +482,29 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] * $$ * </blockquote></p> * - * @param parameters including three part: The log of scale parameter, the intercept and - * regression coefficients corresponding to the features. + * @param bcParameters The broadcasted value includes three part: The log of scale parameter, + * the intercept and regression coefficients corresponding to the features. * @param fitIntercept Whether to fit an intercept term. - * @param featuresStd The standard deviation values of the features. + * @param bcFeaturesStd The broadcast standard deviation values of the features. */ private class AFTAggregator( - parameters: BDV[Double], + bcParameters: Broadcast[BDV[Double]], fitIntercept: Boolean, - featuresStd: Array[Double]) extends Serializable { + bcFeaturesStd: Broadcast[Array[Double]]) extends Serializable { + private val length = bcParameters.value.length + // make transient so we do not serialize between aggregation stages + @transient private lazy val parameters = bcParameters.value // the regression coefficients to the covariates - private val coefficients = parameters.slice(2, parameters.length) - private val intercept = parameters(1) + @transient private lazy val coefficients = parameters.slice(2, length) + @transient private lazy val intercept = parameters(1) // sigma is the scale parameter of the AFT model - private val sigma = math.exp(parameters(0)) + @transient private lazy val sigma = math.exp(parameters(0)) private var totalCnt: Long = 0L private var lossSum = 0.0 // Here we optimize loss function over log(sigma), intercept and coefficients - private val gradientSumArray = Array.ofDim[Double](parameters.length) + private val gradientSumArray = Array.ofDim[Double](length) def count: Long = totalCnt def loss: Double = { @@ -524,11 +531,13 @@ private class AFTAggregator( val ti = data.label val delta = data.censor + val localFeaturesStd = bcFeaturesStd.value + val margin = { var sum = 0.0 xi.foreachActive { (index, value) => - if (featuresStd(index) != 0.0 && value != 0.0) { - sum += coefficients(index) * (value / featuresStd(index)) + if (localFeaturesStd(index) != 0.0 && value != 0.0) { + sum += coefficients(index) * (value / localFeaturesStd(index)) } } sum + intercept @@ -542,8 +551,8 @@ private class AFTAggregator( gradientSumArray(0) += delta + multiplier * sigma * epsilon gradientSumArray(1) += { if (fitIntercept) multiplier else 0.0 } xi.foreachActive { (index, value) => - if (featuresStd(index) != 0.0 && value != 0.0) { - gradientSumArray(index + 2) += multiplier * (value / featuresStd(index)) + if (localFeaturesStd(index) != 0.0 && value != 0.0) { + gradientSumArray(index + 2) += multiplier * (value / localFeaturesStd(index)) } } @@ -565,8 +574,7 @@ private class AFTAggregator( lossSum += other.lossSum var i = 0 - val len = this.gradientSumArray.length - while (i < len) { + while (i < length) { this.gradientSumArray(i) += other.gradientSumArray(i) i += 1 } @@ -583,12 +591,14 @@ private class AFTAggregator( private class AFTCostFun( data: RDD[AFTPoint], fitIntercept: Boolean, - featuresStd: Array[Double]) extends DiffFunction[BDV[Double]] { + bcFeaturesStd: Broadcast[Array[Double]]) extends DiffFunction[BDV[Double]] { override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = { + val bcParameters = data.context.broadcast(parameters) + val aftAggregator = data.treeAggregate( - new AFTAggregator(parameters, fitIntercept, featuresStd))( + new AFTAggregator(bcParameters, fitIntercept, bcFeaturesStd))( seqOp = (c, v) => (c, v) match { case (aggregator, instance) => aggregator.add(instance) }, @@ -596,6 +606,7 @@ private class AFTCostFun( case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) }) + bcParameters.destroy(blocking = false) (aftAggregator.loss, aftAggregator.gradient) } }