Skip to content
Snippets Groups Projects
Commit 182e1190 authored by Yanbo Liang's avatar Yanbo Liang
Browse files

[SPARK-16933][ML] Fix AFTAggregator in AFTSurvivalRegression serializes unnecessary data.

## What changes were proposed in this pull request?
Similar to ```LeastSquaresAggregator``` in #14109, ```AFTAggregator``` used for ```AFTSurvivalRegression``` ends up serializing the ```parameters``` and ```featuresStd```, which is not necessary and can cause performance issues for high dimensional data. This patch removes this serialization. This PR is highly inspired by #14109.

## How was this patch tested?
I tested this locally and verified the serialization reduction.

Before patch
![image](https://cloud.githubusercontent.com/assets/1962026/17512035/abb93f04-5dda-11e6-97d3-8ae6b61a0dfd.png)

After patch
![image](https://cloud.githubusercontent.com/assets/1962026/17512024/9e0dc44c-5dda-11e6-93d0-6e130ba0d6aa.png)

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #14519 from yanboliang/spark-16933.
parent 511f52f8
No related branches found
No related tags found
No related merge requests found
...@@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path ...@@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT}
...@@ -219,7 +220,9 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S ...@@ -219,7 +220,9 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
"columns. This behavior is different from R survival::survreg.") "columns. This behavior is different from R survival::survreg.")
} }
val costFun = new AFTCostFun(instances, $(fitIntercept), featuresStd) val bcFeaturesStd = instances.context.broadcast(featuresStd)
val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd)
val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
/* /*
...@@ -247,6 +250,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S ...@@ -247,6 +250,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
state.x.toArray.clone() state.x.toArray.clone()
} }
bcFeaturesStd.destroy(blocking = false)
if (handlePersistence) instances.unpersist() if (handlePersistence) instances.unpersist()
val rawCoefficients = parameters.slice(2, parameters.length) val rawCoefficients = parameters.slice(2, parameters.length)
...@@ -478,26 +482,29 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] ...@@ -478,26 +482,29 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
* $$ * $$
* </blockquote></p> * </blockquote></p>
* *
* @param parameters including three part: The log of scale parameter, the intercept and * @param bcParameters The broadcasted value includes three part: The log of scale parameter,
* regression coefficients corresponding to the features. * the intercept and regression coefficients corresponding to the features.
* @param fitIntercept Whether to fit an intercept term. * @param fitIntercept Whether to fit an intercept term.
* @param featuresStd The standard deviation values of the features. * @param bcFeaturesStd The broadcast standard deviation values of the features.
*/ */
private class AFTAggregator( private class AFTAggregator(
parameters: BDV[Double], bcParameters: Broadcast[BDV[Double]],
fitIntercept: Boolean, fitIntercept: Boolean,
featuresStd: Array[Double]) extends Serializable { bcFeaturesStd: Broadcast[Array[Double]]) extends Serializable {
private val length = bcParameters.value.length
// make transient so we do not serialize between aggregation stages
@transient private lazy val parameters = bcParameters.value
// the regression coefficients to the covariates // the regression coefficients to the covariates
private val coefficients = parameters.slice(2, parameters.length) @transient private lazy val coefficients = parameters.slice(2, length)
private val intercept = parameters(1) @transient private lazy val intercept = parameters(1)
// sigma is the scale parameter of the AFT model // sigma is the scale parameter of the AFT model
private val sigma = math.exp(parameters(0)) @transient private lazy val sigma = math.exp(parameters(0))
private var totalCnt: Long = 0L private var totalCnt: Long = 0L
private var lossSum = 0.0 private var lossSum = 0.0
// Here we optimize loss function over log(sigma), intercept and coefficients // Here we optimize loss function over log(sigma), intercept and coefficients
private val gradientSumArray = Array.ofDim[Double](parameters.length) private val gradientSumArray = Array.ofDim[Double](length)
def count: Long = totalCnt def count: Long = totalCnt
def loss: Double = { def loss: Double = {
...@@ -524,11 +531,13 @@ private class AFTAggregator( ...@@ -524,11 +531,13 @@ private class AFTAggregator(
val ti = data.label val ti = data.label
val delta = data.censor val delta = data.censor
val localFeaturesStd = bcFeaturesStd.value
val margin = { val margin = {
var sum = 0.0 var sum = 0.0
xi.foreachActive { (index, value) => xi.foreachActive { (index, value) =>
if (featuresStd(index) != 0.0 && value != 0.0) { if (localFeaturesStd(index) != 0.0 && value != 0.0) {
sum += coefficients(index) * (value / featuresStd(index)) sum += coefficients(index) * (value / localFeaturesStd(index))
} }
} }
sum + intercept sum + intercept
...@@ -542,8 +551,8 @@ private class AFTAggregator( ...@@ -542,8 +551,8 @@ private class AFTAggregator(
gradientSumArray(0) += delta + multiplier * sigma * epsilon gradientSumArray(0) += delta + multiplier * sigma * epsilon
gradientSumArray(1) += { if (fitIntercept) multiplier else 0.0 } gradientSumArray(1) += { if (fitIntercept) multiplier else 0.0 }
xi.foreachActive { (index, value) => xi.foreachActive { (index, value) =>
if (featuresStd(index) != 0.0 && value != 0.0) { if (localFeaturesStd(index) != 0.0 && value != 0.0) {
gradientSumArray(index + 2) += multiplier * (value / featuresStd(index)) gradientSumArray(index + 2) += multiplier * (value / localFeaturesStd(index))
} }
} }
...@@ -565,8 +574,7 @@ private class AFTAggregator( ...@@ -565,8 +574,7 @@ private class AFTAggregator(
lossSum += other.lossSum lossSum += other.lossSum
var i = 0 var i = 0
val len = this.gradientSumArray.length while (i < length) {
while (i < len) {
this.gradientSumArray(i) += other.gradientSumArray(i) this.gradientSumArray(i) += other.gradientSumArray(i)
i += 1 i += 1
} }
...@@ -583,12 +591,14 @@ private class AFTAggregator( ...@@ -583,12 +591,14 @@ private class AFTAggregator(
private class AFTCostFun( private class AFTCostFun(
data: RDD[AFTPoint], data: RDD[AFTPoint],
fitIntercept: Boolean, fitIntercept: Boolean,
featuresStd: Array[Double]) extends DiffFunction[BDV[Double]] { bcFeaturesStd: Broadcast[Array[Double]]) extends DiffFunction[BDV[Double]] {
override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = { override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = {
val bcParameters = data.context.broadcast(parameters)
val aftAggregator = data.treeAggregate( val aftAggregator = data.treeAggregate(
new AFTAggregator(parameters, fitIntercept, featuresStd))( new AFTAggregator(bcParameters, fitIntercept, bcFeaturesStd))(
seqOp = (c, v) => (c, v) match { seqOp = (c, v) => (c, v) match {
case (aggregator, instance) => aggregator.add(instance) case (aggregator, instance) => aggregator.add(instance)
}, },
...@@ -596,6 +606,7 @@ private class AFTCostFun( ...@@ -596,6 +606,7 @@ private class AFTCostFun(
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
}) })
bcParameters.destroy(blocking = false)
(aftAggregator.loss, aftAggregator.gradient) (aftAggregator.loss, aftAggregator.gradient)
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment