Skip to content
Snippets Groups Projects
Commit 0903c648 authored by Holden Karau's avatar Holden Karau Committed by Joseph K. Bradley
Browse files

[SPARK-9718] [ML] linear regression training summary all columns

LinearRegression training summary: The transformed dataset should hold all columns, not just selected ones like prediction and label. There is no real need to remove some, and the user may find them useful.

Author: Holden Karau <holden@pigscanfly.ca>

Closes #8564 from holdenk/SPARK-9718-LinearRegressionTrainingSummary-all-columns.
parent dcbd58a9
No related branches found
No related tags found
No related merge requests found
......@@ -170,9 +170,12 @@ class LinearRegression(override val uid: String)
val intercept = yMean
val model = new LinearRegressionModel(uid, coefficients, intercept)
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
val trainingSummary = new LinearRegressionTrainingSummary(
model.transform(dataset),
$(predictionCol),
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
$(featuresCol),
Array(0D))
......@@ -262,9 +265,12 @@ class LinearRegression(override val uid: String)
if (handlePersistence) instances.unpersist()
val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept))
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol()
val trainingSummary = new LinearRegressionTrainingSummary(
model.transform(dataset),
$(predictionCol),
summaryModel.transform(dataset),
predictionColName,
$(labelCol),
$(featuresCol),
objectiveHistory)
......@@ -316,13 +322,26 @@ class LinearRegressionModel private[ml] (
*/
// TODO: decide on a good name before exposing to public API
private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = {
val t = udf { features: Vector => predict(features) }
val predictionAndObservations = dataset
.select(col($(labelCol)), t(col($(featuresCol))).as($(predictionCol)))
// Handle possible missing or invalid prediction columns
val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol()
new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, $(labelCol))
}
new LinearRegressionSummary(predictionAndObservations, $(predictionCol), $(labelCol))
/**
* If the prediction column is set returns the current model and prediction column,
* otherwise generates a new column and sets it as the prediction column on a new copy
* of the current model.
*/
private[regression] def findSummaryModelAndPredictionCol(): (LinearRegressionModel, String) = {
$(predictionCol) match {
case "" =>
val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString()
(copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
case p => (this, p)
}
}
override protected def predict(features: Vector): Double = {
dot(features, weights) + intercept
}
......
......@@ -462,9 +462,22 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
test("linear regression model training summary") {
val trainer = new LinearRegression
val model = trainer.fit(dataset)
val trainerNoPredictionCol = trainer.setPredictionCol("")
val modelNoPredictionCol = trainerNoPredictionCol.fit(dataset)
// Training results for the model should be available
assert(model.hasSummary)
assert(modelNoPredictionCol.hasSummary)
// Schema should be a superset of the input dataset
assert((dataset.schema.fieldNames.toSet + "prediction").subsetOf(
model.summary.predictions.schema.fieldNames.toSet))
// Validate that we re-insert a prediction column for evaluation
val modelNoPredictionColFieldNames = modelNoPredictionCol.summary.predictions.schema.fieldNames
assert((dataset.schema.fieldNames.toSet).subsetOf(
modelNoPredictionColFieldNames.toSet))
assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_")))
// Residuals in [[LinearRegressionResults]] should equal those manually computed
val expectedResiduals = dataset.select("features", "label")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment