Skip to content
Snippets Groups Projects
Commit d8813fa0 authored by Bryan Cutler's avatar Bryan Cutler Committed by Joseph K. Bradley
Browse files

[SPARK-13625][PYSPARK][ML] Added a check to see if an attribute is a property...

[SPARK-13625][PYSPARK][ML] Added a check to see if an attribute is a property when getting param list

## What changes were proposed in this pull request?

Added a check in pyspark.ml.param.Param.params() to see if an attribute is a property (decorated with `property`) before checking if it is a `Param` instance.  This prevents the property from being invoked to 'get' this attribute, which could possibly cause an error.

## How was this patch tested?

Added a test case with a class has a property that will raise an error when invoked and then call`Param.params` to verify that the property is not invoked, but still able to find another property in the class.  Also ran pyspark-ml test before fix that will trigger an error, and again after the fix to verify that the error was resolved and the method was working properly.

Author: Bryan Cutler <cutlerb@gmail.com>

Closes #11476 from BryanCutler/pyspark-ml-property-attr-SPARK-13625.
parent 81f54acc
No related branches found
No related tags found
No related merge requests found
......@@ -109,7 +109,8 @@ class Params(Identifiable):
"""
if self._params is None:
self._params = list(filter(lambda attr: isinstance(attr, Param),
[getattr(self, x) for x in dir(self) if x != "params"]))
[getattr(self, x) for x in dir(self) if x != "params" and
not isinstance(getattr(type(self), x, None), property)]))
return self._params
@since("1.4.0")
......
......@@ -271,6 +271,12 @@ class ParamTests(PySparkTestCase):
# Check that a different class has a different seed
self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed())
def test_param_property_error(self):
param_store = HasThrowableProperty()
self.assertRaises(RuntimeError, lambda: param_store.test_property)
params = param_store.params # should not invoke the property 'test_property'
self.assertEqual(len(params), 1)
class FeatureTests(PySparkTestCase):
......@@ -494,6 +500,17 @@ class PersistenceTest(PySparkTestCase):
pass
class HasThrowableProperty(Params):
def __init__(self):
super(HasThrowableProperty, self).__init__()
self.p = Param(self, "none", "empty param")
@property
def test_property(self):
raise RuntimeError("Test property to raise error when invoked")
if __name__ == "__main__":
from pyspark.ml.tests import *
if xmlrunner:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment