diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 3179f4882fd49145a17dd4a26b52416994bd0aba..9d5ba999781f60bf275b086490ccc3ccfa26979d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -46,7 +46,7 @@ import org.apache.spark.storage.StorageLevel
  */
 private[regression] trait AFTSurvivalRegressionParams extends Params
   with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter
-  with HasTol with HasFitIntercept with Logging {
+  with HasTol with HasFitIntercept with HasAggregationDepth with Logging {
 
   /**
    * Param for censor column name.
@@ -183,6 +183,17 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
   def setTol(value: Double): this.type = set(tol, value)
   setDefault(tol -> 1E-6)
 
+  /**
+   * Suggested depth for treeAggregate (>= 2).
+   * If the dimensions of features or the number of partitions are large,
+   * this param could be adjusted to a larger size.
+   * Default is 2.
+   * @group expertSetParam
+   */
+  @Since("2.1.0")
+  def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
+  setDefault(aggregationDepth -> 2)
+
   /**
    * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset,
    * and put it in an RDD with strong types.
@@ -207,7 +218,9 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
       val combOp = (c1: MultivariateOnlineSummarizer, c2: MultivariateOnlineSummarizer) => {
         c1.merge(c2)
       }
-      instances.treeAggregate(new MultivariateOnlineSummarizer)(seqOp, combOp)
+      instances.treeAggregate(
+        new MultivariateOnlineSummarizer
+      )(seqOp, combOp, $(aggregationDepth))
     }
 
     val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
@@ -222,7 +235,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
 
     val bcFeaturesStd = instances.context.broadcast(featuresStd)
 
-    val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd)
+    val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd, $(aggregationDepth))
     val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
 
     /*
@@ -591,7 +604,8 @@ private class AFTAggregator(
 private class AFTCostFun(
     data: RDD[AFTPoint],
     fitIntercept: Boolean,
-    bcFeaturesStd: Broadcast[Array[Double]]) extends DiffFunction[BDV[Double]] {
+    bcFeaturesStd: Broadcast[Array[Double]],
+    aggregationDepth: Int) extends DiffFunction[BDV[Double]] {
 
   override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = {
 
@@ -604,7 +618,7 @@ private class AFTCostFun(
       },
       combOp = (c1, c2) => (c1, c2) match {
         case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
-      })
+      }, depth = aggregationDepth)
 
     bcParameters.destroy(blocking = false)
     (aftAggregator.loss, aftAggregator.gradient)
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 19afc723bb7847cc39d506a308871424880e0fd9..55d38033ef72a3c292303b9923493a209b66423a 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -1088,7 +1088,8 @@ class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
 
 @inherit_doc
 class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
-                            HasFitIntercept, HasMaxIter, HasTol, JavaMLWritable, JavaMLReadable):
+                            HasFitIntercept, HasMaxIter, HasTol, HasAggregationDepth,
+                            JavaMLWritable, JavaMLReadable):
     """
     .. note:: Experimental
 
@@ -1153,12 +1154,12 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
     def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
                  fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
                  quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]),
-                 quantilesCol=None):
+                 quantilesCol=None, aggregationDepth=2):
         """
         __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
                  fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
                  quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
-                 quantilesCol=None)
+                 quantilesCol=None, aggregationDepth=2)
         """
         super(AFTSurvivalRegression, self).__init__()
         self._java_obj = self._new_java_obj(
@@ -1174,12 +1175,12 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
     def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
                   fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
                   quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]),
-                  quantilesCol=None):
+                  quantilesCol=None, aggregationDepth=2):
         """
         setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
                   fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
                   quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
-                  quantilesCol=None):
+                  quantilesCol=None, aggregationDepth=2):
         """
         kwargs = self.setParams._input_kwargs
         return self._set(**kwargs)