From 73e04ecc4f29a0fe51687ed1337c61840c976f89 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?C=C3=A9dric=20Pelvet?= <cedric.pelvet@gmail.com>
Date: Sun, 20 Aug 2017 11:05:54 +0100
Subject: [PATCH] [MINOR] Correct validateAndTransformSchema in GaussianMixture
 and AFTSurvivalRegression
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## What changes were proposed in this pull request?

The line SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) did not modify the variable schema, hence only the last line had any effect. A temporary variable is used to correctly append the two columns predictionCol and probabilityCol.

## How was this patch tested?

Manually.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Cédric Pelvet <cedric.pelvet@gmail.com>

Closes #18980 from sharp-pixel/master.
---
 .../org/apache/spark/ml/clustering/GaussianMixture.scala  | 4 ++--
 .../spark/ml/regression/AFTSurvivalRegression.scala       | 8 +++++---
 2 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 5259ee4194..f19ad7a5a6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -64,8 +64,8 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w
    */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
     SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
-    SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
-    SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT)
+    val schemaWithPredictionCol = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
+    SchemaUtils.appendColumn(schemaWithPredictionCol, $(probabilityCol), new VectorUDT)
   }
 }
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 0891994530..16821f3177 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -109,10 +109,12 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
       SchemaUtils.checkNumericType(schema, $(censorCol))
       SchemaUtils.checkNumericType(schema, $(labelCol))
     }
-    if (hasQuantilesCol) {
+
+    val schemaWithQuantilesCol = if (hasQuantilesCol) {
       SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT)
-    }
-    SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
+    } else schema
+
+    SchemaUtils.appendColumn(schemaWithQuantilesCol, $(predictionCol), DoubleType)
   }
 }
 
-- 
GitLab