Skip to content
Snippets Groups Projects
Commit e3554605 authored by hyukjinkwon's avatar hyukjinkwon Committed by Joseph K. Bradley
Browse files

[SPARK-15892][ML] Incorrectly merged AFTAggregator with zero total count

## What changes were proposed in this pull request?

Currently, `AFTAggregator` is not being merged correctly. For example, if there is any single empty partition in the data, this creates an `AFTAggregator` with zero total count which causes the exception below:

```
IllegalArgumentException: u'requirement failed: The number of instances should be greater than 0.0, but got 0.'
```

Please see [AFTSurvivalRegression.scala#L573-L575](https://github.com/apache/spark/blob/6ecedf39b44c9acd58cdddf1a31cf11e8e24428c/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala#L573-L575) as well.

Just to be clear, the python example `aft_survival_regression.py` seems using 5 rows. So, if there exist partitions more than 5, it throws the exception above since it contains empty partitions which results in an incorrectly merged `AFTAggregator`.

Executing `bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py` on a machine with CPUs more than 5 is being failed because it creates tasks with some empty partitions with defualt  configurations (AFAIK, it sets the parallelism level to the number of CPU cores).

## How was this patch tested?

An unit test in `AFTSurvivalRegressionSuite.scala` and manually tested by `bin/spark-submit examples/src/main/python/ml/aft_survival_regression.py`.

Author: hyukjinkwon <gurwls223@gmail.com>
Author: Hyukjin Kwon <gurwls223@gmail.com>

Closes #13619 from HyukjinKwon/SPARK-15892.
parent 0ff8a68b
No related branches found
No related tags found
No related merge requests found
......@@ -538,7 +538,7 @@ private class AFTAggregator(
* @return This AFTAggregator object.
*/
def merge(other: AFTAggregator): this.type = {
if (totalCnt != 0) {
if (other.count != 0) {
totalCnt += other.totalCnt
lossSum += other.lossSum
......
......@@ -390,6 +390,18 @@ class AFTSurvivalRegressionSuite
testEstimatorAndModelReadWrite(aft, datasetMultivariate,
AFTSurvivalRegressionSuite.allParamSettings, checkModelData)
}
test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") {
// This `dataset` will contain an empty partition because it has two rows but
// the parallelism is bigger than that. Because the issue was about `AFTAggregator`s
// being merged incorrectly when it has an empty partition, running the codes below
// should not throw an exception.
val dataset = spark.createDataFrame(
sc.parallelize(generateAFTInput(
1, Array(5.5), Array(0.8), 2, 42, 1.0, 2.0, 2.0), numSlices = 3))
val trainer = new AFTSurvivalRegression()
trainer.fit(dataset)
}
}
object AFTSurvivalRegressionSuite {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment