diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index ea31c68e4c943516d571fb6997d327b5a71c3b2e..757d52052d87fa10bc929882cc7e72914750e398 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -48,7 +48,7 @@ import org.apache.spark.storage.StorageLevel
  */
 private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
   with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
-  with HasStandardization with HasWeightCol with HasThreshold {
+  with HasStandardization with HasWeightCol with HasThreshold with HasAggregationDepth {
 
   /**
    * Set threshold in binary classification, in range [0, 1].
@@ -256,6 +256,17 @@ class LogisticRegression @Since("1.2.0") (
   @Since("1.5.0")
   override def getThresholds: Array[Double] = super.getThresholds
 
+  /**
+   * Suggested depth for treeAggregate (>= 2).
+   * If the dimensions of features or the number of partitions are large,
+   * this param could be adjusted to a larger size.
+   * Default is 2.
+   * @group expertSetParam
+   */
+  @Since("2.1.0")
+  def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
+  setDefault(aggregationDepth -> 2)
+
   private var optInitialModel: Option[LogisticRegressionModel] = None
 
   /** @group setParam */
@@ -294,7 +305,8 @@ class LogisticRegression @Since("1.2.0") (
           (c1._1.merge(c2._1), c1._2.merge(c2._2))
 
       instances.treeAggregate(
-        new MultivariateOnlineSummarizer, new MultiClassSummarizer)(seqOp, combOp)
+        new MultivariateOnlineSummarizer, new MultiClassSummarizer
+      )(seqOp, combOp, $(aggregationDepth))
     }
 
     val histogram = labelSummarizer.histogram
@@ -358,7 +370,7 @@ class LogisticRegression @Since("1.2.0") (
 
         val bcFeaturesStd = instances.context.broadcast(featuresStd)
         val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
-          $(standardization), bcFeaturesStd, regParamL2, multinomial = false)
+          $(standardization), bcFeaturesStd, regParamL2, multinomial = false, $(aggregationDepth))
 
         val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
           new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
@@ -1331,8 +1343,8 @@ private class LogisticCostFun(
     standardization: Boolean,
     bcFeaturesStd: Broadcast[Array[Double]],
     regParamL2: Double,
-    multinomial: Boolean) extends DiffFunction[BDV[Double]] {
-
+    multinomial: Boolean,
+    aggregationDepth: Int) extends DiffFunction[BDV[Double]] {
 
   override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
     val coeffs = Vectors.fromBreeze(coefficients)
@@ -1347,7 +1359,7 @@ private class LogisticCostFun(
       instances.treeAggregate(
         new LogisticAggregator(bcCoeffs, bcFeaturesStd, numClasses, fitIntercept,
           multinomial)
-      )(seqOp, combOp)
+      )(seqOp, combOp, aggregationDepth)
     }
 
     val totalGradientArray = logisticAggregator.gradient.toArray
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultinomialLogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultinomialLogisticRegression.scala
index dfadd68c5f476f5a47cd2ca85850e44b6869f233..f85ac76a8d129cf7e0f79c1f3f000f40f6a41f8c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultinomialLogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultinomialLogisticRegression.scala
@@ -44,7 +44,8 @@ import org.apache.spark.storage.StorageLevel
  */
 private[classification] trait MultinomialLogisticRegressionParams
   extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter
