Skip to content
Snippets Groups Projects
Commit d8662cd9 authored by leahmcguire's avatar leahmcguire Committed by Joseph K. Bradley
Browse files

[SPARK-6164] [ML] CrossValidatorModel should keep stats from fitting

Added stats from cross validation as a val in the cross validation model to save them for user access.

Author: leahmcguire <lmcguire@salesforce.com>

Closes #5915 from leahmcguire/saveCVmetrics and squashes the following commits:

49b507b [leahmcguire] fixed tyle error
67537b1 [leahmcguire] rebased
85907f0 [leahmcguire] fixed name
59987cc [leahmcguire] changed param name and test according to comments
36e71e3 [leahmcguire] rebasing
4b8223e [leahmcguire] fixed name
4ddffc6 [leahmcguire] changed param name and test according to comments
3a995da [leahmcguire] Added stats from cross validation as a val in the cross validation model to save them for user access
parent 26c9d7a0
No related branches found
No related tags found
No related merge requests found
...@@ -135,7 +135,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM ...@@ -135,7 +135,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best cross-validation metric: $bestMetric.") logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
copyValues(new CrossValidatorModel(uid, bestModel).setParent(this)) copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
} }
override def transformSchema(schema: StructType): StructType = { override def transformSchema(schema: StructType): StructType = {
...@@ -158,7 +158,8 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM ...@@ -158,7 +158,8 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
@Experimental @Experimental
class CrossValidatorModel private[ml] ( class CrossValidatorModel private[ml] (
override val uid: String, override val uid: String,
val bestModel: Model[_]) val bestModel: Model[_],
val avgMetrics: Array[Double])
extends Model[CrossValidatorModel] with CrossValidatorParams { extends Model[CrossValidatorModel] with CrossValidatorParams {
override def validateParams(): Unit = { override def validateParams(): Unit = {
...@@ -175,7 +176,10 @@ class CrossValidatorModel private[ml] ( ...@@ -175,7 +176,10 @@ class CrossValidatorModel private[ml] (
} }
override def copy(extra: ParamMap): CrossValidatorModel = { override def copy(extra: ParamMap): CrossValidatorModel = {
val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]]) val copied = new CrossValidatorModel(
uid,
bestModel.copy(extra).asInstanceOf[Model[_]],
avgMetrics.clone())
copyValues(copied, extra) copyValues(copied, extra)
} }
} }
...@@ -56,6 +56,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -56,6 +56,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(parent.getRegParam === 0.001) assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10) assert(parent.getMaxIter === 10)
assert(cvModel.avgMetrics.length === lrParamMaps.length)
} }
test("validateParams should check estimatorParamMaps") { test("validateParams should check estimatorParamMaps") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment