diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index d68f5ff0053c94f2eb265736787ab31460154c9a..91c0a5631319d7981e0caf0c86a2ef428a24cad4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -559,13 +559,26 @@ trait Params extends Identifiable with Serializable { /** * Copies param values from this instance to another instance for params shared by them. - * @param to the target instance - * @param extra extra params to be copied + * + * This handles default Params and explicitly set Params separately. + * Default Params are copied from and to [[defaultParamMap]], and explicitly set Params are + * copied from and to [[paramMap]]. + * Warning: This implicitly assumes that this [[Params]] instance and the target instance + * share the same set of default Params. + * + * @param to the target instance, which should work with the same set of default Params as this + * source instance + * @param extra extra params to be copied to the target's [[paramMap]] * @return the target instance with param values copied */ protected def copyValues[T <: Params](to: T, extra: ParamMap = ParamMap.empty): T = { - val map = extractParamMap(extra) + val map = paramMap ++ extra params.foreach { param => + // copy default Params + if (defaultParamMap.contains(param) && to.hasParam(param.name)) { + to.defaultParamMap.put(to.getParam(param.name), defaultParamMap(param)) + } + // copy explicitly set Params if (map.contains(param) && to.hasParam(param.name)) { to.set(param.name, map(param)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 050d4170ea017443918e411bcac751145c2f02d1..be95638d81686e541a7f6d745301bf90fcc766ef 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -200,6 +200,14 @@ class ParamsSuite extends SparkFunSuite { val inArray = ParamValidators.inArray[Int](Array(1, 2)) assert(inArray(1) && inArray(2) && !inArray(0)) } + + test("Params.copyValues") { + val t = new TestParams() + val t2 = t.copy(ParamMap.empty) + assert(!t2.isSet(t2.maxIter)) + val t3 = t.copy(ParamMap(t.maxIter -> 20)) + assert(t3.isSet(t3.maxIter)) + } } object ParamsSuite extends SparkFunSuite {