diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 0dc084fdd12f7c2a25e0982a637e96071229d774..dd09667ef5a0f944bb2061abf7607ba9df05e94e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -170,9 +170,12 @@ class LinearRegression(override val uid: String) val intercept = yMean val model = new LinearRegressionModel(uid, coefficients, intercept) + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + val trainingSummary = new LinearRegressionTrainingSummary( - model.transform(dataset), - $(predictionCol), + summaryModel.transform(dataset), + predictionColName, $(labelCol), $(featuresCol), Array(0D)) @@ -262,9 +265,12 @@ class LinearRegression(override val uid: String) if (handlePersistence) instances.unpersist() val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept)) + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + val trainingSummary = new LinearRegressionTrainingSummary( - model.transform(dataset), - $(predictionCol), + summaryModel.transform(dataset), + predictionColName, $(labelCol), $(featuresCol), objectiveHistory) @@ -316,13 +322,26 @@ class LinearRegressionModel private[ml] ( */ // TODO: decide on a good name before exposing to public API private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { - val t = udf { features: Vector => predict(features) } - val predictionAndObservations = dataset - .select(col($(labelCol)), t(col($(featuresCol))).as($(predictionCol))) + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() + new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, $(labelCol)) + } - new LinearRegressionSummary(predictionAndObservations, $(predictionCol), $(labelCol)) + /** + * If the prediction column is set returns the current model and prediction column, + * otherwise generates a new column and sets it as the prediction column on a new copy + * of the current model. + */ + private[regression] def findSummaryModelAndPredictionCol(): (LinearRegressionModel, String) = { + $(predictionCol) match { + case "" => + val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString() + (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName) + case p => (this, p) + } } + override protected def predict(features: Vector): Double = { dot(features, weights) + intercept } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 32729470d5bff1156eb3e618a41bedbdec66c95c..73a0a5caf86408bb46e2c3c02ff7fb78dcf76e57 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -462,9 +462,22 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { test("linear regression model training summary") { val trainer = new LinearRegression val model = trainer.fit(dataset) + val trainerNoPredictionCol = trainer.setPredictionCol("") + val modelNoPredictionCol = trainerNoPredictionCol.fit(dataset) + // Training results for the model should be available assert(model.hasSummary) + assert(modelNoPredictionCol.hasSummary) + + // Schema should be a superset of the input dataset + assert((dataset.schema.fieldNames.toSet + "prediction").subsetOf( + model.summary.predictions.schema.fieldNames.toSet)) + // Validate that we re-insert a prediction column for evaluation + val modelNoPredictionColFieldNames = modelNoPredictionCol.summary.predictions.schema.fieldNames + assert((dataset.schema.fieldNames.toSet).subsetOf( + modelNoPredictionColFieldNames.toSet)) + assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_"))) // Residuals in [[LinearRegressionResults]] should equal those manually computed val expectedResiduals = dataset.select("features", "label")