diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 6d5e398dfe155191823bc03a1265709752ce12fa..76be4204e90508448b4bff2760e307b3bd9a3027 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -26,6 +26,7 @@ import org.apache.hadoop.fs.Path
 
 import org.apache.spark.SparkException
 import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.linalg.{Vector, Vectors}
@@ -82,6 +83,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
   /**
    * Set the regularization parameter.
    * Default is 0.0.
+   *
    * @group setParam
    */
   @Since("1.3.0")
@@ -91,6 +93,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
   /**
    * Set if we should fit the intercept
    * Default is true.
+   *
    * @group setParam
    */
   @Since("1.5.0")
@@ -104,6 +107,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
    * the models should be always converged to the same solution when no regularization
    * is applied. In R's GLMNET package, the default behavior is true as well.
    * Default is true.
+   *
    * @group setParam
    */
   @Since("1.5.0")
@@ -115,6 +119,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
    * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
    * For 0 < alpha < 1, the penalty is a combination of L1 and L2.
    * Default is 0.0 which is an L2 penalty.
+   *
    * @group setParam
    */
   @Since("1.4.0")
@@ -124,6 +129,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
   /**
    * Set the maximum number of iterations.
    * Default is 100.
+   *
    * @group setParam
    */
   @Since("1.3.0")
@@ -134,6 +140,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
    * Set the convergence tolerance of iterations.
    * Smaller value will lead to higher accuracy with the cost of more iterations.
    * Default is 1E-6.
+   *
    * @group setParam
    */
   @Since("1.4.0")
@@ -144,6 +151,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
    * Whether to over-/under-sample training instances according to the given weights in weightCol.
    * If not set or empty, all instances are treated equally (weight 1.0).
    * Default is not set, so all instances have weight one.
+   *
    * @group setParam
    */
   @Since("1.6.0")
@@ -157,6 +165,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
    * solution to the linear regression problem.
    * The default value is "auto" which means that the solver algorithm is
    * selected automatically.
+   *
    * @group setParam
    */
   @Since("1.6.0")
@@ -270,6 +279,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
     val yStd = if (rawYStd > 0) rawYStd else math.abs(yMean)
     val featuresMean = featuresSummarizer.mean.toArray
     val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
+    val bcFeaturesMean = instances.context.broadcast(featuresMean)
+    val bcFeaturesStd = instances.context.broadcast(featuresStd)
 
     if (!$(fitIntercept) && (0 until numFeatures).exists { i =>
       featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) {
@@ -285,7 +296,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
     val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam
 
     val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept),
-      $(standardization), featuresStd, featuresMean, effectiveL2RegParam)
+      $(standardization), bcFeaturesStd, bcFeaturesMean, effectiveL2RegParam)
 
     val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
       new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
@@ -330,6 +341,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
         throw new SparkException(msg)
       }
 
+      bcFeaturesMean.destroy(blocking = false)
+      bcFeaturesStd.destroy(blocking = false)
+
       /*
          The coefficients are trained in the scaled space; we're converting them back to
          the original space.
@@ -419,6 +433,7 @@ class LinearRegressionModel private[ml] (
 
   /**
    * Evaluates the model on a test dataset.
+   *
    * @param dataset Test dataset to evaluate model on.
    */
   @Since("2.0.0")
@@ -544,6 +559,7 @@ class LinearRegressionTrainingSummary private[regression] (
    * Number of training iterations until termination
    *
    * This value is only available when using the "l-bfgs" solver.
+   *
    * @see [[LinearRegression.solver]]
    */
   @Since("1.5.0")
@@ -862,27 +878,31 @@ class LinearRegressionSummary private[regression] (
  *    $$
  * </blockquote></p>
  *
- * @param coefficients The coefficients corresponding to the features.
+ * @param bcCoefficients The broadcast coefficients corresponding to the features.
  * @param labelStd The standard deviation value of the label.
  * @param labelMean The mean value of the label.
  * @param fitIntercept Whether to fit an intercept term.
- * @param featuresStd The standard deviation values of the features.
- * @param featuresMean The mean values of the features.
+ * @param bcFeaturesStd The broadcast standard deviation values of the features.
+ * @param bcFeaturesMean The broadcast mean values of the features.
  */
 private class LeastSquaresAggregator(
-    coefficients: Vector,
+    bcCoefficients: Broadcast[Vector],
     labelStd: Double,
     labelMean: Double,
     fitIntercept: Boolean,
-    featuresStd: Array[Double],
-    featuresMean: Array[Double]) extends Serializable {
+    bcFeaturesStd: Broadcast[Array[Double]],
+    bcFeaturesMean: Broadcast[Array[Double]]) extends Serializable {
 
   private var totalCnt: Long = 0L
   private var weightSum: Double = 0.0
   private var lossSum = 0.0
 
-  private val (effectiveCoefficientsArray: Array[Double], offset: Double, dim: Int) = {
-    val coefficientsArray = coefficients.toArray.clone()
+  private val dim = bcCoefficients.value.size
+  // make transient so we do not serialize between aggregation stages
+  @transient private lazy val featuresStd = bcFeaturesStd.value
+  @transient private lazy val effectiveCoefAndOffset = {
+    val coefficientsArray = bcCoefficients.value.toArray.clone()
+    val featuresMean = bcFeaturesMean.value
     var sum = 0.0
     var i = 0
     val len = coefficientsArray.length
@@ -896,10 +916,11 @@ private class LeastSquaresAggregator(
       i += 1
     }
     val offset = if (fitIntercept) labelMean / labelStd - sum else 0.0
-    (coefficientsArray, offset, coefficientsArray.length)
+    (Vectors.dense(coefficientsArray), offset)
   }
-
-  private val effectiveCoefficientsVector = Vectors.dense(effectiveCoefficientsArray)
+  // do not use tuple assignment above because it will circumvent the @transient tag
+  @transient private lazy val effectiveCoefficientsVector = effectiveCoefAndOffset._1
+  @transient private lazy val offset = effectiveCoefAndOffset._2
 
   private val gradientSumArray = Array.ofDim[Double](dim)
 
@@ -922,9 +943,10 @@ private class LeastSquaresAggregator(
 
       if (diff != 0) {
         val localGradientSumArray = gradientSumArray
+        val localFeaturesStd = featuresStd
         features.foreachActive { (index, value) =>
-          if (featuresStd(index) != 0.0 && value != 0.0) {
-            localGradientSumArray(index) += weight * diff * value / featuresStd(index)
+          if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+            localGradientSumArray(index) += weight * diff * value / localFeaturesStd(index)
           }
         }
         lossSum += weight * diff * diff / 2.0
@@ -992,23 +1014,26 @@ private class LeastSquaresCostFun(
     labelMean: Double,
     fitIntercept: Boolean,
     standardization: Boolean,
-    featuresStd: Array[Double],
-    featuresMean: Array[Double],
+    bcFeaturesStd: Broadcast[Array[Double]],
+    bcFeaturesMean: Broadcast[Array[Double]],
     effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
 
   override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
     val coeffs = Vectors.fromBreeze(coefficients)
+    val bcCoeffs = instances.context.broadcast(coeffs)
+    val localFeaturesStd = bcFeaturesStd.value
 
     val leastSquaresAggregator = {
       val seqOp = (c: LeastSquaresAggregator, instance: Instance) => c.add(instance)
       val combOp = (c1: LeastSquaresAggregator, c2: LeastSquaresAggregator) => c1.merge(c2)
 
       instances.treeAggregate(
-        new LeastSquaresAggregator(coeffs, labelStd, labelMean, fitIntercept, featuresStd,
-          featuresMean))(seqOp, combOp)
+        new LeastSquaresAggregator(bcCoeffs, labelStd, labelMean, fitIntercept, bcFeaturesStd,
+          bcFeaturesMean))(seqOp, combOp)
     }
 
     val totalGradientArray = leastSquaresAggregator.gradient.toArray
+    bcCoeffs.destroy(blocking = false)
 
     val regVal = if (effectiveL2regParam == 0.0) {
       0.0
@@ -1022,13 +1047,13 @@ private class LeastSquaresCostFun(
             totalGradientArray(index) += effectiveL2regParam * value
             value * value
           } else {
-            if (featuresStd(index) != 0.0) {
+            if (localFeaturesStd(index) != 0.0) {
               // If `standardization` is false, we still standardize the data
               // to improve the rate of convergence; as a result, we have to
               // perform this reverse standardization by penalizing each component
               // differently to get effectively the same objective function when
               // the training dataset is not standardized.
-              val temp = value / (featuresStd(index) * featuresStd(index))
+              val temp = value / (localFeaturesStd(index) * localFeaturesStd(index))
               totalGradientArray(index) += effectiveL2regParam * temp
               value * temp
             } else {