From 57cd1e86d1d450f85fc9e296aff498a940452113 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng <meng@databricks.com> Date: Wed, 15 Apr 2015 23:49:42 -0700 Subject: [PATCH] [SPARK-6893][ML] default pipeline parameter handling in python Same as #5431 but for Python. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #5534 from mengxr/SPARK-6893 and squashes the following commits: d3b519b [Xiangrui Meng] address comments ebaccc6 [Xiangrui Meng] style update fce244e [Xiangrui Meng] update explainParams with test 4d6b07a [Xiangrui Meng] add tests 5294500 [Xiangrui Meng] update default param handling in python --- .../org/apache/spark/ml/Identifiable.scala | 2 +- .../apache/spark/ml/param/TestParams.scala | 9 +- python/pyspark/ml/classification.py | 3 +- python/pyspark/ml/feature.py | 19 +-- python/pyspark/ml/param/__init__.py | 146 +++++++++++++++--- ...d_params.py => _shared_params_code_gen.py} | 42 ++--- python/pyspark/ml/param/shared.py | 106 ++++++------- python/pyspark/ml/pipeline.py | 6 +- python/pyspark/ml/tests.py | 52 ++++++- python/pyspark/ml/util.py | 4 +- python/pyspark/ml/wrapper.py | 2 +- 11 files changed, 270 insertions(+), 121 deletions(-) rename python/pyspark/ml/param/{_gen_shared_params.py => _shared_params_code_gen.py} (70%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala index a50090671a..a1d49095c2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala @@ -25,7 +25,7 @@ import java.util.UUID private[ml] trait Identifiable extends Serializable { /** - * A unique id for the object. The default implementation concatenates the class name, "-", and 8 + * A unique id for the object. The default implementation concatenates the class name, "_", and 8 * random hex chars. */ private[ml] val uid: String = diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 8f9ab687c0..641b64b42a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -17,16 +17,13 @@ package org.apache.spark.ml.param +import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter} + /** A subclass of Params for testing. */ -class TestParams extends Params { +class TestParams extends Params with HasMaxIter with HasInputCol { - val maxIter = new IntParam(this, "maxIter", "max number of iterations") def setMaxIter(value: Int): this.type = { set(maxIter, value); this } - def getMaxIter: Int = getOrDefault(maxIter) - - val inputCol = new Param[String](this, "inputCol", "input column name") def setInputCol(value: String): this.type = { set(inputCol, value); this } - def getInputCol: String = getOrDefault(inputCol) setDefault(maxIter -> 10) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 7f42de531f..d7bc09fd77 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -59,6 +59,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti maxIter=100, regParam=0.1) """ super(LogisticRegression, self).__init__() + self._setDefault(maxIter=100, regParam=0.1) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -71,7 +72,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti Sets params for logistic regression. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) def _create_model(self, java_model): return LogisticRegressionModel(java_model) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 1cfcd019df..263fe2a5bc 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -52,22 +52,22 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): _java_class = "org.apache.spark.ml.feature.Tokenizer" @keyword_only - def __init__(self, inputCol="input", outputCol="output"): + def __init__(self, inputCol=None, outputCol=None): """ - __init__(self, inputCol="input", outputCol="output") + __init__(self, inputCol=None, outputCol=None) """ super(Tokenizer, self).__init__() kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, inputCol="input", outputCol="output"): + def setParams(self, inputCol=None, outputCol=None): """ setParams(self, inputCol="input", outputCol="output") Sets params for this Tokenizer. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) @inherit_doc @@ -91,22 +91,23 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): _java_class = "org.apache.spark.ml.feature.HashingTF" @keyword_only - def __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): + def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None): """ - __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output") + __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None) """ super(HashingTF, self).__init__() + self._setDefault(numFeatures=1 << 18) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"): + def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None): """ - setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output") + setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None) Sets params for this HashingTF. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) if __name__ == "__main__": diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index e3a53dd780..5c62620562 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -25,23 +25,21 @@ __all__ = ['Param', 'Params'] class Param(object): """ - A param with self-contained documentation and optionally default value. + A param with self-contained documentation. """ - def __init__(self, parent, name, doc, defaultValue=None): - if not isinstance(parent, Identifiable): - raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__) + def __init__(self, parent, name, doc): + if not isinstance(parent, Params): + raise ValueError("Parent must be a Params but got type %s." % type(parent).__name__) self.parent = parent self.name = str(name) self.doc = str(doc) - self.defaultValue = defaultValue def __str__(self): - return str(self.parent) + "-" + self.name + return str(self.parent) + "__" + self.name def __repr__(self): - return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \ - (self.parent, self.name, self.doc, self.defaultValue) + return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc) class Params(Identifiable): @@ -52,26 +50,128 @@ class Params(Identifiable): __metaclass__ = ABCMeta - def __init__(self): - super(Params, self).__init__() - #: embedded param map - self.paramMap = {} + #: internal param map for user-supplied values param map + paramMap = {} + + #: internal param map for default values + defaultParamMap = {} @property def params(self): """ - Returns all params. The default implementation uses - :py:func:`dir` to get all attributes of type + Returns all params ordered by name. The default implementation + uses :py:func:`dir` to get all attributes of type :py:class:`Param`. """ return filter(lambda attr: isinstance(attr, Param), [getattr(self, x) for x in dir(self) if x != "params"]) - def _merge_params(self, params): - paramMap = self.paramMap.copy() - paramMap.update(params) + def _explain(self, param): + """ + Explains a single param and returns its name, doc, and optional + default value and user-supplied value in a string. + """ + param = self._resolveParam(param) + values = [] + if self.isDefined(param): + if param in self.defaultParamMap: + values.append("default: %s" % self.defaultParamMap[param]) + if param in self.paramMap: + values.append("current: %s" % self.paramMap[param]) + else: + values.append("undefined") + valueStr = "(" + ", ".join(values) + ")" + return "%s: %s %s" % (param.name, param.doc, valueStr) + + def explainParams(self): + """ + Returns the documentation of all params with their optionally + default values and user-supplied values. + """ + return "\n".join([self._explain(param) for param in self.params]) + + def getParam(self, paramName): + """ + Gets a param by its name. + """ + param = getattr(self, paramName) + if isinstance(param, Param): + return param + else: + raise ValueError("Cannot find param with name %s." % paramName) + + def isSet(self, param): + """ + Checks whether a param is explicitly set by user. + """ + param = self._resolveParam(param) + return param in self.paramMap + + def hasDefault(self, param): + """ + Checks whether a param has a default value. + """ + param = self._resolveParam(param) + return param in self.defaultParamMap + + def isDefined(self, param): + """ + Checks whether a param is explicitly set by user or has a default value. + """ + return self.isSet(param) or self.hasDefault(param) + + def getOrDefault(self, param): + """ + Gets the value of a param in the user-supplied param map or its + default value. Raises an error if either is set. + """ + if isinstance(param, Param): + if param in self.paramMap: + return self.paramMap[param] + else: + return self.defaultParamMap[param] + elif isinstance(param, str): + return self.getOrDefault(self.getParam(param)) + else: + raise KeyError("Cannot recognize %r as a param." % param) + + def extractParamMap(self, extraParamMap={}): + """ + Extracts the embedded default param values and user-supplied + values, and then merges them with extra values from input into + a flat param map, where the latter value is used if there exist + conflicts, i.e., with ordering: default param values < + user-supplied values < extraParamMap. + :param extraParamMap: extra param values + :return: merged param map + """ + paramMap = self.defaultParamMap.copy() + paramMap.update(self.paramMap) + paramMap.update(extraParamMap) return paramMap + def _shouldOwn(self, param): + """ + Validates that the input param belongs to this Params instance. + """ + if param.parent is not self: + raise ValueError("Param %r does not belong to %r." % (param, self)) + + def _resolveParam(self, param): + """ + Resolves a param and validates the ownership. + :param param: param name or the param instance, which must + belong to this Params instance + :return: resolved param instance + """ + if isinstance(param, Param): + self._shouldOwn(param) + return param + elif isinstance(param, str): + return self.getParam(param) + else: + raise ValueError("Cannot resolve %r as a param." % param) + @staticmethod def _dummy(): """ @@ -81,10 +181,18 @@ class Params(Identifiable): dummy.uid = "undefined" return dummy - def _set_params(self, **kwargs): + def _set(self, **kwargs): """ - Sets params. + Sets user-supplied params. """ for param, value in kwargs.iteritems(): self.paramMap[getattr(self, param)] = value return self + + def _setDefault(self, **kwargs): + """ + Sets default params. + """ + for param, value in kwargs.iteritems(): + self.defaultParamMap[getattr(self, param)] = value + return self diff --git a/python/pyspark/ml/param/_gen_shared_params.py b/python/pyspark/ml/param/_shared_params_code_gen.py similarity index 70% rename from python/pyspark/ml/param/_gen_shared_params.py rename to python/pyspark/ml/param/_shared_params_code_gen.py index 5eb81106f1..55f4224976 100644 --- a/python/pyspark/ml/param/_gen_shared_params.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -32,29 +32,34 @@ header = """# # limitations under the License. #""" +# Code generator for shared params (shared.py). Run under this folder with: +# python _shared_params_code_gen.py > shared.py -def _gen_param_code(name, doc, defaultValue): + +def _gen_param_code(name, doc, defaultValueStr): """ Generates Python code for a shared param class. :param name: param name :param doc: param doc - :param defaultValue: string representation of the param + :param defaultValueStr: string representation of the default value :return: code string """ # TODO: How to correctly inherit instance attributes? template = '''class Has$Name(Params): """ - Params with $name. + Mixin for param $name: $doc. """ # a placeholder to make it appear in the generated doc - $name = Param(Params._dummy(), "$name", "$doc", $defaultValue) + $name = Param(Params._dummy(), "$name", "$doc") def __init__(self): super(Has$Name, self).__init__() #: param for $doc - self.$name = Param(self, "$name", "$doc", $defaultValue) + self.$name = Param(self, "$name", "$doc") + if $defaultValueStr is not None: + self._setDefault($name=$defaultValueStr) def set$Name(self, value): """ @@ -67,32 +72,29 @@ def _gen_param_code(name, doc, defaultValue): """ Gets the value of $name or its default value. """ - if self.$name in self.paramMap: - return self.paramMap[self.$name] - else: - return self.$name.defaultValue''' + return self.getOrDefault(self.$name)''' - upperCamelName = name[0].upper() + name[1:] + Name = name[0].upper() + name[1:] return template \ .replace("$name", name) \ - .replace("$Name", upperCamelName) \ + .replace("$Name", Name) \ .replace("$doc", doc) \ - .replace("$defaultValue", defaultValue) + .replace("$defaultValueStr", str(defaultValueStr)) if __name__ == "__main__": print header - print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n" + print "\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n" print "from pyspark.ml.param import Param, Params\n\n" shared = [ - ("maxIter", "max number of iterations", "100"), - ("regParam", "regularization constant", "0.1"), + ("maxIter", "max number of iterations", None), + ("regParam", "regularization constant", None), ("featuresCol", "features column name", "'features'"), ("labelCol", "label column name", "'label'"), ("predictionCol", "prediction column name", "'prediction'"), - ("inputCol", "input column name", "'input'"), - ("outputCol", "output column name", "'output'"), - ("numFeatures", "number of features", "1 << 18")] + ("inputCol", "input column name", None), + ("outputCol", "output column name", None), + ("numFeatures", "number of features", None)] code = [] - for name, doc, defaultValue in shared: - code.append(_gen_param_code(name, doc, defaultValue)) + for name, doc, defaultValueStr in shared: + code.append(_gen_param_code(name, doc, defaultValueStr)) print "\n\n\n".join(code) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 586822f2de..13b6749998 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -15,23 +15,25 @@ # limitations under the License. # -# DO NOT MODIFY. The code is generated by _gen_shared_params.py. +# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py. from pyspark.ml.param import Param, Params class HasMaxIter(Params): """ - Params with maxIter. + Mixin for param maxIter: max number of iterations. """ # a placeholder to make it appear in the generated doc - maxIter = Param(Params._dummy(), "maxIter", "max number of iterations", 100) + maxIter = Param(Params._dummy(), "maxIter", "max number of iterations") def __init__(self): super(HasMaxIter, self).__init__() #: param for max number of iterations - self.maxIter = Param(self, "maxIter", "max number of iterations", 100) + self.maxIter = Param(self, "maxIter", "max number of iterations") + if None is not None: + self._setDefault(maxIter=None) def setMaxIter(self, value): """ @@ -44,24 +46,23 @@ class HasMaxIter(Params): """ Gets the value of maxIter or its default value. """ - if self.maxIter in self.paramMap: - return self.paramMap[self.maxIter] - else: - return self.maxIter.defaultValue + return self.getOrDefault(self.maxIter) class HasRegParam(Params): """ - Params with regParam. + Mixin for param regParam: regularization constant. """ # a placeholder to make it appear in the generated doc - regParam = Param(Params._dummy(), "regParam", "regularization constant", 0.1) + regParam = Param(Params._dummy(), "regParam", "regularization constant") def __init__(self): super(HasRegParam, self).__init__() #: param for regularization constant - self.regParam = Param(self, "regParam", "regularization constant", 0.1) + self.regParam = Param(self, "regParam", "regularization constant") + if None is not None: + self._setDefault(regParam=None) def setRegParam(self, value): """ @@ -74,24 +75,23 @@ class HasRegParam(Params): """ Gets the value of regParam or its default value. """ - if self.regParam in self.paramMap: - return self.paramMap[self.regParam] - else: - return self.regParam.defaultValue + return self.getOrDefault(self.regParam) class HasFeaturesCol(Params): """ - Params with featuresCol. + Mixin for param featuresCol: features column name. """ # a placeholder to make it appear in the generated doc - featuresCol = Param(Params._dummy(), "featuresCol", "features column name", 'features') + featuresCol = Param(Params._dummy(), "featuresCol", "features column name") def __init__(self): super(HasFeaturesCol, self).__init__() #: param for features column name - self.featuresCol = Param(self, "featuresCol", "features column name", 'features') + self.featuresCol = Param(self, "featuresCol", "features column name") + if 'features' is not None: + self._setDefault(featuresCol='features') def setFeaturesCol(self, value): """ @@ -104,24 +104,23 @@ class HasFeaturesCol(Params): """ Gets the value of featuresCol or its default value. """ - if self.featuresCol in self.paramMap: - return self.paramMap[self.featuresCol] - else: - return self.featuresCol.defaultValue + return self.getOrDefault(self.featuresCol) class HasLabelCol(Params): """ - Params with labelCol. + Mixin for param labelCol: label column name. """ # a placeholder to make it appear in the generated doc - labelCol = Param(Params._dummy(), "labelCol", "label column name", 'label') + labelCol = Param(Params._dummy(), "labelCol", "label column name") def __init__(self): super(HasLabelCol, self).__init__() #: param for label column name - self.labelCol = Param(self, "labelCol", "label column name", 'label') + self.labelCol = Param(self, "labelCol", "label column name") + if 'label' is not None: + self._setDefault(labelCol='label') def setLabelCol(self, value): """ @@ -134,24 +133,23 @@ class HasLabelCol(Params): """ Gets the value of labelCol or its default value. """ - if self.labelCol in self.paramMap: - return self.paramMap[self.labelCol] - else: - return self.labelCol.defaultValue + return self.getOrDefault(self.labelCol) class HasPredictionCol(Params): """ - Params with predictionCol. + Mixin for param predictionCol: prediction column name. """ # a placeholder to make it appear in the generated doc - predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name", 'prediction') + predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name") def __init__(self): super(HasPredictionCol, self).__init__() #: param for prediction column name - self.predictionCol = Param(self, "predictionCol", "prediction column name", 'prediction') + self.predictionCol = Param(self, "predictionCol", "prediction column name") + if 'prediction' is not None: + self._setDefault(predictionCol='prediction') def setPredictionCol(self, value): """ @@ -164,24 +162,23 @@ class HasPredictionCol(Params): """ Gets the value of predictionCol or its default value. """ - if self.predictionCol in self.paramMap: - return self.paramMap[self.predictionCol] - else: - return self.predictionCol.defaultValue + return self.getOrDefault(self.predictionCol) class HasInputCol(Params): """ - Params with inputCol. + Mixin for param inputCol: input column name. """ # a placeholder to make it appear in the generated doc - inputCol = Param(Params._dummy(), "inputCol", "input column name", 'input') + inputCol = Param(Params._dummy(), "inputCol", "input column name") def __init__(self): super(HasInputCol, self).__init__() #: param for input column name - self.inputCol = Param(self, "inputCol", "input column name", 'input') + self.inputCol = Param(self, "inputCol", "input column name") + if None is not None: + self._setDefault(inputCol=None) def setInputCol(self, value): """ @@ -194,24 +191,23 @@ class HasInputCol(Params): """ Gets the value of inputCol or its default value. """ - if self.inputCol in self.paramMap: - return self.paramMap[self.inputCol] - else: - return self.inputCol.defaultValue + return self.getOrDefault(self.inputCol) class HasOutputCol(Params): """ - Params with outputCol. + Mixin for param outputCol: output column name. """ # a placeholder to make it appear in the generated doc - outputCol = Param(Params._dummy(), "outputCol", "output column name", 'output') + outputCol = Param(Params._dummy(), "outputCol", "output column name") def __init__(self): super(HasOutputCol, self).__init__() #: param for output column name - self.outputCol = Param(self, "outputCol", "output column name", 'output') + self.outputCol = Param(self, "outputCol", "output column name") + if None is not None: + self._setDefault(outputCol=None) def setOutputCol(self, value): """ @@ -224,24 +220,23 @@ class HasOutputCol(Params): """ Gets the value of outputCol or its default value. """ - if self.outputCol in self.paramMap: - return self.paramMap[self.outputCol] - else: - return self.outputCol.defaultValue + return self.getOrDefault(self.outputCol) class HasNumFeatures(Params): """ - Params with numFeatures. + Mixin for param numFeatures: number of features. """ # a placeholder to make it appear in the generated doc - numFeatures = Param(Params._dummy(), "numFeatures", "number of features", 1 << 18) + numFeatures = Param(Params._dummy(), "numFeatures", "number of features") def __init__(self): super(HasNumFeatures, self).__init__() #: param for number of features - self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18) + self.numFeatures = Param(self, "numFeatures", "number of features") + if None is not None: + self._setDefault(numFeatures=None) def setNumFeatures(self, value): """ @@ -254,7 +249,4 @@ class HasNumFeatures(Params): """ Gets the value of numFeatures or its default value. """ - if self.numFeatures in self.paramMap: - return self.paramMap[self.numFeatures] - else: - return self.numFeatures.defaultValue + return self.getOrDefault(self.numFeatures) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 83880a5afc..d94ecfff09 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -124,10 +124,10 @@ class Pipeline(Estimator): Sets params for Pipeline. """ kwargs = self.setParams._input_kwargs - return self._set_params(**kwargs) + return self._set(**kwargs) def fit(self, dataset, params={}): - paramMap = self._merge_params(params) + paramMap = self.extractParamMap(params) stages = paramMap[self.stages] for stage in stages: if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)): @@ -164,7 +164,7 @@ class PipelineModel(Transformer): self.transformers = transformers def transform(self, dataset, params={}): - paramMap = self._merge_params(params) + paramMap = self.extractParamMap(params) for t in self.transformers: dataset = t.transform(dataset, paramMap) return dataset diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index b627c2b4e9..3a42bcf723 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -33,6 +33,7 @@ else: from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase from pyspark.sql import DataFrame from pyspark.ml.param import Param +from pyspark.ml.param.shared import HasMaxIter, HasInputCol from pyspark.ml.pipeline import Transformer, Estimator, Pipeline @@ -46,7 +47,7 @@ class MockTransformer(Transformer): def __init__(self): super(MockTransformer, self).__init__() - self.fake = Param(self, "fake", "fake", None) + self.fake = Param(self, "fake", "fake") self.dataset_index = None self.fake_param_value = None @@ -62,7 +63,7 @@ class MockEstimator(Estimator): def __init__(self): super(MockEstimator, self).__init__() - self.fake = Param(self, "fake", "fake", None) + self.fake = Param(self, "fake", "fake") self.dataset_index = None self.fake_param_value = None self.model = None @@ -111,5 +112,52 @@ class PipelineTests(PySparkTestCase): self.assertEqual(6, dataset.index) +class TestParams(HasMaxIter, HasInputCol): + """ + A subclass of Params mixed with HasMaxIter and HasInputCol. + """ + + def __init__(self): + super(TestParams, self).__init__() + self._setDefault(maxIter=10) + + +class ParamTests(PySparkTestCase): + + def test_param(self): + testParams = TestParams() + maxIter = testParams.maxIter + self.assertEqual(maxIter.name, "maxIter") + self.assertEqual(maxIter.doc, "max number of iterations") + self.assertTrue(maxIter.parent is testParams) + + def test_params(self): + testParams = TestParams() + maxIter = testParams.maxIter + inputCol = testParams.inputCol + + params = testParams.params + self.assertEqual(params, [inputCol, maxIter]) + + self.assertTrue(testParams.hasDefault(maxIter)) + self.assertFalse(testParams.isSet(maxIter)) + self.assertTrue(testParams.isDefined(maxIter)) + self.assertEqual(testParams.getMaxIter(), 10) + testParams.setMaxIter(100) + self.assertTrue(testParams.isSet(maxIter)) + self.assertEquals(testParams.getMaxIter(), 100) + + self.assertFalse(testParams.hasDefault(inputCol)) + self.assertFalse(testParams.isSet(inputCol)) + self.assertFalse(testParams.isDefined(inputCol)) + with self.assertRaises(KeyError): + testParams.getInputCol() + + self.assertEquals( + testParams.explainParams(), + "\n".join(["inputCol: input column name (undefined)", + "maxIter: max number of iterations (default: 10, current: 100)"])) + + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 6f7f39c40e..d3cb100a9e 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -40,8 +40,8 @@ class Identifiable(object): def __init__(self): #: A unique id for the object. The default implementation - #: concatenates the class name, "-", and 8 random hex chars. - self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8] + #: concatenates the class name, "_", and 8 random hex chars. + self.uid = type(self).__name__ + "_" + uuid.uuid4().hex[:8] def __repr__(self): return self.uid diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 31a66b3d2f..394f23c5e9 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -64,7 +64,7 @@ class JavaWrapper(Params): :param params: additional params (overwriting embedded values) :param java_obj: Java object to receive the params """ - paramMap = self._merge_params(params) + paramMap = self.extractParamMap(params) for param in self.params: if param in paramMap: java_obj.set(param.name, paramMap[param]) -- GitLab