-    with HasFitIntercept with HasTol with HasStandardization with HasWeightCol {
+    with HasFitIntercept with HasTol with HasStandardization with HasWeightCol
+    with HasAggregationDepth {
 
   /**
    * Set thresholds in multiclass (or binary) classification to adjust the probability of
@@ -163,6 +164,17 @@ class MultinomialLogisticRegression @Since("2.1.0") (
   @Since("2.1.0")
   override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
 
+  /**
+   * Suggested depth for treeAggregate (>= 2).
+   * If the dimensions of features or the number of partitions are large,
+   * this param could be adjusted to a larger size.
+   * Default is 2.
+   * @group expertSetParam
+   */
+  @Since("2.1.0")
+  def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
+  setDefault(aggregationDepth -> 2)
+
   override protected[spark] def train(dataset: Dataset[_]): MultinomialLogisticRegressionModel = {
     val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
     val instances: RDD[Instance] =
@@ -245,7 +257,7 @@ class MultinomialLogisticRegression @Since("2.1.0") (
 
         val bcFeaturesStd = instances.context.broadcast(featuresStd)
         val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
-          $(standardization), bcFeaturesStd, regParamL2, multinomial = true)
+          $(standardization), bcFeaturesStd, regParamL2, multinomial = true, $(aggregationDepth))
 
         val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
           new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 4ab0c16a1b4d00d1a2e22f53cfe878fcb21323ce..0f48a16a429ffe5a6cea1c6ce6b22f371faf26eb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -78,7 +78,9 @@ private[shared] object SharedParamsCodeGen {
       ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " +
         "all instance weights as 1.0"),
       ParamDesc[String]("solver", "the solver algorithm for optimization. If this is not set or " +
-        "empty, default value is 'auto'", Some("\"auto\"")))
+        "empty, default value is 'auto'", Some("\"auto\"")),
+      ParamDesc[Int]("aggregationDepth", "suggested depth for treeAggregate (>= 2)", Some("2"),
+        isValid = "ParamValidators.gtEq(2)"))
 
     val code = genSharedParams(params)
     val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 64d6af2766ca913b5141382cf51e8a803199d72c..6803772c63d62523fbefc8c966243d3d43e3f3b1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -334,10 +334,10 @@ private[ml] trait HasElasticNetParam extends Params {
 private[ml] trait HasTol extends Params {
 
   /**
-   * Param for the convergence tolerance for iterative algorithms.
+   * Param for the convergence tolerance for iterative algorithms (>= 0).
    * @group param
    */
-  final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms")
+  final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms (>= 0)", ParamValidators.gtEq(0))
 
   /** @group getParam */
   final def getTol: Double = $(tol)
@@ -349,10 +349,10 @@ private[ml] trait HasTol extends Params {
 private[ml] trait HasStepSize extends Params {
 
   /**
-   * Param for Step size to be used for each iteration of optimization.
+   * Param for Step size to be used for each iteration of optimization (> 0).
    * @group param
    */
-  final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization")
+  final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization (> 0)", ParamValidators.gt(0))
 
   /** @group getParam */
   final def getStepSize: Double = $(stepSize)
@@ -389,4 +389,21 @@ private[ml] trait HasSolver extends Params {
   /** @group getParam */
   final def getSolver: String = $(solver)
 }
+
+/**
+ * Trait for shared param aggregationDepth (default: 2).
+ */
+private[ml] trait HasAggregationDepth extends Params {
+
+  /**
+   * Param for suggested depth for treeAggregate (>= 2).
+   * @group param
+   */
+  final val aggregationDepth: IntParam = new IntParam(this, "aggregationDepth", "suggested depth for treeAggregate (>= 2)", ParamValidators.gtEq(2))
+
+  setDefault(aggregationDepth, 2)
+
+  /** @group getParam */
+  final def getAggregationDepth: Int = $(aggregationDepth)
+}
 // scalastyle:on
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 76be4204e90508448b4bff2760e307b3bd9a3027..b1bb9b9fe0058f2dd0daca1a3f3f1dd948755f7c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -53,6 +53,7 @@ import org.apache.spark.storage.StorageLevel
 private[regression] trait LinearRegressionParams extends PredictorParams
     with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
     with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver
+    with HasAggregationDepth
 
 /**
  * Linear regression.
@@ -172,6 +173,17 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
   def setSolver(value: String): this.type = set(solver, value)
   setDefault(solver -> "auto")
 
+  /**
+   * Suggested depth for treeAggregate (>= 2).
+   * If the dimensions of features or the number of partitions are large,
+   * this param could be adjusted to a larger size.
+   * Default is 2.
+   * @group expertSetParam
+   */
+  @Since("2.1.0")
+  def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
+  setDefault(aggregationDepth -> 2)
+
   override protected def train(dataset: Dataset[_]): LinearRegressionModel = {
     // Extract the number of features before deciding optimization solver.
     val numFeatures = dataset.select(col($(featuresCol))).first().getAs[Vector](0).size
@@ -230,7 +242,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
           (c1._1.merge(c2._1), c1._2.merge(c2._2))
 
       instances.treeAggregate(
-        new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer)(seqOp, combOp)
+        new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer
+      )(seqOp, combOp, $(aggregationDepth))
     }
 
     val yMean = ySummarizer.mean(0)
@@ -296,7 +309,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
     val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam
 
     val costFun = new LeastSquaresCostFun(instances, yStd, yMean, $(fitIntercept),
-      $(standardization), bcFeaturesStd, bcFeaturesMean, effectiveL2RegParam)
+      $(standardization), bcFeaturesStd, bcFeaturesMean, effectiveL2RegParam, $(aggregationDepth))
 
     val optimizer = if ($(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
       new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
@@ -1016,7 +1029,8 @@ private class LeastSquaresCostFun(
     standardization: Boolean,
     bcFeaturesStd: Broadcast[Array[Double]],
     bcFeaturesMean: Broadcast[Array[Double]],
-    effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
+    effectiveL2regParam: Double,
+    aggregationDepth: Int) extends DiffFunction[BDV[Double]] {
 
   override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
     val coeffs = Vectors.fromBreeze(coefficients)
@@ -1029,7 +1043,7 @@ private class LeastSquaresCostFun(
 
       instances.treeAggregate(
         new LeastSquaresAggregator(bcCoeffs, labelStd, labelMean, fitIntercept, bcFeaturesStd,
-          bcFeaturesMean))(seqOp, combOp)
+          bcFeaturesMean))(seqOp, combOp, aggregationDepth)
     }
 
     val totalGradientArray = leastSquaresAggregator.gradient.toArray