Skip to content
Snippets Groups Projects
Commit d8662cd9 authored by leahmcguire's avatar leahmcguire Committed by Joseph K. Bradley
Browse files

[SPARK-6164] [ML] CrossValidatorModel should keep stats from fitting

Added stats from cross validation as a val in the cross validation model to save them for user access.

Author: leahmcguire <lmcguire@salesforce.com>

Closes #5915 from leahmcguire/saveCVmetrics and squashes the following commits:

49b507b [leahmcguire] fixed tyle error
67537b1 [leahmcguire] rebased
85907f0 [leahmcguire] fixed name
59987cc [leahmcguire] changed param name and test according to comments
36e71e3 [leahmcguire] rebasing
4b8223e [leahmcguire] fixed name
4ddffc6 [leahmcguire] changed param name and test according to comments
3a995da [leahmcguire] Added stats from cross validation as a val in the cross validation model to save them for user access
parent 26c9d7a0
No related branches found
No related tags found
No related merge requests found
......@@ -135,7 +135,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
copyValues(new CrossValidatorModel(uid, bestModel).setParent(this))
copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
}
override def transformSchema(schema: StructType): StructType = {
......@@ -158,7 +158,8 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
@Experimental
class CrossValidatorModel private[ml] (
override val uid: String,
val bestModel: Model[_])
val bestModel: Model[_],
val avgMetrics: Array[Double])
extends Model[CrossValidatorModel] with CrossValidatorParams {
override def validateParams(): Unit = {
......@@ -175,7 +176,10 @@ class CrossValidatorModel private[ml] (
}
override def copy(extra: ParamMap): CrossValidatorModel = {
val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]])
val copied = new CrossValidatorModel(
uid,
bestModel.copy(extra).asInstanceOf[Model[_]],
avgMetrics.clone())
copyValues(copied, extra)
}
}
......@@ -56,6 +56,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
assert(cvModel.avgMetrics.length === lrParamMaps.length)
}
test("validateParams should check estimatorParamMaps") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment