diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index be234f7fea44fb77547149d07f03ab4eac475228..3179f4882fd49145a17dd4a26b52416994bd0aba 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path
 
 import org.apache.spark.SparkException
 import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT}
@@ -219,7 +220,9 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
         "columns. This behavior is different from R survival::survreg.")
     }
 
-    val costFun = new AFTCostFun(instances, $(fitIntercept), featuresStd)
+    val bcFeaturesStd = instances.context.broadcast(featuresStd)
+
+    val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd)
     val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
 
     /*
@@ -247,6 +250,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
       state.x.toArray.clone()
     }
 
+    bcFeaturesStd.destroy(blocking = false)
     if (handlePersistence) instances.unpersist()
 
     val rawCoefficients = parameters.slice(2, parameters.length)
@@ -478,26 +482,29 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
  *    $$
  * </blockquote></p>
  *
- * @param parameters including three part: The log of scale parameter, the intercept and
- *                regression coefficients corresponding to the features.
+ * @param bcParameters The broadcasted value includes three part: The log of scale parameter,
+ *                     the intercept and regression coefficients corresponding to the features.
  * @param fitIntercept Whether to fit an intercept term.
- * @param featuresStd The standard deviation values of the features.
+ * @param bcFeaturesStd The broadcast standard deviation values of the features.
  */
 private class AFTAggregator(
-    parameters: BDV[Double],
+    bcParameters: Broadcast[BDV[Double]],
     fitIntercept: Boolean,
-    featuresStd: Array[Double]) extends Serializable {
+    bcFeaturesStd: Broadcast[Array[Double]]) extends Serializable {
 
+  private val length = bcParameters.value.length
+  // make transient so we do not serialize between aggregation stages
+  @transient private lazy val parameters = bcParameters.value
   // the regression coefficients to the covariates
-  private val coefficients = parameters.slice(2, parameters.length)
-  private val intercept = parameters(1)
+  @transient private lazy val coefficients = parameters.slice(2, length)
+  @transient private lazy val intercept = parameters(1)
   // sigma is the scale parameter of the AFT model
-  private val sigma = math.exp(parameters(0))
+  @transient private lazy val sigma = math.exp(parameters(0))
 
   private var totalCnt: Long = 0L
   private var lossSum = 0.0
   // Here we optimize loss function over log(sigma), intercept and coefficients
-  private val gradientSumArray = Array.ofDim[Double](parameters.length)
+  private val gradientSumArray = Array.ofDim[Double](length)
 
   def count: Long = totalCnt
   def loss: Double = {
@@ -524,11 +531,13 @@ private class AFTAggregator(
     val ti = data.label
     val delta = data.censor
 
+    val localFeaturesStd = bcFeaturesStd.value
+
     val margin = {
       var sum = 0.0
       xi.foreachActive { (index, value) =>
-        if (featuresStd(index) != 0.0 && value != 0.0) {
-          sum += coefficients(index) * (value / featuresStd(index))
+        if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+          sum += coefficients(index) * (value / localFeaturesStd(index))
         }
       }
       sum + intercept
@@ -542,8 +551,8 @@ private class AFTAggregator(
     gradientSumArray(0) += delta + multiplier * sigma * epsilon
     gradientSumArray(1) += { if (fitIntercept) multiplier else 0.0 }
     xi.foreachActive { (index, value) =>
-      if (featuresStd(index) != 0.0 && value != 0.0) {
-        gradientSumArray(index + 2) += multiplier * (value / featuresStd(index))
+      if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+        gradientSumArray(index + 2) += multiplier * (value / localFeaturesStd(index))
       }
     }
 
@@ -565,8 +574,7 @@ private class AFTAggregator(
       lossSum += other.lossSum
 
       var i = 0
-      val len = this.gradientSumArray.length
-      while (i < len) {
+      while (i < length) {
         this.gradientSumArray(i) += other.gradientSumArray(i)
         i += 1
       }
@@ -583,12 +591,14 @@ private class AFTAggregator(
 private class AFTCostFun(
     data: RDD[AFTPoint],
     fitIntercept: Boolean,
-    featuresStd: Array[Double]) extends DiffFunction[BDV[Double]] {
+    bcFeaturesStd: Broadcast[Array[Double]]) extends DiffFunction[BDV[Double]] {
 
   override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = {
 
+    val bcParameters = data.context.broadcast(parameters)
+
     val aftAggregator = data.treeAggregate(
-      new AFTAggregator(parameters, fitIntercept, featuresStd))(
+      new AFTAggregator(bcParameters, fitIntercept, bcFeaturesStd))(
       seqOp = (c, v) => (c, v) match {
         case (aggregator, instance) => aggregator.add(instance)
       },
@@ -596,6 +606,7 @@ private class AFTCostFun(
         case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
       })
 
+    bcParameters.destroy(blocking = false)
     (aftAggregator.loss, aftAggregator.gradient)
   }
 }