diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 5259ee419445fd087cae417bc1ab1c06903c6c27..f19ad7a5a6938b0979c8238a78c6cd868e2d4a90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -64,8 +64,8 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w */ protected def validateAndTransformSchema(schema: StructType): StructType = { SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) - SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) - SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT) + val schemaWithPredictionCol = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + SchemaUtils.appendColumn(schemaWithPredictionCol, $(probabilityCol), new VectorUDT) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 0891994530f889d36bbda9a34c8515fcfdb83b9a..16821f317760ea9a5799879f365fb167d642f435 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -109,10 +109,12 @@ private[regression] trait AFTSurvivalRegressionParams extends Params SchemaUtils.checkNumericType(schema, $(censorCol)) SchemaUtils.checkNumericType(schema, $(labelCol)) } - if (hasQuantilesCol) { + + val schemaWithQuantilesCol = if (hasQuantilesCol) { SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT) - } - SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) + } else schema + + SchemaUtils.appendColumn(schemaWithQuantilesCol, $(predictionCol), DoubleType) } }