Skip to content
Snippets Groups Projects
Commit 2f69e3f6 authored by Bryan Cutler's avatar Bryan Cutler Committed by Joseph K. Bradley
Browse files

[SPARK-14772][PYTHON][ML] Fixed Params.copy method to match Scala implementation

## What changes were proposed in this pull request?
Fixed the PySpark Params.copy method to behave like the Scala implementation.  The main issue was that it did not account for the _defaultParamMap and merged it into the explicitly created param map.

## How was this patch tested?
Added new unit test to verify the copy method behaves correctly for copying uid, explicitly created params, and default params.

Author: Bryan Cutler <cutlerb@gmail.com>

Closes #16772 from BryanCutler/pyspark-ml-param_copy-Scala_sync-SPARK-14772.
parent d0276245
No related branches found
No related tags found
No related merge requests found
...@@ -372,6 +372,7 @@ class Params(Identifiable): ...@@ -372,6 +372,7 @@ class Params(Identifiable):
extra = dict() extra = dict()
that = copy.copy(self) that = copy.copy(self)
that._paramMap = {} that._paramMap = {}
that._defaultParamMap = {}
return self._copyValues(that, extra) return self._copyValues(that, extra)
def _shouldOwn(self, param): def _shouldOwn(self, param):
...@@ -452,12 +453,16 @@ class Params(Identifiable): ...@@ -452,12 +453,16 @@ class Params(Identifiable):
:param extra: extra params to be copied :param extra: extra params to be copied
:return: the target instance with param values copied :return: the target instance with param values copied
""" """
if extra is None: paramMap = self._paramMap.copy()
extra = dict() if extra is not None:
paramMap = self.extractParamMap(extra) paramMap.update(extra)
for p in self.params: for param in self.params:
if p in paramMap and to.hasParam(p.name): # copy default params
to._set(**{p.name: paramMap[p]}) if param in self._defaultParamMap and to.hasParam(param.name):
to._defaultParamMap[to.getParam(param.name)] = self._defaultParamMap[param]
# copy explicitly set params
if param in paramMap and to.hasParam(param.name):
to._set(**{param.name: paramMap[param]})
return to return to
def _resetUid(self, newUid): def _resetUid(self, newUid):
......
...@@ -389,6 +389,22 @@ class ParamTests(PySparkTestCase): ...@@ -389,6 +389,22 @@ class ParamTests(PySparkTestCase):
# Check windowSize is set properly # Check windowSize is set properly
self.assertEqual(model.getWindowSize(), 6) self.assertEqual(model.getWindowSize(), 6)
def test_copy_param_extras(self):
tp = TestParams(seed=42)
extra = {tp.getParam(TestParams.inputCol.name): "copy_input"}
tp_copy = tp.copy(extra=extra)
self.assertEqual(tp.uid, tp_copy.uid)
self.assertEqual(tp.params, tp_copy.params)
for k, v in extra.items():
self.assertTrue(tp_copy.isDefined(k))
self.assertEqual(tp_copy.getOrDefault(k), v)
copied_no_extra = {}
for k, v in tp_copy._paramMap.items():
if k not in extra:
copied_no_extra[k] = v
self.assertEqual(tp._paramMap, copied_no_extra)
self.assertEqual(tp._defaultParamMap, tp_copy._defaultParamMap)
class EvaluatorTests(SparkSessionTestCase): class EvaluatorTests(SparkSessionTestCase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment