Skip to content
Snippets Groups Projects
Commit 4f87e956 authored by Joseph K. Bradley's avatar Joseph K. Bradley Committed by Xiangrui Meng
Browse files

[SPARK-7429] [ML] Params cleanups

Params.setDefault taking a set of ParamPairs should be annotated with varargs. I thought it would not work before, but it apparently does.

CrossValidator.transform should call transformSchema since the underlying Model might be a PipelineModel

CC: mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #5960 from jkbradley/params-cleanups and squashes the following commits:

118b158 [Joseph K. Bradley] Params.setDefault taking a set of ParamPairs should be annotated with varargs. I thought it would not work before, but it apparently does. CrossValidator.transform should call transformSchema since the underlying Model might be a PipelineModel
parent 8b6b46e4
No related branches found
No related tags found
No related merge requests found
......@@ -366,13 +366,11 @@ trait Params extends Identifiable with Serializable {
/**
* Sets default values for a list of params.
*
* Note: Java developers should use the single-parameter [[setDefault()]].
* Annotating this with varargs causes compilation failures.
*
* @param paramPairs a list of param pairs that specify params and their default values to set
* respectively. Make sure that the params are initialized before this method
* gets called.
*/
@varargs
protected final def setDefault(paramPairs: ParamPair[_]*): this.type = {
paramPairs.foreach { p =>
setDefault(p.param.asInstanceOf[Param[Any]], p.value)
......
......@@ -105,7 +105,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
override def fit(dataset: DataFrame): CrossValidatorModel = {
val schema = dataset.schema
transformSchema(dataset.schema, logging = true)
transformSchema(schema, logging = true)
val sqlCtx = dataset.sqlContext
val est = $(estimator)
val eval = $(evaluator)
......@@ -159,6 +159,7 @@ class CrossValidatorModel private[ml] (
}
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
bestModel.transform(dataset)
}
......
......@@ -59,5 +59,6 @@ public class JavaTestParams extends JavaParams {
ParamValidators.inArray(validStrings));
setDefault(myIntParam, 1);
setDefault(myDoubleParam, 0.5);
setDefault(myIntParam.w(1), myDoubleParam.w(0.5));
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment