Skip to content
Snippets Groups Projects
Commit b3546738 authored by sethah's avatar sethah Committed by Xiangrui Meng
Browse files

[SPARK-13047][PYSPARK][ML] Pyspark Params.hasParam should not throw an error

Pyspark Params class has a method `hasParam(paramName)` which returns `True` if the class has a parameter by that name, but throws an `AttributeError` otherwise. There is not currently a way of getting a Boolean to indicate if a class has a parameter. With Spark 2.0 we could modify the existing behavior of `hasParam` or add an additional method with this functionality.

In Python:
```python
from pyspark.ml.classification import NaiveBayes
nb = NaiveBayes()
print nb.hasParam("smoothing")
print nb.hasParam("notAParam")
```
produces:
> True
> AttributeError: 'NaiveBayes' object has no attribute 'notAParam'

However, in Scala:
```scala
import org.apache.spark.ml.classification.NaiveBayes
val nb  = new NaiveBayes()
nb.hasParam("smoothing")
nb.hasParam("notAParam")
```
produces:
> true
> false

cc holdenk

Author: sethah <seth.hendrickson16@gmail.com>

Closes #10962 from sethah/SPARK-13047.
parent 30e00955
No related branches found
No related tags found
No related merge requests found
......@@ -179,8 +179,11 @@ class Params(Identifiable):
Tests whether this instance contains a param with a given
(string) name.
"""
param = self._resolveParam(paramName)
return param in self.params
if isinstance(paramName, str):
p = getattr(self, paramName, None)
return isinstance(p, Param)
else:
raise TypeError("hasParam(): paramName must be a string")
@since("1.4.0")
def getOrDefault(self, param):
......
......@@ -209,6 +209,11 @@ class ParamTests(PySparkTestCase):
self.assertEqual(maxIter.doc, "max number of iterations (>= 0).")
self.assertTrue(maxIter.parent == testParams.uid)
def test_hasparam(self):
testParams = TestParams()
self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params]))
self.assertFalse(testParams.hasParam("notAParameter"))
def test_params(self):
testParams = TestParams()
maxIter = testParams.maxIter
......@@ -218,7 +223,7 @@ class ParamTests(PySparkTestCase):
params = testParams.params
self.assertEqual(params, [inputCol, maxIter, seed])
self.assertTrue(testParams.hasParam(maxIter))
self.assertTrue(testParams.hasParam(maxIter.name))
self.assertTrue(testParams.hasDefault(maxIter))
self.assertFalse(testParams.isSet(maxIter))
self.assertTrue(testParams.isDefined(maxIter))
......@@ -227,7 +232,7 @@ class ParamTests(PySparkTestCase):
self.assertTrue(testParams.isSet(maxIter))
self.assertEqual(testParams.getMaxIter(), 100)
self.assertTrue(testParams.hasParam(inputCol))
self.assertTrue(testParams.hasParam(inputCol.name))
self.assertFalse(testParams.hasDefault(inputCol))
self.assertFalse(testParams.isSet(inputCol))
self.assertFalse(testParams.isDefined(inputCol))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment